diff --git a/plugin/output/elasticsearch/README.md b/plugin/output/elasticsearch/README.md
index bcf034035..03df043be 100755
--- a/plugin/output/elasticsearch/README.md
+++ b/plugin/output/elasticsearch/README.md
@@ -170,5 +170,18 @@ Process ES response and report errors, if any.
+**`ban_period`** *`cfg.Duration`* *`default=10s`*
+
+Period for which addresses will be banned in case of unavailability.
+If set to 0, circuit breaker is disabled.
+
+
+
+**`reconnect_interval`** *`cfg.Duration`* *`default=5s`*
+
+Interval for checking banned endpoints availability.
+
+
+
*Generated using [__insane-doc__](https://github.com/vitkovskii/insane-doc)*
\ No newline at end of file
diff --git a/plugin/output/elasticsearch/elasticsearch.go b/plugin/output/elasticsearch/elasticsearch.go
index 4bdaea55e..35861d6c8 100644
--- a/plugin/output/elasticsearch/elasticsearch.go
+++ b/plugin/output/elasticsearch/elasticsearch.go
@@ -49,8 +49,9 @@ type Plugin struct {
mu *sync.Mutex
// plugin metrics
- sendErrorMetric *metric.CounterVec
- indexingErrorsMetric *metric.Counter
+ sendErrorMetric *metric.CounterVec
+ indexingErrorsMetric *metric.Counter
+ bannedEndpointsMetric *metric.Gauge
router *pipeline.Router
}
@@ -203,6 +204,19 @@ type Config struct {
// >
// > Process ES response and report errors, if any.
ProcessResponse bool `json:"process_response" default:"true"` // *
+
+ // > @3@4@5@6
+ // >
+ // > Period for which addresses will be banned in case of unavailability.
+ // > If set to 0, circuit breaker is disabled.
+ BanPeriod cfg.Duration `json:"ban_period" default:"10s" parse:"duration"` // *
+ BanPeriod_ time.Duration
+
+ // > @3@4@5@6
+ // >
+ // > Interval for checking banned endpoints availability.
+ ReconnectInterval cfg.Duration `json:"reconnect_interval" default:"5s" parse:"duration"` // *
+ ReconnectInterval_ time.Duration
}
type KeepAliveConfig struct {
@@ -243,8 +257,17 @@ func (p *Plugin) Start(config pipeline.AnyConfig, params *pipeline.OutputPluginP
if len(p.config.IndexValues) == 0 {
p.config.IndexValues = append(p.config.IndexValues, "@time")
}
+ if p.config.ReconnectInterval_ < 1 {
+ p.logger.Fatal("'reconnect_interval' can't be <1")
+ }
+ if p.config.BanPeriod_ < 0 {
+ p.logger.Fatal("'ban_period' cant't be <0")
+ }
- p.prepareClient()
+ ctx, cancel := context.WithCancel(context.Background())
+ p.cancel = cancel
+
+ p.prepareClient(ctx)
p.maintenance(nil)
@@ -295,9 +318,6 @@ func (p *Plugin) Start(config pipeline.AnyConfig, params *pipeline.OutputPluginP
onError,
)
- ctx, cancel := context.WithCancel(context.Background())
- p.cancel = cancel
-
p.batcher.Start(ctx)
}
@@ -313,17 +333,26 @@ func (p *Plugin) Out(event *pipeline.Event) {
func (p *Plugin) registerMetrics(ctl *metric.Ctl) {
p.sendErrorMetric = ctl.RegisterCounterVec("output_elasticsearch_send_error_total", "Total elasticsearch send errors", "status_code")
p.indexingErrorsMetric = ctl.RegisterCounter("output_elasticsearch_index_error_total", "Number of elasticsearch indexing errors")
+ p.bannedEndpointsMetric = ctl.RegisterGauge(
+ "output_elasticsearch_banned_endpoints_count",
+ "Current number of endpoints banned by circuit breaker",
+ )
+ p.bannedEndpointsMetric.Set(0)
}
-func (p *Plugin) prepareClient() {
+func (p *Plugin) prepareClient(ctx context.Context) {
config := &xhttp.ClientConfig{
Endpoints: prepareEndpoints(p.config.Endpoints, p.config.IngestPipeline),
ConnectionTimeout: p.config.ConnectionTimeout_ * 2,
AuthHeader: p.getAuthHeader(),
+ BanPeriod: p.config.BanPeriod_,
+ ReconnectInterval: p.config.ReconnectInterval_,
KeepAlive: &xhttp.ClientKeepAliveConfig{
MaxConnDuration: p.config.KeepAlive.MaxConnDuration_,
MaxIdleConnDuration: p.config.KeepAlive.MaxIdleConnDuration_,
},
+ Logger: p.logger,
+ BannedEndpointsMetric: p.bannedEndpointsMetric,
}
if p.config.CACert != "" {
config.TLS = &xhttp.ClientTLSConfig{
@@ -335,7 +364,7 @@ func (p *Plugin) prepareClient() {
}
var err error
- p.client, err = xhttp.NewClient(config)
+ p.client, err = xhttp.NewClient(ctx, config)
if err != nil {
p.logger.Fatal("can't create http client", zap.Error(err))
}
diff --git a/plugin/output/http/README.md b/plugin/output/http/README.md
index 3c3810aaa..7c6a51b66 100755
--- a/plugin/output/http/README.md
+++ b/plugin/output/http/README.md
@@ -144,4 +144,17 @@ After a non-retryable write error, fall with a non-zero exit code or not
+**`ban_period`** *`cfg.Duration`* *`default=10s`*
+
+Period for which addresses will be banned in case of unavailability.
+If set to 0, circuit breaker is disabled.
+
+
+
+**`reconnect_interval`** *`cfg.Duration`* *`default=5s`*
+
+Interval for checking banned endpoints availability.
+
+
+
*Generated using [__insane-doc__](https://github.com/vitkovskii/insane-doc)*
\ No newline at end of file
diff --git a/plugin/output/http/http.go b/plugin/output/http/http.go
index 03edf9c51..ce4b98a09 100644
--- a/plugin/output/http/http.go
+++ b/plugin/output/http/http.go
@@ -44,7 +44,8 @@ type Plugin struct {
mu *sync.Mutex
// plugin metrics
- sendErrorMetric *metric.CounterVec
+ sendErrorMetric *metric.CounterVec
+ bannedEndpointsMetric *metric.Gauge
router *pipeline.Router
}
@@ -176,6 +177,19 @@ type Config struct {
// >
// > After a non-retryable write error, fall with a non-zero exit code or not
Strict bool `json:"strict" default:"false"` // *
+
+ // > @3@4@5@6
+ // >
+ // > Period for which addresses will be banned in case of unavailability.
+ // > If set to 0, circuit breaker is disabled.
+ BanPeriod cfg.Duration `json:"ban_period" default:"10s" parse:"duration"` // *
+ BanPeriod_ time.Duration
+
+ // > @3@4@5@6
+ // >
+ // > Interval for checking banned endpoints availability.
+ ReconnectInterval cfg.Duration `json:"reconnect_interval" default:"5s" parse:"duration"` // *
+ ReconnectInterval_ time.Duration
}
type KeepAliveConfig struct {
@@ -212,13 +226,23 @@ func (p *Plugin) Start(config pipeline.AnyConfig, params *pipeline.OutputPluginP
p.registerMetrics(params.MetricCtl)
p.mu = &sync.Mutex{}
+ if p.config.ReconnectInterval_ < 1 {
+ p.logger.Fatal("'reconnect_interval' can't be <1")
+ }
+ if p.config.BanPeriod_ < 0 {
+ p.logger.Fatal("'ban_period' cant't be <0")
+ }
+
var err error
p.encoder, err = NewEncoder(p.config.Encoding)
if err != nil {
p.logger.Fatal("can't create encoder", zap.Error(err))
}
- p.prepareClient()
+ ctx, cancel := context.WithCancel(context.Background())
+ p.cancel = cancel
+
+ p.prepareClient(ctx)
p.logger.Info("starting batcher", zap.Duration("timeout", p.config.BatchFlushTimeout_))
@@ -267,9 +291,6 @@ func (p *Plugin) Start(config pipeline.AnyConfig, params *pipeline.OutputPluginP
onError,
)
- ctx, cancel := context.WithCancel(context.Background())
- p.cancel = cancel
-
p.batcher.Start(ctx)
}
@@ -284,17 +305,26 @@ func (p *Plugin) Out(event *pipeline.Event) {
func (p *Plugin) registerMetrics(ctl *metric.Ctl) {
p.sendErrorMetric = ctl.RegisterCounterVec("output_http_send_error_total", "Total HTTP send errors", "status_code")
+ p.bannedEndpointsMetric = ctl.RegisterGauge(
+ "output_http_banned_endpoints_count",
+ "Current number of endpoints banned by circuit breaker",
+ )
+ p.bannedEndpointsMetric.Set(0)
}
-func (p *Plugin) prepareClient() {
+func (p *Plugin) prepareClient(ctx context.Context) {
config := &xhttp.ClientConfig{
Endpoints: p.prepareEndpoints(),
ConnectionTimeout: p.config.ConnectionTimeout_ * 2,
AuthHeader: p.getAuthHeader(),
+ BanPeriod: p.config.BanPeriod_,
+ ReconnectInterval: p.config.ReconnectInterval_,
KeepAlive: &xhttp.ClientKeepAliveConfig{
MaxConnDuration: p.config.KeepAlive.MaxConnDuration_,
MaxIdleConnDuration: p.config.KeepAlive.MaxIdleConnDuration_,
},
+ Logger: p.logger,
+ BannedEndpointsMetric: p.bannedEndpointsMetric,
}
if p.config.CACert != "" {
config.TLS = &xhttp.ClientTLSConfig{
@@ -306,7 +336,7 @@ func (p *Plugin) prepareClient() {
}
var err error
- p.client, err = xhttp.NewClient(config)
+ p.client, err = xhttp.NewClient(ctx, config)
if err != nil {
p.logger.Fatal("can't create http client", zap.Error(err))
}
diff --git a/plugin/output/loki/README.md b/plugin/output/loki/README.md
index 1ce47cee1..26b9fec02 100644
--- a/plugin/output/loki/README.md
+++ b/plugin/output/loki/README.md
@@ -149,5 +149,18 @@ Multiplier for exponential increase of retention between retries
+**`ban_period`** *`cfg.Duration`* *`default=10s`*
+
+Period for which addresses will be banned in case of unavailability.
+If set to 0, circuit breaker is disabled.
+
+
+
+**`reconnect_interval`** *`cfg.Duration`* *`default=5s`*
+
+Interval for checking banned endpoints availability.
+
+
+
*Generated using [__insane-doc__](https://github.com/vitkovskii/insane-doc)*
\ No newline at end of file
diff --git a/plugin/output/loki/loki.go b/plugin/output/loki/loki.go
index 160054675..ce3653e87 100644
--- a/plugin/output/loki/loki.go
+++ b/plugin/output/loki/loki.go
@@ -178,6 +178,19 @@ type Config struct {
// >
// > Multiplier for exponential increase of retention between retries
RetentionExponentMultiplier int `json:"retention_exponentially_multiplier" default:"2"` // *
+
+ // > @3@4@5@6
+ // >
+ // > Period for which addresses will be banned in case of unavailability.
+ // > If set to 0, circuit breaker is disabled.
+ BanPeriod cfg.Duration `json:"ban_period" default:"10s" parse:"duration"` // *
+ BanPeriod_ time.Duration
+
+ // > @3@4@5@6
+ // >
+ // > Interval for checking banned endpoints availability.
+ ReconnectInterval cfg.Duration `json:"reconnect_interval" default:"5s" parse:"duration"` // *
+ ReconnectInterval_ time.Duration
}
type AuthStrategy byte
@@ -232,7 +245,8 @@ type Plugin struct {
batcher *pipeline.RetriableBatcher
// plugin metrics
- sendErrorMetric *metric.CounterVec
+ sendErrorMetric *metric.CounterVec
+ bannedEndpointsMetric *metric.Gauge
labels map[string]string
@@ -259,7 +273,18 @@ func (p *Plugin) Start(config pipeline.AnyConfig, params *pipeline.OutputPluginP
p.labels = p.parseLabels()
- p.prepareClient()
+ if p.config.ReconnectInterval_ < 1 {
+ p.logger.Fatal("'reconnect_interval' can't be <1")
+ }
+ if p.config.BanPeriod_ < 0 {
+ p.logger.Fatal("'ban_period' cant't be <0")
+ }
+
+ ctx, cancel := context.WithCancel(context.Background())
+ p.ctx = ctx
+ p.cancel = cancel
+
+ p.prepareClient(ctx)
batcherOpts := &pipeline.BatcherOptions{
PipelineName: params.PipelineName,
@@ -303,10 +328,6 @@ func (p *Plugin) Start(config pipeline.AnyConfig, params *pipeline.OutputPluginP
onError,
)
- ctx, cancel := context.WithCancel(context.Background())
- p.ctx = ctx
- p.cancel = cancel
-
p.batcher.Start(ctx)
}
@@ -428,22 +449,31 @@ func (p *Plugin) send(root *insaneJSON.Root) (int, error) {
func (p *Plugin) registerMetrics(ctl *metric.Ctl) {
p.sendErrorMetric = ctl.RegisterCounterVec("output_loki_send_error_total", "Total Loki send errors", "status_code")
+ p.bannedEndpointsMetric = ctl.RegisterGauge(
+ "output_loki_banned_endpoints_count",
+ "Current number of endpoints banned by circuit breaker",
+ )
+ p.bannedEndpointsMetric.Set(0)
}
-func (p *Plugin) prepareClient() {
+func (p *Plugin) prepareClient(ctx context.Context) {
config := &xhttp.ClientConfig{
Endpoints: []string{fmt.Sprintf("%s/loki/api/v1/push", p.config.Address)},
ConnectionTimeout: p.config.ConnectionTimeout_ * 2,
AuthHeader: p.getAuthHeader(),
CustomHeaders: p.getCustomHeaders(),
+ BanPeriod: p.config.BanPeriod_,
+ ReconnectInterval: p.config.ReconnectInterval_,
KeepAlive: &xhttp.ClientKeepAliveConfig{
MaxConnDuration: p.config.KeepAlive.MaxConnDuration_,
MaxIdleConnDuration: p.config.KeepAlive.MaxIdleConnDuration_,
},
+ Logger: p.logger,
+ BannedEndpointsMetric: p.bannedEndpointsMetric,
}
var err error
- p.client, err = xhttp.NewClient(config)
+ p.client, err = xhttp.NewClient(ctx, config)
if err != nil {
p.logger.Fatal("can't create http client", zap.Error(err))
}
diff --git a/plugin/output/splunk/README.md b/plugin/output/splunk/README.md
index b9df5013f..e6b8f2913 100755
--- a/plugin/output/splunk/README.md
+++ b/plugin/output/splunk/README.md
@@ -153,5 +153,18 @@ or the "event" key with any of its subkeys.
+**`ban_period`** *`cfg.Duration`* *`default=10s`*
+
+Period for which addresses will be banned in case of unavailability.
+If set to 0, circuit breaker is disabled.
+
+
+
+**`reconnect_interval`** *`cfg.Duration`* *`default=5s`*
+
+Interval for checking banned endpoints availability.
+
+
+
*Generated using [__insane-doc__](https://github.com/vitkovskii/insane-doc)*
\ No newline at end of file
diff --git a/plugin/output/splunk/splunk.go b/plugin/output/splunk/splunk.go
index 090ae5048..6b3a5796f 100644
--- a/plugin/output/splunk/splunk.go
+++ b/plugin/output/splunk/splunk.go
@@ -96,7 +96,8 @@ type Plugin struct {
cancel context.CancelFunc
// plugin metrics
- sendErrorMetric *metric.CounterVec
+ sendErrorMetric *metric.CounterVec
+ bannedEndpointsMetric *metric.Gauge
router *pipeline.Router
}
@@ -202,6 +203,19 @@ type Config struct {
// > Supports copying whole original event, but does not allow to copy directly to the output root
// > or the "event" key with any of its subkeys.
CopyFields []CopyField `json:"copy_fields" slice:"true"` // *
+
+ // > @3@4@5@6
+ // >
+ // > Period for which addresses will be banned in case of unavailability.
+ // > If set to 0, circuit breaker is disabled.
+ BanPeriod cfg.Duration `json:"ban_period" default:"10s" parse:"duration"` // *
+ BanPeriod_ time.Duration
+
+ // > @3@4@5@6
+ // >
+ // > Interval for checking banned endpoints availability.
+ ReconnectInterval cfg.Duration `json:"reconnect_interval" default:"5s" parse:"duration"` // *
+ ReconnectInterval_ time.Duration
}
type KeepAliveConfig struct {
@@ -235,7 +249,18 @@ func (p *Plugin) Start(config pipeline.AnyConfig, params *pipeline.OutputPluginP
p.avgEventSize = params.PipelineSettings.AvgEventSize
p.config = config.(*Config)
p.registerMetrics(params.MetricCtl)
- p.prepareClient()
+
+ if p.config.ReconnectInterval_ < 1 {
+ p.logger.Fatal("'reconnect_interval' can't be <1")
+ }
+ if p.config.BanPeriod_ < 0 {
+ p.logger.Fatal("'ban_period' cant't be <0")
+ }
+
+ ctx, cancel := context.WithCancel(context.Background())
+ p.cancel = cancel
+
+ p.prepareClient(ctx)
for _, cf := range p.config.CopyFields {
if cf.To == "" {
@@ -296,9 +321,6 @@ func (p *Plugin) Start(config pipeline.AnyConfig, params *pipeline.OutputPluginP
onError,
)
- ctx, cancel := context.WithCancel(context.Background())
- p.cancel = cancel
-
p.batcher.Start(ctx)
}
@@ -317,13 +339,20 @@ func (p *Plugin) registerMetrics(ctl *metric.Ctl) {
"Total splunk send errors",
"status_code",
)
+ p.bannedEndpointsMetric = ctl.RegisterGauge(
+ "output_splunk_banned_endpoints_count",
+ "Current number of endpoints banned by circuit breaker",
+ )
+ p.bannedEndpointsMetric.Set(0)
}
-func (p *Plugin) prepareClient() {
+func (p *Plugin) prepareClient(ctx context.Context) {
config := &xhttp.ClientConfig{
Endpoints: []string{p.config.Endpoint},
ConnectionTimeout: p.config.RequestTimeout_,
AuthHeader: "Splunk " + p.config.Token,
+ BanPeriod: p.config.BanPeriod_,
+ ReconnectInterval: p.config.ReconnectInterval_,
KeepAlive: &xhttp.ClientKeepAliveConfig{
MaxConnDuration: p.config.KeepAlive.MaxConnDuration_,
MaxIdleConnDuration: p.config.KeepAlive.MaxIdleConnDuration_,
@@ -332,13 +361,15 @@ func (p *Plugin) prepareClient() {
// TODO: make this configuration option and false by default
InsecureSkipVerify: true,
},
+ Logger: p.logger.Desugar(),
+ BannedEndpointsMetric: p.bannedEndpointsMetric,
}
if p.config.UseGzip {
config.GzipCompressionLevel = p.config.GzipCompressionLevel
}
var err error
- p.client, err = xhttp.NewClient(config)
+ p.client, err = xhttp.NewClient(ctx, config)
if err != nil {
p.logger.Fatal("can't create http client", zap.Error(err))
}
diff --git a/plugin/output/splunk/splunk_test.go b/plugin/output/splunk/splunk_test.go
index 40626e076..fec6de177 100644
--- a/plugin/output/splunk/splunk_test.go
+++ b/plugin/output/splunk/splunk_test.go
@@ -1,6 +1,7 @@
package splunk
import (
+ "context"
"io"
"net/http"
"net/http/httptest"
@@ -54,7 +55,7 @@ func TestSplunk(t *testing.T) {
},
logger: zap.NewExample().Sugar(),
}
- plugin.prepareClient()
+ plugin.prepareClient(context.Background())
batch := pipeline.NewPreparedBatch([]*pipeline.Event{
{Root: input},
@@ -185,7 +186,7 @@ func TestCopyFields(t *testing.T) {
copyFieldsPaths: tt.copyFields,
logger: zap.NewExample().Sugar(),
}
- plugin.prepareClient()
+ plugin.prepareClient(context.Background())
batch := pipeline.NewPreparedBatch([]*pipeline.Event{
{Root: input},
diff --git a/xhttp/circuit_breaker.go b/xhttp/circuit_breaker.go
new file mode 100644
index 000000000..ebc91f717
--- /dev/null
+++ b/xhttp/circuit_breaker.go
@@ -0,0 +1,176 @@
+package xhttp
+
+import (
+ "context"
+ "math/rand"
+ "sync"
+ "time"
+
+ "github.com/ozontech/file.d/metric"
+ "github.com/ozontech/file.d/xtime"
+ "github.com/valyala/fasthttp"
+ "go.uber.org/zap"
+)
+
+type endpoint struct {
+ uri *fasthttp.URI
+ banUntil time.Time
+}
+
+type circuitBreaker struct {
+ endpoints []endpoint
+ activeEndpoints []int
+ idxByURI map[string]int
+ banPeriod time.Duration
+
+ logger *zap.Logger
+ bannedEndpointsMetric *metric.Gauge
+
+ mu sync.RWMutex
+ nowFn func() time.Time
+}
+
+func newCircuitBreaker(
+ ctx context.Context,
+ logger *zap.Logger,
+ uris []*fasthttp.URI,
+ banPeriod, reconnectInterval time.Duration,
+ bannedEndpointsMetric *metric.Gauge,
+) *circuitBreaker {
+ if banPeriod <= 0 || len(uris) == 1 {
+ logger.Info(
+ "circuit breaker disabled",
+ zap.Duration("ban_period", banPeriod),
+ zap.Int("endpoints_count", len(uris)),
+ )
+
+ return nil
+ }
+
+ cb := &circuitBreaker{
+ endpoints: make([]endpoint, 0, len(uris)),
+ activeEndpoints: make([]int, 0, len(uris)),
+ idxByURI: make(map[string]int, len(uris)),
+ banPeriod: banPeriod,
+ logger: logger,
+ bannedEndpointsMetric: bannedEndpointsMetric,
+ nowFn: xtime.GetInaccurateTime,
+ }
+
+ for i, uri := range uris {
+ cb.endpoints = append(cb.endpoints, endpoint{uri: uri})
+ cb.idxByURI[uri.String()] = i
+ cb.activeEndpoints = append(cb.activeEndpoints, i)
+ }
+
+ logger.Info(
+ "circuit breaker enabled",
+ zap.Duration("ban_period", banPeriod),
+ zap.Duration("reconnect_interval", reconnectInterval),
+ zap.Int("endpoints_count", len(uris)),
+ )
+
+ go cb.checkBannedEndpoints(ctx, reconnectInterval)
+
+ return cb
+}
+
+func (cb *circuitBreaker) updateBannedEndpointsMetric() {
+ if cb.bannedEndpointsMetric == nil {
+ return
+ }
+
+ cb.bannedEndpointsMetric.Set(float64(len(cb.endpoints) - len(cb.activeEndpoints)))
+}
+
+func (cb *circuitBreaker) getEndpoint() *fasthttp.URI {
+ cb.mu.RLock()
+ defer cb.mu.RUnlock()
+
+ if len(cb.activeEndpoints) == 0 {
+ return nil
+ }
+
+ idx := rand.Intn(len(cb.activeEndpoints))
+ return cb.endpoints[cb.activeEndpoints[idx]].uri
+}
+
+func (cb *circuitBreaker) banEndpoint(uri *fasthttp.URI) {
+ cb.mu.Lock()
+ defer cb.mu.Unlock()
+
+ idx := cb.idxByURI[uri.String()]
+ cb.endpoints[idx].banUntil = cb.nowFn().Add(cb.banPeriod)
+
+ for i, activeIdx := range cb.activeEndpoints {
+ if activeIdx == idx {
+ cb.activeEndpoints[i] = cb.activeEndpoints[len(cb.activeEndpoints)-1]
+ cb.activeEndpoints = cb.activeEndpoints[:len(cb.activeEndpoints)-1]
+ break
+ }
+ }
+
+ cb.logger.Info(
+ "endpoint banned",
+ zap.String("endpoint", uri.String()),
+ zap.Duration("ban_period", cb.banPeriod),
+ zap.Int("active_endpoints_count", len(cb.activeEndpoints)),
+ zap.Int("banned_endpoints_count", len(cb.endpoints)-len(cb.activeEndpoints)),
+ )
+
+ cb.updateBannedEndpointsMetric()
+}
+
+func (cb *circuitBreaker) restoreBannedEndpoints() {
+ cb.mu.RLock()
+ if len(cb.endpoints) == len(cb.activeEndpoints) {
+ cb.mu.RUnlock()
+ return
+ }
+ cb.mu.RUnlock()
+
+ cb.mu.Lock()
+ defer cb.mu.Unlock()
+
+ hasRestoredEndpoints := false
+ now := cb.nowFn()
+ for i := range cb.endpoints {
+ e := &cb.endpoints[i]
+ if !e.banUntil.IsZero() && now.After(e.banUntil) {
+ e.banUntil = time.Time{}
+ cb.activeEndpoints = append(cb.activeEndpoints, i)
+ hasRestoredEndpoints = true
+
+ cb.logger.Info(
+ "endpoint restored",
+ zap.String("endpoint", e.uri.String()),
+ zap.Int("active_endpoints_count", len(cb.activeEndpoints)),
+ zap.Int("banned_endpoints_count", len(cb.endpoints)-len(cb.activeEndpoints)),
+ )
+ }
+ }
+
+ if hasRestoredEndpoints {
+ cb.updateBannedEndpointsMetric()
+ }
+}
+
+func (cb *circuitBreaker) checkBannedEndpoints(ctx context.Context, reconnectInterval time.Duration) {
+ ticker := time.NewTicker(reconnectInterval)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ case <-ticker.C:
+ cb.restoreBannedEndpoints()
+ }
+ }
+}
+
+func (cb *circuitBreaker) setNowFn(nowFn func() time.Time) {
+ cb.mu.Lock()
+ defer cb.mu.Unlock()
+ cb.nowFn = nowFn
+}
diff --git a/xhttp/circuit_breaker_test.go b/xhttp/circuit_breaker_test.go
new file mode 100644
index 000000000..2ccafc0cb
--- /dev/null
+++ b/xhttp/circuit_breaker_test.go
@@ -0,0 +1,261 @@
+package xhttp
+
+import (
+ "sync"
+ "testing"
+ "testing/synctest"
+ "time"
+
+ "github.com/stretchr/testify/require"
+ "go.uber.org/zap"
+)
+
+const (
+ opBanEndpoint = iota + 1
+ opSleep
+)
+
+var (
+ defaultEndpoints = []string{"http://localhost:19200", "http://localhost:19201", "http://localhost:19202"}
+ defaultWorkerCount = 50
+)
+
+type cbStep struct {
+ operation int
+ idxEp int
+ duration time.Duration
+}
+
+func TestNewCircuitBreaker(t *testing.T) {
+ cases := []struct {
+ name string
+ banPeriod time.Duration
+ endpoints []string
+ disabled bool
+ }{
+ {
+ name: "ban_period_zero",
+ banPeriod: 0,
+ endpoints: defaultEndpoints,
+ disabled: true,
+ },
+ {
+ name: "single_endpoint",
+ banPeriod: 2 * time.Second,
+ endpoints: defaultEndpoints[0:1],
+ disabled: true,
+ },
+ {
+ name: "two_and_more_endpoints",
+ banPeriod: 3 * time.Second,
+ endpoints: defaultEndpoints,
+ disabled: false,
+ },
+ }
+
+ for _, tt := range cases {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ uris, err := parseEndpoints(tt.endpoints)
+ require.NoError(t, err)
+
+ ctx := t.Context()
+ cb := newCircuitBreaker(ctx, zap.NewNop(), uris, tt.banPeriod, 5*time.Minute, nil)
+
+ if tt.disabled {
+ require.Nil(t, cb, "circuit breaker must be disabled with these parameters")
+ return
+ }
+
+ require.NotNil(t, cb)
+ require.Len(t, cb.activeEndpoints, len(tt.endpoints))
+ require.Equal(t, cb.banPeriod, tt.banPeriod)
+ for i := range cb.endpoints {
+ require.Zero(t, cb.endpoints[i].banUntil, "endpoint[%d]: banUntil should be zero after creation", i)
+ }
+ })
+ }
+}
+
+func TestCircuitBreakerScenarios(t *testing.T) {
+ cases := []struct {
+ name string
+ endpoints []string
+ banPeriod time.Duration
+ steps []cbStep
+ wantActive []int
+ wantBanUntil map[int]time.Duration
+ }{
+ {
+ name: "ban_removes_endpoint_from_active",
+ endpoints: defaultEndpoints,
+ banPeriod: 10 * time.Second,
+ steps: []cbStep{{operation: opBanEndpoint, idxEp: 1}},
+ wantActive: []int{0, 2},
+ wantBanUntil: map[int]time.Duration{
+ 1: 10 * time.Second,
+ },
+ },
+ {
+ name: "ban_all_endpoint",
+ endpoints: defaultEndpoints,
+ banPeriod: 10 * time.Second,
+ steps: []cbStep{
+ {operation: opBanEndpoint, idxEp: 0},
+ {operation: opBanEndpoint, idxEp: 1},
+ {operation: opBanEndpoint, idxEp: 2},
+ },
+ wantActive: nil,
+ },
+ {
+ name: "ban_refreshes_ban_until",
+ endpoints: defaultEndpoints,
+ banPeriod: 10 * time.Second,
+ steps: []cbStep{
+ {operation: opBanEndpoint, idxEp: 0},
+ {operation: opSleep, duration: 5 * time.Second},
+ {operation: opBanEndpoint, idxEp: 0},
+ },
+ wantActive: []int{1, 2},
+ wantBanUntil: map[int]time.Duration{
+ 0: 15 * time.Second,
+ },
+ },
+ {
+ name: "does_not_restore_before_ban_period",
+ endpoints: defaultEndpoints,
+ banPeriod: 40 * time.Second,
+ steps: []cbStep{
+ {operation: opBanEndpoint, idxEp: 0},
+ {operation: opSleep, duration: 31 * time.Second},
+ },
+ wantActive: []int{1, 2},
+ },
+ {
+ name: "restores_after_ban_period",
+ endpoints: defaultEndpoints,
+ banPeriod: 25 * time.Second,
+ steps: []cbStep{
+ {operation: opBanEndpoint, idxEp: 0},
+ {operation: opSleep, duration: 31 * time.Second},
+ },
+ wantActive: []int{0, 1, 2},
+ },
+ {
+ name: "partially_restores_expired_endpoints",
+ endpoints: defaultEndpoints,
+ banPeriod: 10 * time.Second,
+ steps: []cbStep{
+ {operation: opBanEndpoint, idxEp: 0},
+ {operation: opSleep, duration: 25 * time.Second},
+ {operation: opBanEndpoint, idxEp: 1},
+ {operation: opSleep, duration: 10 * time.Second},
+ },
+ wantActive: []int{0, 2},
+ },
+ }
+
+ for _, tt := range cases {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ synctest.Test(t, func(t *testing.T) {
+ ctx := t.Context()
+
+ uris, err := parseEndpoints(tt.endpoints)
+ require.NoError(t, err)
+
+ cb := newCircuitBreaker(ctx, zap.NewNop(), uris, tt.banPeriod, 30*time.Second, nil)
+ require.NotNil(t, cb)
+ cb.setNowFn(time.Now)
+
+ start := time.Now()
+ for _, s := range tt.steps {
+ switch s.operation {
+ case opBanEndpoint:
+ cb.banEndpoint(cb.endpoints[s.idxEp].uri)
+ case opSleep:
+ time.Sleep(s.duration)
+ }
+ }
+
+ cb.mu.RLock()
+ activeEp := append([]int{}, cb.activeEndpoints...)
+ banUntil := make([]time.Time, len(cb.endpoints))
+ for i := range cb.endpoints {
+ banUntil[i] = cb.endpoints[i].banUntil
+ }
+ cb.mu.RUnlock()
+
+ require.ElementsMatch(t, tt.wantActive, activeEp)
+
+ for idx, dur := range tt.wantBanUntil {
+ if dur == 0 {
+ require.Zero(t, banUntil[idx], "endpoint[%d]: banUntil should be zero", idx)
+ }
+ require.Equal(t, start.Add(dur), banUntil[idx])
+ }
+ })
+ })
+ }
+}
+
+func TestCircuitBreakerFullCycle(t *testing.T) {
+ t.Parallel()
+
+ synctest.Test(t, func(t *testing.T) {
+ uris, err := parseEndpoints(defaultEndpoints)
+ require.NoError(t, err)
+
+ ctx := t.Context()
+ cb := newCircuitBreaker(ctx, zap.NewNop(), uris, 10*time.Second, 3*time.Second, nil)
+ require.NotNil(t, cb)
+ cb.setNowFn(time.Now)
+
+ ep0, ep1, ep2 := cb.endpoints[0].uri.String(), cb.endpoints[1].uri.String(), cb.endpoints[2].uri.String()
+ require.ElementsMatch(t, []string{ep0, ep1, ep2}, pickedURIs(cb, defaultWorkerCount))
+
+ cb.banEndpoint(cb.endpoints[0].uri)
+ require.ElementsMatch(t, []string{ep1, ep2}, pickedURIs(cb, defaultWorkerCount))
+ time.Sleep(5 * time.Second)
+
+ cb.banEndpoint(cb.endpoints[1].uri)
+ require.ElementsMatch(t, []string{ep2}, pickedURIs(cb, defaultWorkerCount))
+ time.Sleep(8 * time.Second)
+
+ require.ElementsMatch(t, []string{ep0, ep2}, pickedURIs(cb, defaultWorkerCount))
+
+ time.Sleep(7 * time.Second)
+ require.ElementsMatch(t, []string{ep0, ep1, ep2}, pickedURIs(cb, defaultWorkerCount))
+ })
+}
+
+func pickedURIs(cb *circuitBreaker, workers int) []string {
+ var (
+ wg sync.WaitGroup
+ mu sync.Mutex
+ seen = make(map[string]struct{})
+ )
+
+ for range workers {
+ wg.Go(func() {
+ uri := cb.getEndpoint()
+ if uri == nil {
+ return
+ }
+
+ mu.Lock()
+ seen[uri.String()] = struct{}{}
+ mu.Unlock()
+ })
+ }
+ wg.Wait()
+
+ out := make([]string, 0, len(seen))
+ for k := range seen {
+ out = append(out, k)
+ }
+
+ return out
+}
diff --git a/xhttp/client.go b/xhttp/client.go
index f9ea6e6a9..560d68289 100644
--- a/xhttp/client.go
+++ b/xhttp/client.go
@@ -1,13 +1,16 @@
package xhttp
import (
+ "context"
"fmt"
"math/rand"
"net/http"
"time"
+ "github.com/ozontech/file.d/metric"
"github.com/ozontech/file.d/xtls"
"github.com/valyala/fasthttp"
+ "go.uber.org/zap"
)
const gzipContentEncoding = "gzip"
@@ -23,24 +26,29 @@ type ClientKeepAliveConfig struct {
}
type ClientConfig struct {
- Endpoints []string
- ConnectionTimeout time.Duration
- AuthHeader string
- CustomHeaders map[string]string
- GzipCompressionLevel string
- TLS *ClientTLSConfig
- KeepAlive *ClientKeepAliveConfig
+ Endpoints []string
+ ConnectionTimeout time.Duration
+ AuthHeader string
+ CustomHeaders map[string]string
+ GzipCompressionLevel string
+ TLS *ClientTLSConfig
+ KeepAlive *ClientKeepAliveConfig
+ BanPeriod time.Duration
+ ReconnectInterval time.Duration
+ Logger *zap.Logger
+ BannedEndpointsMetric *metric.Gauge
}
type Client struct {
client *fasthttp.Client
endpoints []*fasthttp.URI
+ cb *circuitBreaker
authHeader string
customHeaders map[string]string
gzipCompressionLevel int
}
-func NewClient(cfg *ClientConfig) (*Client, error) {
+func NewClient(ctx context.Context, cfg *ClientConfig) (*Client, error) {
client := &fasthttp.Client{
ReadTimeout: cfg.ConnectionTimeout,
WriteTimeout: cfg.ConnectionTimeout,
@@ -70,8 +78,16 @@ func NewClient(cfg *ClientConfig) (*Client, error) {
}
return &Client{
- client: client,
- endpoints: endpoints,
+ client: client,
+ endpoints: endpoints,
+ cb: newCircuitBreaker(
+ ctx,
+ cfg.Logger,
+ endpoints,
+ cfg.BanPeriod,
+ cfg.ReconnectInterval,
+ cfg.BannedEndpointsMetric,
+ ),
authHeader: cfg.AuthHeader,
customHeaders: cfg.CustomHeaders,
gzipCompressionLevel: parseGzipCompressionLevel(cfg.GzipCompressionLevel),
@@ -89,16 +105,15 @@ func (c *Client) DoTimeout(
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseResponse(resp)
- var endpoint *fasthttp.URI
- if len(c.endpoints) == 1 {
- endpoint = c.endpoints[0]
- } else {
- endpoint = c.endpoints[rand.Int()%len(c.endpoints)]
+ endpoint := c.getEndpoint()
+ if endpoint == nil {
+ return 0, fmt.Errorf("no available endpoints")
}
c.prepareRequest(req, endpoint, method, contentType, body)
if err := c.client.DoTimeout(req, resp, timeout); err != nil {
+ c.banEndpoint(endpoint)
return 0, fmt.Errorf("can't send request to %s: %w", endpoint.String(), err)
}
@@ -106,6 +121,9 @@ func (c *Client) DoTimeout(
statusCode := resp.Header.StatusCode()
if !(http.StatusOK <= statusCode && statusCode <= http.StatusAccepted) {
+ if shouldBanEndpoint(statusCode) {
+ c.banEndpoint(endpoint)
+ }
return statusCode, fmt.Errorf("response status from %s isn't OK: status=%d, body=%s", endpoint.String(), statusCode, string(respContent))
}
@@ -168,3 +186,32 @@ func parseGzipCompressionLevel(level string) int {
return -1
}
}
+
+func (c *Client) getEndpoint() *fasthttp.URI {
+ if c.cb != nil {
+ return c.cb.getEndpoint()
+ }
+
+ if len(c.endpoints) == 0 {
+ return nil
+ }
+ return c.endpoints[rand.Intn(len(c.endpoints))]
+}
+
+func (c *Client) banEndpoint(endpoint *fasthttp.URI) {
+ if c.cb != nil {
+ c.cb.banEndpoint(endpoint)
+ }
+}
+
+func shouldBanEndpoint(statusCode int) bool {
+ switch statusCode {
+ case http.StatusBadGateway,
+ http.StatusServiceUnavailable,
+ http.StatusGatewayTimeout,
+ http.StatusTooManyRequests:
+ return true
+ default:
+ return false
+ }
+}