diff --git a/modules/caddyhttp/reverseproxy/active_health_test.go b/modules/caddyhttp/reverseproxy/active_health_test.go new file mode 100644 index 00000000000..089f96aec86 --- /dev/null +++ b/modules/caddyhttp/reverseproxy/active_health_test.go @@ -0,0 +1,190 @@ +// Copyright 2015 Matthew Holt and The Caddy Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package reverseproxy + +import ( + "testing" +) + +func newTestUpstream() *Upstream { + return &Upstream{ + Host: new(Host), + activeHealthStats: &ActiveHealthStats{}, + } +} + +// TestConsecutiveCounterResetOnPass verifies that a health check pass +// resets the consecutive failure counter to zero. Without this, non- +// consecutive failures could accumulate and incorrectly trip the threshold. +func TestConsecutiveCounterResetOnPass(t *testing.T) { + upstream := newTestUpstream() + + // Simulate: fail, fail, then pass + upstream.countHealthFail(1) + upstream.countHealthFail(1) + if upstream.activeHealthFails() != 2 { + t.Fatalf("expected 2 fails, got %d", upstream.activeHealthFails()) + } + + // A pass should reset the fail counter + upstream.countHealthPass(1) + if upstream.activeHealthFails() != 0 { + t.Errorf("expected fail counter to reset to 0 after a pass, got %d", upstream.activeHealthFails()) + } + if upstream.activeHealthPasses() != 1 { + t.Errorf("expected 1 pass, got %d", upstream.activeHealthPasses()) + } +} + +// TestConsecutiveCounterResetOnFail verifies that a health check failure +// resets the consecutive pass counter to zero. +func TestConsecutiveCounterResetOnFail(t *testing.T) { + upstream := newTestUpstream() + + // Simulate: pass, pass, then fail + upstream.countHealthPass(1) + upstream.countHealthPass(1) + if upstream.activeHealthPasses() != 2 { + t.Fatalf("expected 2 passes, got %d", upstream.activeHealthPasses()) + } + + // A fail should reset the pass counter + upstream.countHealthFail(1) + if upstream.activeHealthPasses() != 0 { + t.Errorf("expected pass counter to reset to 0 after a fail, got %d", upstream.activeHealthPasses()) + } + if upstream.activeHealthFails() != 1 { + t.Errorf("expected 1 fail, got %d", upstream.activeHealthFails()) + } +} + +// TestNonConsecutiveFailuresDoNotTripThreshold is a regression test: +// interleaved pass/fail results must NOT accumulate toward the threshold. +// Before the fix, fail-pass-fail-pass-fail would reach Fails=3 even +// though there were zero consecutive failures. +func TestNonConsecutiveFailuresDoNotTripThreshold(t *testing.T) { + upstream := newTestUpstream() + + // Interleave: fail, pass, fail, pass, fail + for i := 0; i < 3; i++ { + upstream.countHealthFail(1) + if i < 2 { + upstream.countHealthPass(1) + } + } + + // With correct consecutive tracking, we should have only 1 consecutive fail + if upstream.activeHealthFails() != 1 { + t.Errorf("expected 1 consecutive fail, got %d", upstream.activeHealthFails()) + } +} + +// TestConsecutiveFailuresDoTripThreshold verifies that truly consecutive +// failures correctly accumulate and trip the threshold. +func TestConsecutiveFailuresDoTripThreshold(t *testing.T) { + upstream := newTestUpstream() + + const failThreshold = 3 + + upstream.countHealthFail(1) + upstream.countHealthFail(1) + upstream.countHealthFail(1) + + if upstream.activeHealthFails() != 3 { + t.Errorf("expected 3 consecutive fails, got %d", upstream.activeHealthFails()) + } + if upstream.activeHealthFails() < failThreshold { + t.Error("3 consecutive failures should trip threshold of 3") + } + // Pass counter should be 0 (reset by the first fail) + if upstream.activeHealthPasses() != 0 { + t.Errorf("expected 0 passes after consecutive fails, got %d", upstream.activeHealthPasses()) + } +} + +// TestInitiallyUnhealthy verifies that when InitiallyUnhealthy is true +// and there are no prior health check passes, the upstream starts unhealthy. +func TestInitiallyUnhealthy(t *testing.T) { + upstream := &Upstream{ + Dial: "10.4.0.1:80", + Host: new(Host), + activeHealthStats: &ActiveHealthStats{}, + } + + // Simulate what Provision does when InitiallyUnhealthy=true and + // passes=0 (fresh host, no prior health checks) + passes := 1 // default Passes threshold + upstream.setHealthy(upstream.activeHealthPasses() >= passes) + + if upstream.healthy() { + t.Error("upstream should be unhealthy when InitiallyUnhealthy=true and no passes recorded") + } +} + +// TestInitiallyUnhealthyWithPriorPasses verifies that when InitiallyUnhealthy +// is true but the host already has enough passes (e.g., across a reload), +// it starts healthy. +func TestInitiallyUnhealthyWithPriorPasses(t *testing.T) { + stats := &ActiveHealthStats{} + upstream := &Upstream{ + Dial: "10.4.0.2:80", + Host: new(Host), + activeHealthStats: stats, + } + upstream.countHealthPass(1) // simulate a prior health check pass + + passes := 1 + upstream.setHealthy(upstream.activeHealthPasses() >= passes) + + if !upstream.healthy() { + t.Error("upstream should be healthy when it has enough prior passes, even with InitiallyUnhealthy=true") + } +} + +// TestInitiallyHealthyDefault verifies the default behavior: upstreams +// start healthy unless they have accumulated enough failures. +func TestInitiallyHealthyDefault(t *testing.T) { + upstream := &Upstream{ + Dial: "10.4.0.3:80", + Host: new(Host), + activeHealthStats: &ActiveHealthStats{}, + } + + // Default behavior: healthy unless fails >= threshold + fails := 1 + upstream.setHealthy(upstream.activeHealthFails() < fails) + + if !upstream.healthy() { + t.Error("upstream should be healthy by default when no failures recorded") + } +} + +// TestInitiallyHealthyDefaultWithPriorFails verifies that an upstream +// with prior failures (e.g., from before a reload) starts unhealthy. +func TestInitiallyHealthyDefaultWithPriorFails(t *testing.T) { + upstream := &Upstream{ + Dial: "10.4.0.4:80", + Host: new(Host), + activeHealthStats: &ActiveHealthStats{}, + } + upstream.countHealthFail(1) // simulate a prior failure + + fails := 1 + upstream.setHealthy(upstream.activeHealthFails() < fails) + + if upstream.healthy() { + t.Error("upstream should be unhealthy when it has prior failures >= threshold") + } +} diff --git a/modules/caddyhttp/reverseproxy/caddyfile.go b/modules/caddyhttp/reverseproxy/caddyfile.go index 777bc06ac67..c697ef7112d 100644 --- a/modules/caddyhttp/reverseproxy/caddyfile.go +++ b/modules/caddyhttp/reverseproxy/caddyfile.go @@ -514,6 +514,18 @@ func (h *Handler) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { } h.HealthChecks.Active.FollowRedirects = true + case "health_initially_unhealthy": + if d.NextArg() { + return d.ArgErr() + } + if h.HealthChecks == nil { + h.HealthChecks = new(HealthChecks) + } + if h.HealthChecks.Active == nil { + h.HealthChecks.Active = new(ActiveHealthChecks) + } + h.HealthChecks.Active.InitiallyUnhealthy = true + case "health_passes": if !d.NextArg() { return d.ArgErr() diff --git a/modules/caddyhttp/reverseproxy/healthchecks.go b/modules/caddyhttp/reverseproxy/healthchecks.go index 73604f916fa..623b58ad674 100644 --- a/modules/caddyhttp/reverseproxy/healthchecks.go +++ b/modules/caddyhttp/reverseproxy/healthchecks.go @@ -16,6 +16,7 @@ package reverseproxy import ( "context" + "errors" "fmt" "io" "net" @@ -127,7 +128,10 @@ type ActiveHealthChecks struct { // body of a healthy backend. ExpectBody string `json:"expect_body,omitempty"` - uri *url.URL + // Whether backends are initially considered unhealthy. + InitiallyUnhealthy bool `json:"initially_unhealthy,omitempty"` + + uri url.URL httpClient *http.Client bodyRegexp *regexp.Regexp logger *zap.Logger @@ -163,15 +167,16 @@ func (a *ActiveHealthChecks) Provision(ctx caddy.Context, h *Handler) error { if a.Path != "" { a.logger.Warn("the 'path' option is deprecated, please use 'uri' instead!") + a.uri.Path = a.Path } - // parse the URI string (supports path and query) + // parse the URI string (supports path and query) and takes precedence over the deprecated Path field if a.URI != "" { parsedURI, err := url.Parse(a.URI) if err != nil { return err } - a.uri = parsedURI + a.uri = *parsedURI } a.httpClient = &http.Client{ @@ -185,7 +190,22 @@ func (a *ActiveHealthChecks) Provision(ctx caddy.Context, h *Handler) error { }, } + if a.Passes < 1 { + a.Passes = 1 + } + + if a.Fails < 1 { + a.Fails = 1 + } + for _, upstream := range h.Upstreams { + upstream.provisionActiveHealthStats(a.uri.String()) + if a.InitiallyUnhealthy { + upstream.setHealthy(upstream.activeHealthPasses() >= a.Passes) + } else { + upstream.setHealthy(upstream.activeHealthFails() < a.Fails) + } + // if there's an alternative upstream for health-check provided in the config, // then use it, otherwise use the upstream's dial address. if upstream is used, // then the port is ignored. @@ -210,14 +230,6 @@ func (a *ActiveHealthChecks) Provision(ctx caddy.Context, h *Handler) error { } } - if a.Passes < 1 { - a.Passes = 1 - } - - if a.Fails < 1 { - a.Fails = 1 - } - return nil } @@ -391,8 +403,10 @@ func (h *Handler) doActiveHealthCheckForAllHosts() { func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, networkAddr string, upstream *Upstream) error { // create the URL for the request that acts as a health check u := &url.URL{ - Scheme: "http", - Host: hostAddr, + Scheme: "http", + Host: hostAddr, + Path: h.HealthChecks.Active.uri.Path, + RawQuery: h.HealthChecks.Active.uri.RawQuery, } // split the host and port if possible, override the port if configured @@ -415,15 +429,6 @@ func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, networ hcsot.OverrideHealthCheckScheme(u, port) } - // if we have a provisioned uri, use that, otherwise use - // the deprecated Path option - if h.HealthChecks.Active.uri != nil { - u.Path = h.HealthChecks.Active.uri.Path - u.RawQuery = h.HealthChecks.Active.uri.RawQuery - } else { - u.Path = h.HealthChecks.Active.Path - } - // replacer used for both body and headers. Only globals (env vars, system info, etc.) are available repl := caddy.NewReplacer() @@ -463,7 +468,7 @@ func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, networ markUnhealthy := func() { // increment failures and then check if it has reached the threshold to mark unhealthy - err := upstream.Host.countHealthFail(1) + err := upstream.countHealthFail(1) if err != nil { if c := h.HealthChecks.Active.logger.Check(zapcore.ErrorLevel, "could not count active health failure"); c != nil { c.Write( @@ -473,11 +478,10 @@ func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, networ } return } - if upstream.Host.activeHealthFails() >= h.HealthChecks.Active.Fails { + if upstream.activeHealthFails() >= h.HealthChecks.Active.Fails { // dispatch an event that the host newly became unhealthy if upstream.setHealthy(false) { h.events.Emit(h.ctx, "unhealthy", map[string]any{"host": hostAddr}) - upstream.Host.resetHealth() } } } @@ -494,13 +498,12 @@ func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, networ } return } - if upstream.Host.activeHealthPasses() >= h.HealthChecks.Active.Passes { + if upstream.activeHealthPasses() >= h.HealthChecks.Active.Passes { if upstream.setHealthy(true) { if c := h.HealthChecks.Active.logger.Check(zapcore.InfoLevel, "host is up"); c != nil { c.Write(zap.String("host", hostAddr)) } h.events.Emit(h.ctx, "healthy", map[string]any{"host": hostAddr}) - upstream.Host.resetHealth() } } } @@ -508,6 +511,11 @@ func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, networ // do the request, being careful to tame the response body resp, err := h.HealthChecks.Active.httpClient.Do(req) //nolint:gosec // no SSRF if err != nil { + if errors.Is(err, context.Canceled) { + // context was canceled, so don't count this as a failure + return nil + } + if c := h.HealthChecks.Active.logger.Check(zapcore.InfoLevel, "HTTP request failed"); c != nil { c.Write( zap.String("host", hostAddr), diff --git a/modules/caddyhttp/reverseproxy/hosts.go b/modules/caddyhttp/reverseproxy/hosts.go index a5406e04ef6..541b65a9b79 100644 --- a/modules/caddyhttp/reverseproxy/hosts.go +++ b/modules/caddyhttp/reverseproxy/hosts.go @@ -58,6 +58,7 @@ type Upstream struct { // HeaderAffinity string // IPAffinity string + activeHealthStats *ActiveHealthStats activeHealthCheckPort int activeHealthCheckUpstream string healthCheckPolicy *PassiveHealthChecks @@ -134,6 +135,37 @@ func (u *Upstream) fillHost() { u.Host = host } +func (u *Upstream) provisionActiveHealthStats(key string) { + u.Host.activeHealthStatsMu.Lock() + defer u.Host.activeHealthStatsMu.Unlock() + + if u.Host.activeHealthStats == nil { + u.Host.activeHealthStats = make(map[string]*ActiveHealthStats) + } + + stats, ok := u.Host.activeHealthStats[key] + if !ok { + stats = &ActiveHealthStats{key: key} + u.Host.activeHealthStats[key] = stats + } + + stats.refs++ + u.activeHealthStats = stats +} + +func (u *Upstream) releaseActiveHealthStats() { + if u.activeHealthStats != nil { + u.Host.activeHealthStatsMu.Lock() + defer u.Host.activeHealthStatsMu.Unlock() + + u.activeHealthStats.refs-- + if u.activeHealthStats.refs <= 0 { + delete(u.Host.activeHealthStats, u.activeHealthStats.key) + } + u.activeHealthStats = nil + } +} + // fillDynamicHost is like fillHost, but stores the host in the separate // dynamicHosts map rather than the reference-counted UsagePool. Dynamic // hosts are not reference-counted; instead, they are retained as long as @@ -171,13 +203,22 @@ func (u *Upstream) fillDynamicHost() { }) } +// ActiveHealthStats holds the health check stats for an active health check URI. +type ActiveHealthStats struct { + key string + refs int64 // synchronized via Host.activeHealthStatsMu + consecutivePasses atomic.Int64 + consecutiveFails atomic.Int64 +} + // Host is the basic, in-memory representation of the state of a remote host. // Its fields are accessed atomically and Host values must not be copied. type Host struct { - numRequests atomic.Int64 // atomic.Int64 is automatically aligned for us (see https://golang.org/pkg/sync/atomic/#pkg-note-BUG) - fails atomic.Int64 - activePasses atomic.Int64 - activeFails atomic.Int64 + numRequests atomic.Int64 // atomic.Int64 is automatically aligned for us (see https://golang.org/pkg/sync/atomic/#pkg-note-BUG) + fails atomic.Int64 + + activeHealthStatsMu sync.Mutex // protects activeHealthStats and refs in ActiveHealthStats + activeHealthStats map[string]*ActiveHealthStats // keyed by active health check URI } // NumRequests returns the number of active requests to the upstream. @@ -191,13 +232,19 @@ func (h *Host) Fails() int { } // activeHealthPasses returns the number of consecutive active health check passes with the upstream. -func (h *Host) activeHealthPasses() int { - return int(h.activePasses.Load()) +func (u *Upstream) activeHealthPasses() int { + if u.activeHealthStats == nil { + return 0 + } + return int(u.activeHealthStats.consecutivePasses.Load()) } // activeHealthFails returns the number of consecutive active health check failures with the upstream. -func (h *Host) activeHealthFails() int { - return int(h.activeFails.Load()) +func (u *Upstream) activeHealthFails() int { + if u.activeHealthStats == nil { + return 0 + } + return int(u.activeHealthStats.consecutiveFails.Load()) } // countRequest mutates the active request count by @@ -222,30 +269,34 @@ func (h *Host) countFail(delta int) error { // countHealthPass mutates the recent passes count by // delta. It returns an error if the adjustment fails. -func (h *Host) countHealthPass(delta int) error { - result := h.activePasses.Add(int64(delta)) +func (u *Upstream) countHealthPass(delta int) error { + if u.activeHealthStats == nil { + return fmt.Errorf("active health stats not provisioned for upstream %s", u.String()) + } + + result := u.activeHealthStats.consecutivePasses.Add(int64(delta)) if result < 0 { return fmt.Errorf("count below 0: %d", result) } + u.activeHealthStats.consecutiveFails.Store(0) return nil } // countHealthFail mutates the recent failures count by // delta. It returns an error if the adjustment fails. -func (h *Host) countHealthFail(delta int) error { - result := h.activeFails.Add(int64(delta)) +func (u *Upstream) countHealthFail(delta int) error { + if u.activeHealthStats == nil { + return fmt.Errorf("active health stats not provisioned for upstream %s", u.String()) + } + + result := u.activeHealthStats.consecutiveFails.Add(int64(delta)) if result < 0 { return fmt.Errorf("count below 0: %d", result) } + u.activeHealthStats.consecutivePasses.Store(0) return nil } -// resetHealth resets the health check counters. -func (h *Host) resetHealth() { - h.activePasses.Store(0) - h.activeFails.Store(0) -} - // healthy returns true if the upstream is not actively marked as unhealthy. // (This returns the status only from the "active" health checks.) func (u *Upstream) healthy() bool { diff --git a/modules/caddyhttp/reverseproxy/reverseproxy.go b/modules/caddyhttp/reverseproxy/reverseproxy.go index 2169d17173f..09134bac1c4 100644 --- a/modules/caddyhttp/reverseproxy/reverseproxy.go +++ b/modules/caddyhttp/reverseproxy/reverseproxy.go @@ -437,6 +437,7 @@ func (h *Handler) Cleanup() error { // remove hosts from our config from the pool for _, upstream := range h.Upstreams { + upstream.releaseActiveHealthStats() _, _ = hosts.Delete(upstream.String()) }