diff --git a/plugin/output/elasticsearch/README.md b/plugin/output/elasticsearch/README.md index bcf034035..03df043be 100755 --- a/plugin/output/elasticsearch/README.md +++ b/plugin/output/elasticsearch/README.md @@ -170,5 +170,18 @@ Process ES response and report errors, if any.
+**`ban_period`** *`cfg.Duration`* *`default=10s`* + +Period for which addresses will be banned in case of unavailability. +If set to 0, circuit breaker is disabled. + +
+ +**`reconnect_interval`** *`cfg.Duration`* *`default=5s`* + +Interval for checking banned endpoints availability. + +
+
*Generated using [__insane-doc__](https://github.com/vitkovskii/insane-doc)* \ No newline at end of file diff --git a/plugin/output/elasticsearch/elasticsearch.go b/plugin/output/elasticsearch/elasticsearch.go index 4bdaea55e..35861d6c8 100644 --- a/plugin/output/elasticsearch/elasticsearch.go +++ b/plugin/output/elasticsearch/elasticsearch.go @@ -49,8 +49,9 @@ type Plugin struct { mu *sync.Mutex // plugin metrics - sendErrorMetric *metric.CounterVec - indexingErrorsMetric *metric.Counter + sendErrorMetric *metric.CounterVec + indexingErrorsMetric *metric.Counter + bannedEndpointsMetric *metric.Gauge router *pipeline.Router } @@ -203,6 +204,19 @@ type Config struct { // > // > Process ES response and report errors, if any. ProcessResponse bool `json:"process_response" default:"true"` // * + + // > @3@4@5@6 + // > + // > Period for which addresses will be banned in case of unavailability. + // > If set to 0, circuit breaker is disabled. + BanPeriod cfg.Duration `json:"ban_period" default:"10s" parse:"duration"` // * + BanPeriod_ time.Duration + + // > @3@4@5@6 + // > + // > Interval for checking banned endpoints availability. + ReconnectInterval cfg.Duration `json:"reconnect_interval" default:"5s" parse:"duration"` // * + ReconnectInterval_ time.Duration } type KeepAliveConfig struct { @@ -243,8 +257,17 @@ func (p *Plugin) Start(config pipeline.AnyConfig, params *pipeline.OutputPluginP if len(p.config.IndexValues) == 0 { p.config.IndexValues = append(p.config.IndexValues, "@time") } + if p.config.ReconnectInterval_ < 1 { + p.logger.Fatal("'reconnect_interval' can't be <1") + } + if p.config.BanPeriod_ < 0 { + p.logger.Fatal("'ban_period' cant't be <0") + } - p.prepareClient() + ctx, cancel := context.WithCancel(context.Background()) + p.cancel = cancel + + p.prepareClient(ctx) p.maintenance(nil) @@ -295,9 +318,6 @@ func (p *Plugin) Start(config pipeline.AnyConfig, params *pipeline.OutputPluginP onError, ) - ctx, cancel := context.WithCancel(context.Background()) - p.cancel = cancel - p.batcher.Start(ctx) } @@ -313,17 +333,26 @@ func (p *Plugin) Out(event *pipeline.Event) { func (p *Plugin) registerMetrics(ctl *metric.Ctl) { p.sendErrorMetric = ctl.RegisterCounterVec("output_elasticsearch_send_error_total", "Total elasticsearch send errors", "status_code") p.indexingErrorsMetric = ctl.RegisterCounter("output_elasticsearch_index_error_total", "Number of elasticsearch indexing errors") + p.bannedEndpointsMetric = ctl.RegisterGauge( + "output_elasticsearch_banned_endpoints_count", + "Current number of endpoints banned by circuit breaker", + ) + p.bannedEndpointsMetric.Set(0) } -func (p *Plugin) prepareClient() { +func (p *Plugin) prepareClient(ctx context.Context) { config := &xhttp.ClientConfig{ Endpoints: prepareEndpoints(p.config.Endpoints, p.config.IngestPipeline), ConnectionTimeout: p.config.ConnectionTimeout_ * 2, AuthHeader: p.getAuthHeader(), + BanPeriod: p.config.BanPeriod_, + ReconnectInterval: p.config.ReconnectInterval_, KeepAlive: &xhttp.ClientKeepAliveConfig{ MaxConnDuration: p.config.KeepAlive.MaxConnDuration_, MaxIdleConnDuration: p.config.KeepAlive.MaxIdleConnDuration_, }, + Logger: p.logger, + BannedEndpointsMetric: p.bannedEndpointsMetric, } if p.config.CACert != "" { config.TLS = &xhttp.ClientTLSConfig{ @@ -335,7 +364,7 @@ func (p *Plugin) prepareClient() { } var err error - p.client, err = xhttp.NewClient(config) + p.client, err = xhttp.NewClient(ctx, config) if err != nil { p.logger.Fatal("can't create http client", zap.Error(err)) } diff --git a/plugin/output/http/README.md b/plugin/output/http/README.md index 3c3810aaa..7c6a51b66 100755 --- a/plugin/output/http/README.md +++ b/plugin/output/http/README.md @@ -144,4 +144,17 @@ After a non-retryable write error, fall with a non-zero exit code or not
+**`ban_period`** *`cfg.Duration`* *`default=10s`* + +Period for which addresses will be banned in case of unavailability. +If set to 0, circuit breaker is disabled. + +
+ +**`reconnect_interval`** *`cfg.Duration`* *`default=5s`* + +Interval for checking banned endpoints availability. + +
+
*Generated using [__insane-doc__](https://github.com/vitkovskii/insane-doc)* \ No newline at end of file diff --git a/plugin/output/http/http.go b/plugin/output/http/http.go index 03edf9c51..ce4b98a09 100644 --- a/plugin/output/http/http.go +++ b/plugin/output/http/http.go @@ -44,7 +44,8 @@ type Plugin struct { mu *sync.Mutex // plugin metrics - sendErrorMetric *metric.CounterVec + sendErrorMetric *metric.CounterVec + bannedEndpointsMetric *metric.Gauge router *pipeline.Router } @@ -176,6 +177,19 @@ type Config struct { // > // > After a non-retryable write error, fall with a non-zero exit code or not Strict bool `json:"strict" default:"false"` // * + + // > @3@4@5@6 + // > + // > Period for which addresses will be banned in case of unavailability. + // > If set to 0, circuit breaker is disabled. + BanPeriod cfg.Duration `json:"ban_period" default:"10s" parse:"duration"` // * + BanPeriod_ time.Duration + + // > @3@4@5@6 + // > + // > Interval for checking banned endpoints availability. + ReconnectInterval cfg.Duration `json:"reconnect_interval" default:"5s" parse:"duration"` // * + ReconnectInterval_ time.Duration } type KeepAliveConfig struct { @@ -212,13 +226,23 @@ func (p *Plugin) Start(config pipeline.AnyConfig, params *pipeline.OutputPluginP p.registerMetrics(params.MetricCtl) p.mu = &sync.Mutex{} + if p.config.ReconnectInterval_ < 1 { + p.logger.Fatal("'reconnect_interval' can't be <1") + } + if p.config.BanPeriod_ < 0 { + p.logger.Fatal("'ban_period' cant't be <0") + } + var err error p.encoder, err = NewEncoder(p.config.Encoding) if err != nil { p.logger.Fatal("can't create encoder", zap.Error(err)) } - p.prepareClient() + ctx, cancel := context.WithCancel(context.Background()) + p.cancel = cancel + + p.prepareClient(ctx) p.logger.Info("starting batcher", zap.Duration("timeout", p.config.BatchFlushTimeout_)) @@ -267,9 +291,6 @@ func (p *Plugin) Start(config pipeline.AnyConfig, params *pipeline.OutputPluginP onError, ) - ctx, cancel := context.WithCancel(context.Background()) - p.cancel = cancel - p.batcher.Start(ctx) } @@ -284,17 +305,26 @@ func (p *Plugin) Out(event *pipeline.Event) { func (p *Plugin) registerMetrics(ctl *metric.Ctl) { p.sendErrorMetric = ctl.RegisterCounterVec("output_http_send_error_total", "Total HTTP send errors", "status_code") + p.bannedEndpointsMetric = ctl.RegisterGauge( + "output_http_banned_endpoints_count", + "Current number of endpoints banned by circuit breaker", + ) + p.bannedEndpointsMetric.Set(0) } -func (p *Plugin) prepareClient() { +func (p *Plugin) prepareClient(ctx context.Context) { config := &xhttp.ClientConfig{ Endpoints: p.prepareEndpoints(), ConnectionTimeout: p.config.ConnectionTimeout_ * 2, AuthHeader: p.getAuthHeader(), + BanPeriod: p.config.BanPeriod_, + ReconnectInterval: p.config.ReconnectInterval_, KeepAlive: &xhttp.ClientKeepAliveConfig{ MaxConnDuration: p.config.KeepAlive.MaxConnDuration_, MaxIdleConnDuration: p.config.KeepAlive.MaxIdleConnDuration_, }, + Logger: p.logger, + BannedEndpointsMetric: p.bannedEndpointsMetric, } if p.config.CACert != "" { config.TLS = &xhttp.ClientTLSConfig{ @@ -306,7 +336,7 @@ func (p *Plugin) prepareClient() { } var err error - p.client, err = xhttp.NewClient(config) + p.client, err = xhttp.NewClient(ctx, config) if err != nil { p.logger.Fatal("can't create http client", zap.Error(err)) } diff --git a/plugin/output/loki/README.md b/plugin/output/loki/README.md index 1ce47cee1..26b9fec02 100644 --- a/plugin/output/loki/README.md +++ b/plugin/output/loki/README.md @@ -149,5 +149,18 @@ Multiplier for exponential increase of retention between retries
+**`ban_period`** *`cfg.Duration`* *`default=10s`* + +Period for which addresses will be banned in case of unavailability. +If set to 0, circuit breaker is disabled. + +
+ +**`reconnect_interval`** *`cfg.Duration`* *`default=5s`* + +Interval for checking banned endpoints availability. + +
+
*Generated using [__insane-doc__](https://github.com/vitkovskii/insane-doc)* \ No newline at end of file diff --git a/plugin/output/loki/loki.go b/plugin/output/loki/loki.go index 160054675..ce3653e87 100644 --- a/plugin/output/loki/loki.go +++ b/plugin/output/loki/loki.go @@ -178,6 +178,19 @@ type Config struct { // > // > Multiplier for exponential increase of retention between retries RetentionExponentMultiplier int `json:"retention_exponentially_multiplier" default:"2"` // * + + // > @3@4@5@6 + // > + // > Period for which addresses will be banned in case of unavailability. + // > If set to 0, circuit breaker is disabled. + BanPeriod cfg.Duration `json:"ban_period" default:"10s" parse:"duration"` // * + BanPeriod_ time.Duration + + // > @3@4@5@6 + // > + // > Interval for checking banned endpoints availability. + ReconnectInterval cfg.Duration `json:"reconnect_interval" default:"5s" parse:"duration"` // * + ReconnectInterval_ time.Duration } type AuthStrategy byte @@ -232,7 +245,8 @@ type Plugin struct { batcher *pipeline.RetriableBatcher // plugin metrics - sendErrorMetric *metric.CounterVec + sendErrorMetric *metric.CounterVec + bannedEndpointsMetric *metric.Gauge labels map[string]string @@ -259,7 +273,18 @@ func (p *Plugin) Start(config pipeline.AnyConfig, params *pipeline.OutputPluginP p.labels = p.parseLabels() - p.prepareClient() + if p.config.ReconnectInterval_ < 1 { + p.logger.Fatal("'reconnect_interval' can't be <1") + } + if p.config.BanPeriod_ < 0 { + p.logger.Fatal("'ban_period' cant't be <0") + } + + ctx, cancel := context.WithCancel(context.Background()) + p.ctx = ctx + p.cancel = cancel + + p.prepareClient(ctx) batcherOpts := &pipeline.BatcherOptions{ PipelineName: params.PipelineName, @@ -303,10 +328,6 @@ func (p *Plugin) Start(config pipeline.AnyConfig, params *pipeline.OutputPluginP onError, ) - ctx, cancel := context.WithCancel(context.Background()) - p.ctx = ctx - p.cancel = cancel - p.batcher.Start(ctx) } @@ -428,22 +449,31 @@ func (p *Plugin) send(root *insaneJSON.Root) (int, error) { func (p *Plugin) registerMetrics(ctl *metric.Ctl) { p.sendErrorMetric = ctl.RegisterCounterVec("output_loki_send_error_total", "Total Loki send errors", "status_code") + p.bannedEndpointsMetric = ctl.RegisterGauge( + "output_loki_banned_endpoints_count", + "Current number of endpoints banned by circuit breaker", + ) + p.bannedEndpointsMetric.Set(0) } -func (p *Plugin) prepareClient() { +func (p *Plugin) prepareClient(ctx context.Context) { config := &xhttp.ClientConfig{ Endpoints: []string{fmt.Sprintf("%s/loki/api/v1/push", p.config.Address)}, ConnectionTimeout: p.config.ConnectionTimeout_ * 2, AuthHeader: p.getAuthHeader(), CustomHeaders: p.getCustomHeaders(), + BanPeriod: p.config.BanPeriod_, + ReconnectInterval: p.config.ReconnectInterval_, KeepAlive: &xhttp.ClientKeepAliveConfig{ MaxConnDuration: p.config.KeepAlive.MaxConnDuration_, MaxIdleConnDuration: p.config.KeepAlive.MaxIdleConnDuration_, }, + Logger: p.logger, + BannedEndpointsMetric: p.bannedEndpointsMetric, } var err error - p.client, err = xhttp.NewClient(config) + p.client, err = xhttp.NewClient(ctx, config) if err != nil { p.logger.Fatal("can't create http client", zap.Error(err)) } diff --git a/plugin/output/splunk/README.md b/plugin/output/splunk/README.md index b9df5013f..e6b8f2913 100755 --- a/plugin/output/splunk/README.md +++ b/plugin/output/splunk/README.md @@ -153,5 +153,18 @@ or the "event" key with any of its subkeys.
+**`ban_period`** *`cfg.Duration`* *`default=10s`* + +Period for which addresses will be banned in case of unavailability. +If set to 0, circuit breaker is disabled. + +
+ +**`reconnect_interval`** *`cfg.Duration`* *`default=5s`* + +Interval for checking banned endpoints availability. + +
+
*Generated using [__insane-doc__](https://github.com/vitkovskii/insane-doc)* \ No newline at end of file diff --git a/plugin/output/splunk/splunk.go b/plugin/output/splunk/splunk.go index 090ae5048..6b3a5796f 100644 --- a/plugin/output/splunk/splunk.go +++ b/plugin/output/splunk/splunk.go @@ -96,7 +96,8 @@ type Plugin struct { cancel context.CancelFunc // plugin metrics - sendErrorMetric *metric.CounterVec + sendErrorMetric *metric.CounterVec + bannedEndpointsMetric *metric.Gauge router *pipeline.Router } @@ -202,6 +203,19 @@ type Config struct { // > Supports copying whole original event, but does not allow to copy directly to the output root // > or the "event" key with any of its subkeys. CopyFields []CopyField `json:"copy_fields" slice:"true"` // * + + // > @3@4@5@6 + // > + // > Period for which addresses will be banned in case of unavailability. + // > If set to 0, circuit breaker is disabled. + BanPeriod cfg.Duration `json:"ban_period" default:"10s" parse:"duration"` // * + BanPeriod_ time.Duration + + // > @3@4@5@6 + // > + // > Interval for checking banned endpoints availability. + ReconnectInterval cfg.Duration `json:"reconnect_interval" default:"5s" parse:"duration"` // * + ReconnectInterval_ time.Duration } type KeepAliveConfig struct { @@ -235,7 +249,18 @@ func (p *Plugin) Start(config pipeline.AnyConfig, params *pipeline.OutputPluginP p.avgEventSize = params.PipelineSettings.AvgEventSize p.config = config.(*Config) p.registerMetrics(params.MetricCtl) - p.prepareClient() + + if p.config.ReconnectInterval_ < 1 { + p.logger.Fatal("'reconnect_interval' can't be <1") + } + if p.config.BanPeriod_ < 0 { + p.logger.Fatal("'ban_period' cant't be <0") + } + + ctx, cancel := context.WithCancel(context.Background()) + p.cancel = cancel + + p.prepareClient(ctx) for _, cf := range p.config.CopyFields { if cf.To == "" { @@ -296,9 +321,6 @@ func (p *Plugin) Start(config pipeline.AnyConfig, params *pipeline.OutputPluginP onError, ) - ctx, cancel := context.WithCancel(context.Background()) - p.cancel = cancel - p.batcher.Start(ctx) } @@ -317,13 +339,20 @@ func (p *Plugin) registerMetrics(ctl *metric.Ctl) { "Total splunk send errors", "status_code", ) + p.bannedEndpointsMetric = ctl.RegisterGauge( + "output_splunk_banned_endpoints_count", + "Current number of endpoints banned by circuit breaker", + ) + p.bannedEndpointsMetric.Set(0) } -func (p *Plugin) prepareClient() { +func (p *Plugin) prepareClient(ctx context.Context) { config := &xhttp.ClientConfig{ Endpoints: []string{p.config.Endpoint}, ConnectionTimeout: p.config.RequestTimeout_, AuthHeader: "Splunk " + p.config.Token, + BanPeriod: p.config.BanPeriod_, + ReconnectInterval: p.config.ReconnectInterval_, KeepAlive: &xhttp.ClientKeepAliveConfig{ MaxConnDuration: p.config.KeepAlive.MaxConnDuration_, MaxIdleConnDuration: p.config.KeepAlive.MaxIdleConnDuration_, @@ -332,13 +361,15 @@ func (p *Plugin) prepareClient() { // TODO: make this configuration option and false by default InsecureSkipVerify: true, }, + Logger: p.logger.Desugar(), + BannedEndpointsMetric: p.bannedEndpointsMetric, } if p.config.UseGzip { config.GzipCompressionLevel = p.config.GzipCompressionLevel } var err error - p.client, err = xhttp.NewClient(config) + p.client, err = xhttp.NewClient(ctx, config) if err != nil { p.logger.Fatal("can't create http client", zap.Error(err)) } diff --git a/plugin/output/splunk/splunk_test.go b/plugin/output/splunk/splunk_test.go index 40626e076..fec6de177 100644 --- a/plugin/output/splunk/splunk_test.go +++ b/plugin/output/splunk/splunk_test.go @@ -1,6 +1,7 @@ package splunk import ( + "context" "io" "net/http" "net/http/httptest" @@ -54,7 +55,7 @@ func TestSplunk(t *testing.T) { }, logger: zap.NewExample().Sugar(), } - plugin.prepareClient() + plugin.prepareClient(context.Background()) batch := pipeline.NewPreparedBatch([]*pipeline.Event{ {Root: input}, @@ -185,7 +186,7 @@ func TestCopyFields(t *testing.T) { copyFieldsPaths: tt.copyFields, logger: zap.NewExample().Sugar(), } - plugin.prepareClient() + plugin.prepareClient(context.Background()) batch := pipeline.NewPreparedBatch([]*pipeline.Event{ {Root: input}, diff --git a/xhttp/circuit_breaker.go b/xhttp/circuit_breaker.go new file mode 100644 index 000000000..ebc91f717 --- /dev/null +++ b/xhttp/circuit_breaker.go @@ -0,0 +1,176 @@ +package xhttp + +import ( + "context" + "math/rand" + "sync" + "time" + + "github.com/ozontech/file.d/metric" + "github.com/ozontech/file.d/xtime" + "github.com/valyala/fasthttp" + "go.uber.org/zap" +) + +type endpoint struct { + uri *fasthttp.URI + banUntil time.Time +} + +type circuitBreaker struct { + endpoints []endpoint + activeEndpoints []int + idxByURI map[string]int + banPeriod time.Duration + + logger *zap.Logger + bannedEndpointsMetric *metric.Gauge + + mu sync.RWMutex + nowFn func() time.Time +} + +func newCircuitBreaker( + ctx context.Context, + logger *zap.Logger, + uris []*fasthttp.URI, + banPeriod, reconnectInterval time.Duration, + bannedEndpointsMetric *metric.Gauge, +) *circuitBreaker { + if banPeriod <= 0 || len(uris) == 1 { + logger.Info( + "circuit breaker disabled", + zap.Duration("ban_period", banPeriod), + zap.Int("endpoints_count", len(uris)), + ) + + return nil + } + + cb := &circuitBreaker{ + endpoints: make([]endpoint, 0, len(uris)), + activeEndpoints: make([]int, 0, len(uris)), + idxByURI: make(map[string]int, len(uris)), + banPeriod: banPeriod, + logger: logger, + bannedEndpointsMetric: bannedEndpointsMetric, + nowFn: xtime.GetInaccurateTime, + } + + for i, uri := range uris { + cb.endpoints = append(cb.endpoints, endpoint{uri: uri}) + cb.idxByURI[uri.String()] = i + cb.activeEndpoints = append(cb.activeEndpoints, i) + } + + logger.Info( + "circuit breaker enabled", + zap.Duration("ban_period", banPeriod), + zap.Duration("reconnect_interval", reconnectInterval), + zap.Int("endpoints_count", len(uris)), + ) + + go cb.checkBannedEndpoints(ctx, reconnectInterval) + + return cb +} + +func (cb *circuitBreaker) updateBannedEndpointsMetric() { + if cb.bannedEndpointsMetric == nil { + return + } + + cb.bannedEndpointsMetric.Set(float64(len(cb.endpoints) - len(cb.activeEndpoints))) +} + +func (cb *circuitBreaker) getEndpoint() *fasthttp.URI { + cb.mu.RLock() + defer cb.mu.RUnlock() + + if len(cb.activeEndpoints) == 0 { + return nil + } + + idx := rand.Intn(len(cb.activeEndpoints)) + return cb.endpoints[cb.activeEndpoints[idx]].uri +} + +func (cb *circuitBreaker) banEndpoint(uri *fasthttp.URI) { + cb.mu.Lock() + defer cb.mu.Unlock() + + idx := cb.idxByURI[uri.String()] + cb.endpoints[idx].banUntil = cb.nowFn().Add(cb.banPeriod) + + for i, activeIdx := range cb.activeEndpoints { + if activeIdx == idx { + cb.activeEndpoints[i] = cb.activeEndpoints[len(cb.activeEndpoints)-1] + cb.activeEndpoints = cb.activeEndpoints[:len(cb.activeEndpoints)-1] + break + } + } + + cb.logger.Info( + "endpoint banned", + zap.String("endpoint", uri.String()), + zap.Duration("ban_period", cb.banPeriod), + zap.Int("active_endpoints_count", len(cb.activeEndpoints)), + zap.Int("banned_endpoints_count", len(cb.endpoints)-len(cb.activeEndpoints)), + ) + + cb.updateBannedEndpointsMetric() +} + +func (cb *circuitBreaker) restoreBannedEndpoints() { + cb.mu.RLock() + if len(cb.endpoints) == len(cb.activeEndpoints) { + cb.mu.RUnlock() + return + } + cb.mu.RUnlock() + + cb.mu.Lock() + defer cb.mu.Unlock() + + hasRestoredEndpoints := false + now := cb.nowFn() + for i := range cb.endpoints { + e := &cb.endpoints[i] + if !e.banUntil.IsZero() && now.After(e.banUntil) { + e.banUntil = time.Time{} + cb.activeEndpoints = append(cb.activeEndpoints, i) + hasRestoredEndpoints = true + + cb.logger.Info( + "endpoint restored", + zap.String("endpoint", e.uri.String()), + zap.Int("active_endpoints_count", len(cb.activeEndpoints)), + zap.Int("banned_endpoints_count", len(cb.endpoints)-len(cb.activeEndpoints)), + ) + } + } + + if hasRestoredEndpoints { + cb.updateBannedEndpointsMetric() + } +} + +func (cb *circuitBreaker) checkBannedEndpoints(ctx context.Context, reconnectInterval time.Duration) { + ticker := time.NewTicker(reconnectInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + cb.restoreBannedEndpoints() + } + } +} + +func (cb *circuitBreaker) setNowFn(nowFn func() time.Time) { + cb.mu.Lock() + defer cb.mu.Unlock() + cb.nowFn = nowFn +} diff --git a/xhttp/circuit_breaker_test.go b/xhttp/circuit_breaker_test.go new file mode 100644 index 000000000..2ccafc0cb --- /dev/null +++ b/xhttp/circuit_breaker_test.go @@ -0,0 +1,261 @@ +package xhttp + +import ( + "sync" + "testing" + "testing/synctest" + "time" + + "github.com/stretchr/testify/require" + "go.uber.org/zap" +) + +const ( + opBanEndpoint = iota + 1 + opSleep +) + +var ( + defaultEndpoints = []string{"http://localhost:19200", "http://localhost:19201", "http://localhost:19202"} + defaultWorkerCount = 50 +) + +type cbStep struct { + operation int + idxEp int + duration time.Duration +} + +func TestNewCircuitBreaker(t *testing.T) { + cases := []struct { + name string + banPeriod time.Duration + endpoints []string + disabled bool + }{ + { + name: "ban_period_zero", + banPeriod: 0, + endpoints: defaultEndpoints, + disabled: true, + }, + { + name: "single_endpoint", + banPeriod: 2 * time.Second, + endpoints: defaultEndpoints[0:1], + disabled: true, + }, + { + name: "two_and_more_endpoints", + banPeriod: 3 * time.Second, + endpoints: defaultEndpoints, + disabled: false, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + uris, err := parseEndpoints(tt.endpoints) + require.NoError(t, err) + + ctx := t.Context() + cb := newCircuitBreaker(ctx, zap.NewNop(), uris, tt.banPeriod, 5*time.Minute, nil) + + if tt.disabled { + require.Nil(t, cb, "circuit breaker must be disabled with these parameters") + return + } + + require.NotNil(t, cb) + require.Len(t, cb.activeEndpoints, len(tt.endpoints)) + require.Equal(t, cb.banPeriod, tt.banPeriod) + for i := range cb.endpoints { + require.Zero(t, cb.endpoints[i].banUntil, "endpoint[%d]: banUntil should be zero after creation", i) + } + }) + } +} + +func TestCircuitBreakerScenarios(t *testing.T) { + cases := []struct { + name string + endpoints []string + banPeriod time.Duration + steps []cbStep + wantActive []int + wantBanUntil map[int]time.Duration + }{ + { + name: "ban_removes_endpoint_from_active", + endpoints: defaultEndpoints, + banPeriod: 10 * time.Second, + steps: []cbStep{{operation: opBanEndpoint, idxEp: 1}}, + wantActive: []int{0, 2}, + wantBanUntil: map[int]time.Duration{ + 1: 10 * time.Second, + }, + }, + { + name: "ban_all_endpoint", + endpoints: defaultEndpoints, + banPeriod: 10 * time.Second, + steps: []cbStep{ + {operation: opBanEndpoint, idxEp: 0}, + {operation: opBanEndpoint, idxEp: 1}, + {operation: opBanEndpoint, idxEp: 2}, + }, + wantActive: nil, + }, + { + name: "ban_refreshes_ban_until", + endpoints: defaultEndpoints, + banPeriod: 10 * time.Second, + steps: []cbStep{ + {operation: opBanEndpoint, idxEp: 0}, + {operation: opSleep, duration: 5 * time.Second}, + {operation: opBanEndpoint, idxEp: 0}, + }, + wantActive: []int{1, 2}, + wantBanUntil: map[int]time.Duration{ + 0: 15 * time.Second, + }, + }, + { + name: "does_not_restore_before_ban_period", + endpoints: defaultEndpoints, + banPeriod: 40 * time.Second, + steps: []cbStep{ + {operation: opBanEndpoint, idxEp: 0}, + {operation: opSleep, duration: 31 * time.Second}, + }, + wantActive: []int{1, 2}, + }, + { + name: "restores_after_ban_period", + endpoints: defaultEndpoints, + banPeriod: 25 * time.Second, + steps: []cbStep{ + {operation: opBanEndpoint, idxEp: 0}, + {operation: opSleep, duration: 31 * time.Second}, + }, + wantActive: []int{0, 1, 2}, + }, + { + name: "partially_restores_expired_endpoints", + endpoints: defaultEndpoints, + banPeriod: 10 * time.Second, + steps: []cbStep{ + {operation: opBanEndpoint, idxEp: 0}, + {operation: opSleep, duration: 25 * time.Second}, + {operation: opBanEndpoint, idxEp: 1}, + {operation: opSleep, duration: 10 * time.Second}, + }, + wantActive: []int{0, 2}, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + synctest.Test(t, func(t *testing.T) { + ctx := t.Context() + + uris, err := parseEndpoints(tt.endpoints) + require.NoError(t, err) + + cb := newCircuitBreaker(ctx, zap.NewNop(), uris, tt.banPeriod, 30*time.Second, nil) + require.NotNil(t, cb) + cb.setNowFn(time.Now) + + start := time.Now() + for _, s := range tt.steps { + switch s.operation { + case opBanEndpoint: + cb.banEndpoint(cb.endpoints[s.idxEp].uri) + case opSleep: + time.Sleep(s.duration) + } + } + + cb.mu.RLock() + activeEp := append([]int{}, cb.activeEndpoints...) + banUntil := make([]time.Time, len(cb.endpoints)) + for i := range cb.endpoints { + banUntil[i] = cb.endpoints[i].banUntil + } + cb.mu.RUnlock() + + require.ElementsMatch(t, tt.wantActive, activeEp) + + for idx, dur := range tt.wantBanUntil { + if dur == 0 { + require.Zero(t, banUntil[idx], "endpoint[%d]: banUntil should be zero", idx) + } + require.Equal(t, start.Add(dur), banUntil[idx]) + } + }) + }) + } +} + +func TestCircuitBreakerFullCycle(t *testing.T) { + t.Parallel() + + synctest.Test(t, func(t *testing.T) { + uris, err := parseEndpoints(defaultEndpoints) + require.NoError(t, err) + + ctx := t.Context() + cb := newCircuitBreaker(ctx, zap.NewNop(), uris, 10*time.Second, 3*time.Second, nil) + require.NotNil(t, cb) + cb.setNowFn(time.Now) + + ep0, ep1, ep2 := cb.endpoints[0].uri.String(), cb.endpoints[1].uri.String(), cb.endpoints[2].uri.String() + require.ElementsMatch(t, []string{ep0, ep1, ep2}, pickedURIs(cb, defaultWorkerCount)) + + cb.banEndpoint(cb.endpoints[0].uri) + require.ElementsMatch(t, []string{ep1, ep2}, pickedURIs(cb, defaultWorkerCount)) + time.Sleep(5 * time.Second) + + cb.banEndpoint(cb.endpoints[1].uri) + require.ElementsMatch(t, []string{ep2}, pickedURIs(cb, defaultWorkerCount)) + time.Sleep(8 * time.Second) + + require.ElementsMatch(t, []string{ep0, ep2}, pickedURIs(cb, defaultWorkerCount)) + + time.Sleep(7 * time.Second) + require.ElementsMatch(t, []string{ep0, ep1, ep2}, pickedURIs(cb, defaultWorkerCount)) + }) +} + +func pickedURIs(cb *circuitBreaker, workers int) []string { + var ( + wg sync.WaitGroup + mu sync.Mutex + seen = make(map[string]struct{}) + ) + + for range workers { + wg.Go(func() { + uri := cb.getEndpoint() + if uri == nil { + return + } + + mu.Lock() + seen[uri.String()] = struct{}{} + mu.Unlock() + }) + } + wg.Wait() + + out := make([]string, 0, len(seen)) + for k := range seen { + out = append(out, k) + } + + return out +} diff --git a/xhttp/client.go b/xhttp/client.go index f9ea6e6a9..560d68289 100644 --- a/xhttp/client.go +++ b/xhttp/client.go @@ -1,13 +1,16 @@ package xhttp import ( + "context" "fmt" "math/rand" "net/http" "time" + "github.com/ozontech/file.d/metric" "github.com/ozontech/file.d/xtls" "github.com/valyala/fasthttp" + "go.uber.org/zap" ) const gzipContentEncoding = "gzip" @@ -23,24 +26,29 @@ type ClientKeepAliveConfig struct { } type ClientConfig struct { - Endpoints []string - ConnectionTimeout time.Duration - AuthHeader string - CustomHeaders map[string]string - GzipCompressionLevel string - TLS *ClientTLSConfig - KeepAlive *ClientKeepAliveConfig + Endpoints []string + ConnectionTimeout time.Duration + AuthHeader string + CustomHeaders map[string]string + GzipCompressionLevel string + TLS *ClientTLSConfig + KeepAlive *ClientKeepAliveConfig + BanPeriod time.Duration + ReconnectInterval time.Duration + Logger *zap.Logger + BannedEndpointsMetric *metric.Gauge } type Client struct { client *fasthttp.Client endpoints []*fasthttp.URI + cb *circuitBreaker authHeader string customHeaders map[string]string gzipCompressionLevel int } -func NewClient(cfg *ClientConfig) (*Client, error) { +func NewClient(ctx context.Context, cfg *ClientConfig) (*Client, error) { client := &fasthttp.Client{ ReadTimeout: cfg.ConnectionTimeout, WriteTimeout: cfg.ConnectionTimeout, @@ -70,8 +78,16 @@ func NewClient(cfg *ClientConfig) (*Client, error) { } return &Client{ - client: client, - endpoints: endpoints, + client: client, + endpoints: endpoints, + cb: newCircuitBreaker( + ctx, + cfg.Logger, + endpoints, + cfg.BanPeriod, + cfg.ReconnectInterval, + cfg.BannedEndpointsMetric, + ), authHeader: cfg.AuthHeader, customHeaders: cfg.CustomHeaders, gzipCompressionLevel: parseGzipCompressionLevel(cfg.GzipCompressionLevel), @@ -89,16 +105,15 @@ func (c *Client) DoTimeout( resp := fasthttp.AcquireResponse() defer fasthttp.ReleaseResponse(resp) - var endpoint *fasthttp.URI - if len(c.endpoints) == 1 { - endpoint = c.endpoints[0] - } else { - endpoint = c.endpoints[rand.Int()%len(c.endpoints)] + endpoint := c.getEndpoint() + if endpoint == nil { + return 0, fmt.Errorf("no available endpoints") } c.prepareRequest(req, endpoint, method, contentType, body) if err := c.client.DoTimeout(req, resp, timeout); err != nil { + c.banEndpoint(endpoint) return 0, fmt.Errorf("can't send request to %s: %w", endpoint.String(), err) } @@ -106,6 +121,9 @@ func (c *Client) DoTimeout( statusCode := resp.Header.StatusCode() if !(http.StatusOK <= statusCode && statusCode <= http.StatusAccepted) { + if shouldBanEndpoint(statusCode) { + c.banEndpoint(endpoint) + } return statusCode, fmt.Errorf("response status from %s isn't OK: status=%d, body=%s", endpoint.String(), statusCode, string(respContent)) } @@ -168,3 +186,32 @@ func parseGzipCompressionLevel(level string) int { return -1 } } + +func (c *Client) getEndpoint() *fasthttp.URI { + if c.cb != nil { + return c.cb.getEndpoint() + } + + if len(c.endpoints) == 0 { + return nil + } + return c.endpoints[rand.Intn(len(c.endpoints))] +} + +func (c *Client) banEndpoint(endpoint *fasthttp.URI) { + if c.cb != nil { + c.cb.banEndpoint(endpoint) + } +} + +func shouldBanEndpoint(statusCode int) bool { + switch statusCode { + case http.StatusBadGateway, + http.StatusServiceUnavailable, + http.StatusGatewayTimeout, + http.StatusTooManyRequests: + return true + default: + return false + } +}