From b68e9bfdd458c2bc23d8380dff35c0a918a64c70 Mon Sep 17 00:00:00 2001 From: Francis Lavoie Date: Mon, 13 Apr 2026 04:23:23 -0400 Subject: [PATCH 01/17] reverseproxy: Optionally detach stream (websockets) from config lifecycle --- .../reverseproxy_upgrade_handlers_test.go | 130 +++++ .../integration/stream_reload_stress_test.go | 487 ++++++++++++++++++ modules/caddyhttp/encode/encode.go | 18 +- modules/caddyhttp/responsewriter.go | 80 +-- modules/caddyhttp/responsewriter_test.go | 93 ++++ modules/caddyhttp/reverseproxy/caddyfile.go | 14 + .../caddyhttp/reverseproxy/copyresponse.go | 2 +- modules/caddyhttp/reverseproxy/metrics.go | 79 +++ .../caddyhttp/reverseproxy/metrics_test.go | 67 +++ .../caddyhttp/reverseproxy/reverseproxy.go | 63 ++- modules/caddyhttp/reverseproxy/streaming.go | 379 ++++++++++---- .../caddyhttp/reverseproxy/streaming_test.go | 134 +++++ 12 files changed, 1388 insertions(+), 158 deletions(-) create mode 100644 caddytest/integration/reverseproxy_upgrade_handlers_test.go create mode 100644 caddytest/integration/stream_reload_stress_test.go create mode 100644 modules/caddyhttp/reverseproxy/metrics_test.go diff --git a/caddytest/integration/reverseproxy_upgrade_handlers_test.go b/caddytest/integration/reverseproxy_upgrade_handlers_test.go new file mode 100644 index 00000000000..dda93db0ea3 --- /dev/null +++ b/caddytest/integration/reverseproxy_upgrade_handlers_test.go @@ -0,0 +1,130 @@ +package integration + +import ( + "bufio" + "fmt" + "io" + "net" + "net/textproto" + "strings" + "testing" + "time" + + "github.com/caddyserver/caddy/v2/caddytest" +) + +func TestReverseProxyUpgradeWithEncode(t *testing.T) { + tester := caddytest.NewTester(t) + backend := newUpgradeEchoBackend(t) + defer backend.Close() + + tester.InitServer(fmt.Sprintf(` +{ + admin localhost:2999 + http_port 9080 + https_port 9443 + grace_period 1ns + skip_install_trust +} + +localhost:9080 { + route { + encode gzip + reverse_proxy %s + } +} +`, backend.addr), "caddyfile") + + client := newUpgradedStreamClientWithHeaders(t, map[string]string{ + "Accept-Encoding": "gzip", + }) + defer client.Close() + + if err := client.echo("encode-upgrade\n"); err != nil { + t.Fatalf("upgraded stream echo through encode failed: %v", err) + } +} + +func TestReverseProxyUpgradeWithInterceptHandleResponse(t *testing.T) { + tester := caddytest.NewTester(t) + backend := newUpgradeEchoBackend(t) + defer backend.Close() + + tester.InitServer(fmt.Sprintf(` +{ + admin localhost:2999 + http_port 9080 + https_port 9443 + grace_period 1ns + skip_install_trust +} + +localhost:9080 { + route { + intercept { + @upgrade status 101 + handle_response @upgrade { + respond "should-not-run" + } + } + reverse_proxy %s + } +} +`, backend.addr), "caddyfile") + + client := newUpgradedStreamClientWithHeaders(t, nil) + defer client.Close() + + if err := client.echo("intercept-upgrade\n"); err != nil { + t.Fatalf("upgraded stream echo through intercept failed: %v", err) + } +} + +func newUpgradedStreamClientWithHeaders(t *testing.T, extraHeaders map[string]string) *upgradedStreamClient { + t.Helper() + + conn, err := net.DialTimeout("tcp", "127.0.0.1:9080", 5*time.Second) + if err != nil { + t.Fatalf("dialing caddy: %v", err) + } + + requestLines := []string{ + "GET /upgrade HTTP/1.1", + "Host: localhost:9080", + "Connection: Upgrade", + "Upgrade: stress-stream", + } + for k, v := range extraHeaders { + requestLines = append(requestLines, k+": "+v) + } + requestLines = append(requestLines, "", "") + + if _, err := io.WriteString(conn, strings.Join(requestLines, "\r\n")); err != nil { + _ = conn.Close() + t.Fatalf("writing upgrade request: %v", err) + } + + reader := bufio.NewReader(conn) + tproto := textproto.NewReader(reader) + statusLine, err := tproto.ReadLine() + if err != nil { + _ = conn.Close() + t.Fatalf("reading upgrade status line: %v", err) + } + if !strings.Contains(statusLine, "101") { + _ = conn.Close() + t.Fatalf("unexpected upgrade status: %s", statusLine) + } + + headers, err := tproto.ReadMIMEHeader() + if err != nil { + _ = conn.Close() + t.Fatalf("reading upgrade headers: %v", err) + } + if !strings.EqualFold(headers.Get("Connection"), "Upgrade") { + _ = conn.Close() + t.Fatalf("unexpected upgrade response headers: %v", headers) + } + + return &upgradedStreamClient{conn: conn, reader: reader} +} diff --git a/caddytest/integration/stream_reload_stress_test.go b/caddytest/integration/stream_reload_stress_test.go new file mode 100644 index 00000000000..cd0b354caef --- /dev/null +++ b/caddytest/integration/stream_reload_stress_test.go @@ -0,0 +1,487 @@ +package integration + +import ( + "bufio" + "bytes" + "fmt" + "io" + "net" + "net/http" + "net/textproto" + "os" + "runtime" + "runtime/debug" + "runtime/pprof" + "strconv" + "strings" + "sync" + "testing" + "time" + + "github.com/caddyserver/caddy/v2/caddytest" +) + +// stressCloseDelay is the stream_close_delay used for the close_delay scenario. +// Long enough to outlast all test reloads; short enough to keep total test time reasonable. +const stressCloseDelay = 3 * time.Second + +func TestReverseProxyReloadStressUpgradedStreamsHeapProfiles(t *testing.T) { + tester := caddytest.NewTester(t).WithDefaultOverrides(caddytest.Config{ + LoadRequestTimeout: 30 * time.Second, + TestRequestTimeout: 30 * time.Second, + }) + + backend := newUpgradeEchoBackend(t) + defer backend.Close() + + // Three scenarios, each sequential so they don't share Caddy state: + // + // legacy – no delay, close on reload immediately (old default) + // close_delay – stream_close_delay, the old "keep-alive workaround" + // retain – stream_retain_on_reload, the new explicit retain flag + // + // Reloads are spread across time and interleaved with echo-checks so + // stream health is exercised at each reload boundary, not only at the end. + legacy := runReloadStress(t, tester, backend.addr, "legacy", false, 0) + closeDelay := runReloadStress(t, tester, backend.addr, "close_delay", false, stressCloseDelay) + retain := runReloadStress(t, tester, backend.addr, "retain", true, 0) + + if legacy.aliveAfterReloads != 0 { + t.Fatalf("legacy mode left %d upgraded streams alive after reloads", legacy.aliveAfterReloads) + } + if closeDelay.aliveBeforeDelayExpiry == 0 { + t.Fatalf("close_delay mode: all streams closed before delay expired (expected them alive)") + } + if closeDelay.aliveAfterReloads != 0 { + t.Fatalf("close_delay mode left %d upgraded streams alive after delay expiry", closeDelay.aliveAfterReloads) + } + if retain.aliveAfterReloads != retain.streamCount { + t.Fatalf("retain mode kept %d/%d upgraded streams alive after reloads", retain.aliveAfterReloads, retain.streamCount) + } + + t.Logf("legacy heap: before=%s mid=%s after=%s delta(before→after)=%s objects(before=%d after=%d) handler_frames(before=%d after=%d)", + formatBytes(legacy.beforeReload.HeapInuse), + formatBytes(legacy.midReload.HeapInuse), + formatBytes(legacy.afterReload.HeapInuse), + formatBytesDiff(legacy.beforeReload.HeapInuse, legacy.afterReload.HeapInuse), + legacy.beforeReload.HeapObjects, legacy.afterReload.HeapObjects, + legacy.beforeReload.handlerFrames, legacy.afterReload.handlerFrames, + ) + t.Logf("close_delay heap: before=%s mid=%s after=%s delta(before→after)=%s objects(before=%d after=%d) handler_frames(before=%d after=%d)", + formatBytes(closeDelay.beforeReload.HeapInuse), + formatBytes(closeDelay.midReload.HeapInuse), + formatBytes(closeDelay.afterReload.HeapInuse), + formatBytesDiff(closeDelay.beforeReload.HeapInuse, closeDelay.afterReload.HeapInuse), + closeDelay.beforeReload.HeapObjects, closeDelay.afterReload.HeapObjects, + closeDelay.beforeReload.handlerFrames, closeDelay.afterReload.handlerFrames, + ) + t.Logf("retain heap: before=%s mid=%s after=%s delta(before→after)=%s objects(before=%d after=%d) handler_frames(before=%d after=%d)", + formatBytes(retain.beforeReload.HeapInuse), + formatBytes(retain.midReload.HeapInuse), + formatBytes(retain.afterReload.HeapInuse), + formatBytesDiff(retain.beforeReload.HeapInuse, retain.afterReload.HeapInuse), + retain.beforeReload.HeapObjects, retain.afterReload.HeapObjects, + retain.beforeReload.handlerFrames, retain.afterReload.handlerFrames, + ) +} + +type stressRunResult struct { + streamCount int + aliveAfterReloads int + aliveBeforeDelayExpiry int // only meaningful for close_delay mode + beforeReload heapSnapshot + midReload heapSnapshot // after all reloads, before delay expiry clean-up + afterReload heapSnapshot // after all streams have been fully cleaned up +} + +type heapSnapshot struct { + HeapInuse uint64 + HeapObjects uint64 + handlerFrames int + profileBytes int +} + +// runReloadStress opens streamCount upgraded streams, then performs reloadCount +// config reloads spread over time. An echo check is performed every 6 reloads so +// stream health is exercised at each reload boundary rather than only at the end. +// closeDelay mirrors the stream_close_delay config option; pass 0 to disable. +func runReloadStress(t *testing.T, tester *caddytest.Tester, backendAddr, mode string, retain bool, closeDelay time.Duration) stressRunResult { + t.Helper() + + const echoEvery = 6 // perform an echo check every N reloads + + streamCount := envIntOrDefault(t, "CADDY_STRESS_STREAM_COUNT", 12) + reloadCount := envIntOrDefault(t, "CADDY_STRESS_RELOAD_COUNT", 24) + + tester.InitServer(reloadStressConfig(backendAddr, retain, closeDelay, 0), "caddyfile") + + clients := make([]*upgradedStreamClient, 0, streamCount) + for i := 0; i < streamCount; i++ { + client := newUpgradedStreamClient(t) + clients = append(clients, client) + if err := client.echo(fmt.Sprintf("%s-warmup-%02d\n", mode, i)); err != nil { + closeClients(clients) + t.Fatalf("warmup echo failed in %s mode: %v", mode, err) + } + } + defer closeClients(clients) + + before := captureHeapSnapshot(t) + + // Reloads are spread across time; between batches of echoEvery reloads we + // pause briefly and measure stream health so the snapshot reflects real-world + // reload cadence rather than a tight loop. + for i := 1; i <= reloadCount; i++ { + loadCaddyfileConfig(t, reloadStressConfig(backendAddr, retain, closeDelay, i)) + + // Small pause after each reload to let connection teardown propagate. + time.Sleep(50 * time.Millisecond) + + if i%echoEvery == 0 { + alive := countAliveStreams(clients) + t.Logf("%s mode: %d/%d streams alive after reload %d", mode, alive, streamCount, i) + + // In retain mode every stream must survive every reload (upstream unchanged). + if retain { + for j, client := range clients { + if err := client.echo(fmt.Sprintf("%s-mid-%02d-%02d\n", mode, i, j)); err != nil { + t.Fatalf("retain mode stream %d died at reload %d: %v", j, i, err) + } + } + } + } + } + + // mid snapshot: after all reloads but before any close_delay timer has fired + // (the delay is long enough to still be running at this point). + mid := captureHeapSnapshot(t) + + // For legacy mode: the reloads close streams immediately; wait for that to complete. + // For close_delay mode: streams are still alive here; wait for the delay to fire. + // For retain mode: streams survive indefinitely; no wait needed. + var aliveBeforeDelayExpiry int + aliveAfterReloads := countAliveStreams(clients) + switch { + case retain: + // nothing to wait for + case closeDelay > 0: + // streams should still be alive at this point (delay hasn't expired) + aliveBeforeDelayExpiry = aliveAfterReloads + t.Logf("%s mode: %d/%d streams alive before close_delay expires; waiting %v for cleanup", + mode, aliveBeforeDelayExpiry, streamCount, closeDelay) + time.Sleep(closeDelay + 200*time.Millisecond) + aliveAfterReloads = countAliveStreams(clients) + default: + deadline := time.Now().Add(2 * time.Second) + for aliveAfterReloads > 0 && time.Now().Before(deadline) { + time.Sleep(50 * time.Millisecond) + aliveAfterReloads = countAliveStreams(clients) + } + } + + after := captureHeapSnapshot(t) + t.Logf("%s mode heap profile size: before=%dB mid=%dB after=%dB objects(before=%d mid=%d after=%d)", + mode, + before.profileBytes, mid.profileBytes, after.profileBytes, + before.HeapObjects, mid.HeapObjects, after.HeapObjects, + ) + + return stressRunResult{ + streamCount: streamCount, + aliveAfterReloads: aliveAfterReloads, + aliveBeforeDelayExpiry: aliveBeforeDelayExpiry, + beforeReload: before, + midReload: mid, + afterReload: after, + } +} + +func envIntOrDefault(t *testing.T, key string, def int) int { + t.Helper() + raw := strings.TrimSpace(os.Getenv(key)) + if raw == "" { + return def + } + v, err := strconv.Atoi(raw) + if err != nil || v <= 0 { + t.Fatalf("invalid %s=%q: must be a positive integer", key, raw) + } + return v +} + +func loadCaddyfileConfig(t *testing.T, rawConfig string) { + t.Helper() + + client := &http.Client{Timeout: 30 * time.Second} + req, err := http.NewRequest(http.MethodPost, "http://localhost:2999/load", strings.NewReader(rawConfig)) + if err != nil { + t.Fatalf("creating load request: %v", err) + } + req.Header.Set("Content-Type", "text/caddyfile") + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("loading config: %v", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("reading load response: %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("loading config failed: status=%d body=%s", resp.StatusCode, body) + } +} + +func reloadStressConfig(backendAddr string, retain bool, closeDelay time.Duration, revision int) string { + var directives string + if retain { + directives += "\n\t\tstream_retain_on_reload" + } + if closeDelay > 0 { + directives += fmt.Sprintf("\n\t\tstream_close_delay %s", closeDelay) + } + + return fmt.Sprintf(` +{ + admin localhost:2999 + http_port 9080 + https_port 9443 + grace_period 1ns + skip_install_trust +} + +localhost:9080 { + reverse_proxy %s { + header_up X-Reload-Revision %d%s + } +} +`, backendAddr, revision, directives) +} + +func captureHeapSnapshot(t *testing.T) heapSnapshot { + t.Helper() + + runtime.GC() + debug.FreeOSMemory() + + var mem runtime.MemStats + runtime.ReadMemStats(&mem) + + var buf bytes.Buffer + if err := pprof.Lookup("heap").WriteTo(&buf, 1); err != nil { + t.Fatalf("capturing heap profile: %v", err) + } + profile := buf.String() + + return heapSnapshot{ + HeapInuse: mem.HeapInuse, + HeapObjects: mem.HeapObjects, + handlerFrames: strings.Count(profile, "modules/caddyhttp/reverseproxy.(*Handler)"), + profileBytes: buf.Len(), + } +} + +func countAliveStreams(clients []*upgradedStreamClient) int { + alive := 0 + for index, client := range clients { + if err := client.echo(fmt.Sprintf("alive-check-%02d\n", index)); err == nil { + alive++ + } + } + return alive +} + +func closeClients(clients []*upgradedStreamClient) { + for _, client := range clients { + if client != nil { + _ = client.Close() + } + } +} + +func formatBytes(value uint64) string { + const unit = 1024 + if value < unit { + return fmt.Sprintf("%d B", value) + } + div, exp := uint64(unit), 0 + for n := value / unit; n >= unit; n /= unit { + div *= unit + exp++ + } + return fmt.Sprintf("%.1f %ciB", float64(value)/float64(div), "KMGTPE"[exp]) +} + +func formatBytesDiff(before, after uint64) string { + if after >= before { + return "+" + formatBytes(after-before) + } + return "-" + formatBytes(before-after) +} + +type upgradedStreamClient struct { + conn net.Conn + reader *bufio.Reader + mu sync.Mutex +} + +func newUpgradedStreamClient(t *testing.T) *upgradedStreamClient { + t.Helper() + + conn, err := net.DialTimeout("tcp", "127.0.0.1:9080", 5*time.Second) + if err != nil { + t.Fatalf("dialing caddy: %v", err) + } + + request := strings.Join([]string{ + "GET /upgrade HTTP/1.1", + "Host: localhost:9080", + "Connection: Upgrade", + "Upgrade: stress-stream", + "", + "", + }, "\r\n") + if _, err := io.WriteString(conn, request); err != nil { + _ = conn.Close() + t.Fatalf("writing upgrade request: %v", err) + } + + reader := bufio.NewReader(conn) + tproto := textproto.NewReader(reader) + statusLine, err := tproto.ReadLine() + if err != nil { + _ = conn.Close() + t.Fatalf("reading upgrade status line: %v", err) + } + if !strings.Contains(statusLine, "101") { + _ = conn.Close() + t.Fatalf("unexpected upgrade status: %s", statusLine) + } + + headers, err := tproto.ReadMIMEHeader() + if err != nil { + _ = conn.Close() + t.Fatalf("reading upgrade headers: %v", err) + } + if !strings.EqualFold(headers.Get("Connection"), "Upgrade") { + _ = conn.Close() + t.Fatalf("unexpected upgrade response headers: %v", headers) + } + + return &upgradedStreamClient{conn: conn, reader: reader} +} + +func (c *upgradedStreamClient) echo(payload string) error { + c.mu.Lock() + defer c.mu.Unlock() + + deadline := time.Now().Add(1 * time.Second) + if err := c.conn.SetWriteDeadline(deadline); err != nil { + return err + } + if _, err := io.WriteString(c.conn, payload); err != nil { + return err + } + if err := c.conn.SetReadDeadline(deadline); err != nil { + return err + } + + buf := make([]byte, len(payload)) + if _, err := io.ReadFull(c.reader, buf); err != nil { + return err + } + if string(buf) != payload { + return fmt.Errorf("unexpected echoed payload: got %q want %q", string(buf), payload) + } + return nil +} + +func (c *upgradedStreamClient) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + return c.conn.Close() +} + +type upgradeEchoBackend struct { + addr string + ln net.Listener + mu sync.Mutex + conns map[net.Conn]struct{} + server *http.Server +} + +func newUpgradeEchoBackend(t *testing.T) *upgradeEchoBackend { + t.Helper() + + backend := &upgradeEchoBackend{conns: make(map[net.Conn]struct{})} + backend.server = &http.Server{ + Handler: http.HandlerFunc(backend.serveHTTP), + } + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listening for backend: %v", err) + } + backend.ln = ln + backend.addr = ln.Addr().String() + + go func() { + _ = backend.server.Serve(ln) + }() + + return backend +} + +func (b *upgradeEchoBackend) serveHTTP(w http.ResponseWriter, r *http.Request) { + if !strings.EqualFold(r.Header.Get("Connection"), "Upgrade") || !strings.EqualFold(r.Header.Get("Upgrade"), "stress-stream") { + http.Error(w, "upgrade required", http.StatusUpgradeRequired) + return + } + + hijacker, ok := w.(http.Hijacker) + if !ok { + http.Error(w, "hijacking not supported", http.StatusInternalServerError) + return + } + + conn, rw, err := hijacker.Hijack() + if err != nil { + return + } + + b.trackConn(conn) + _, _ = rw.WriteString("HTTP/1.1 101 Switching Protocols\r\nConnection: Upgrade\r\nUpgrade: stress-stream\r\n\r\n") + _ = rw.Flush() + + go func() { + defer b.untrackConn(conn) + defer conn.Close() + _, _ = io.Copy(conn, conn) + }() +} + +func (b *upgradeEchoBackend) trackConn(conn net.Conn) { + b.mu.Lock() + b.conns[conn] = struct{}{} + b.mu.Unlock() +} + +func (b *upgradeEchoBackend) untrackConn(conn net.Conn) { + b.mu.Lock() + delete(b.conns, conn) + b.mu.Unlock() +} + +func (b *upgradeEchoBackend) Close() { + _ = b.server.Close() + _ = b.ln.Close() + + b.mu.Lock() + defer b.mu.Unlock() + for conn := range b.conns { + _ = conn.Close() + } + clear(b.conns) +} diff --git a/modules/caddyhttp/encode/encode.go b/modules/caddyhttp/encode/encode.go index ac995c37b32..ecf85495a39 100644 --- a/modules/caddyhttp/encode/encode.go +++ b/modules/caddyhttp/encode/encode.go @@ -405,6 +405,11 @@ func (rw *responseWriter) ReadFrom(r io.Reader) (int64, error) { // Close writes any remaining buffered response and // deallocates any active resources. func (rw *responseWriter) Close() error { + if caddyhttp.ResponseWriterHijacked(rw.ResponseWriter) { + rw.releaseEncoder() + return nil + } + // didn't write, probably head request if !rw.wroteHeader { cl, err := strconv.Atoi(rw.Header().Get("Content-Length")) @@ -422,13 +427,20 @@ func (rw *responseWriter) Close() error { var err error if rw.w != nil { err = rw.w.Close() - rw.w.Reset(nil) - rw.config.writerPools[rw.encodingName].Put(rw.w) - rw.w = nil + rw.releaseEncoder() } return err } +func (rw *responseWriter) releaseEncoder() { + if rw.w == nil { + return + } + rw.w.Reset(nil) + rw.config.writerPools[rw.encodingName].Put(rw.w) + rw.w = nil +} + // Unwrap returns the underlying ResponseWriter. func (rw *responseWriter) Unwrap() http.ResponseWriter { return rw.ResponseWriter diff --git a/modules/caddyhttp/responsewriter.go b/modules/caddyhttp/responsewriter.go index 904c30c0352..d5b43bf42de 100644 --- a/modules/caddyhttp/responsewriter.go +++ b/modules/caddyhttp/responsewriter.go @@ -70,6 +70,7 @@ type responseRecorder struct { size int wroteHeader bool stream bool + hijacked bool readSize *int } @@ -144,7 +145,8 @@ func NewResponseRecorder(w http.ResponseWriter, buf *bytes.Buffer, shouldBuffer // WriteHeader writes the headers with statusCode to the wrapped // ResponseWriter unless the response is to be buffered instead. -// 1xx responses are never buffered. +// 1xx responses are never buffered, except 101 which is treated +// as a final upgrade response. func (rr *responseRecorder) WriteHeader(statusCode int) { if rr.wroteHeader { return @@ -153,6 +155,12 @@ func (rr *responseRecorder) WriteHeader(statusCode int) { // save statusCode always, in case HTTP middleware upgrades websocket // connections by manually setting headers and writing status 101 rr.statusCode = statusCode + if statusCode == http.StatusSwitchingProtocols { + rr.stream = true + rr.wroteHeader = true + rr.ResponseWriterWrapper.WriteHeader(statusCode) + return + } // decide whether we should buffer the response if rr.shouldBuffer == nil { @@ -222,7 +230,14 @@ func (rr *responseRecorder) Buffered() bool { return !rr.stream } +func (rr *responseRecorder) Hijacked() bool { + return rr.hijacked +} + func (rr *responseRecorder) WriteResponse() error { + if rr.hijacked { + return nil + } if rr.statusCode == 0 { // could happen if no handlers actually wrote anything, // and this prevents a panic; status must be > 0 @@ -258,13 +273,16 @@ func (rr *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { if err != nil { return nil, nil, err } - // Per http documentation, returned bufio.Writer is empty, but bufio.Read maybe not - conn = &hijackedConn{conn, rr} + rr.hijacked = true + rr.stream = true + rr.wroteHeader = true + // Per http documentation, returned bufio.Writer is empty, but bufio.Read maybe not. + // Return the raw hijacked connection so upgraded stream traffic does not keep + // traversing the response recorder hot path. brw.Writer.Reset(conn) buffered := brw.Reader.Buffered() if buffered != 0 { - conn.(*hijackedConn).updateReadSize(buffered) data, _ := brw.Peek(buffered) brw.Reader.Reset(io.MultiReader(bytes.NewReader(data), conn)) // peek to make buffered data appear, as Reset will make it 0 @@ -275,40 +293,24 @@ func (rr *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { return conn, brw, nil } -// used to track the size of hijacked response writers -type hijackedConn struct { - net.Conn - rr *responseRecorder -} - -func (hc *hijackedConn) updateReadSize(n int) { - if hc.rr.readSize != nil { - *hc.rr.readSize += n +// ResponseWriterHijacked reports whether w or one of its wrapped response +// writers has been hijacked. +func ResponseWriterHijacked(w http.ResponseWriter) bool { + for w != nil { + if hijacked, ok := w.(interface{ Hijacked() bool }); ok && hijacked.Hijacked() { + return true + } + unwrapper, ok := w.(interface{ Unwrap() http.ResponseWriter }) + if !ok { + return false + } + next := unwrapper.Unwrap() + if next == w { + return false + } + w = next } -} - -func (hc *hijackedConn) Read(p []byte) (int, error) { - n, err := hc.Conn.Read(p) - hc.updateReadSize(n) - return n, err -} - -func (hc *hijackedConn) WriteTo(w io.Writer) (int64, error) { - n, err := io.Copy(w, hc.Conn) - hc.updateReadSize(int(n)) - return n, err -} - -func (hc *hijackedConn) Write(p []byte) (int, error) { - n, err := hc.Conn.Write(p) - hc.rr.size += n - return n, err -} - -func (hc *hijackedConn) ReadFrom(r io.Reader) (int64, error) { - n, err := io.Copy(hc.Conn, r) - hc.rr.size += int(n) - return n, err + return false } // ResponseRecorder is a http.ResponseWriter that records @@ -319,6 +321,7 @@ type ResponseRecorder interface { Status() int Buffer() *bytes.Buffer Buffered() bool + Hijacked() bool Size() int WriteResponse() error } @@ -338,7 +341,4 @@ var ( // see PR #5022 (25%-50% speedup) _ io.ReaderFrom = (*ResponseWriterWrapper)(nil) _ io.ReaderFrom = (*responseRecorder)(nil) - _ io.ReaderFrom = (*hijackedConn)(nil) - - _ io.WriterTo = (*hijackedConn)(nil) ) diff --git a/modules/caddyhttp/responsewriter_test.go b/modules/caddyhttp/responsewriter_test.go index c08ad26a472..72e416db1e4 100644 --- a/modules/caddyhttp/responsewriter_test.go +++ b/modules/caddyhttp/responsewriter_test.go @@ -1,11 +1,14 @@ package caddyhttp import ( + "bufio" "bytes" "io" + "net" "net/http" "strings" "testing" + "time" ) type responseWriterSpy interface { @@ -44,6 +47,50 @@ func (rf *readFromRespWriter) ReadFrom(r io.Reader) (int64, error) { func (rf *readFromRespWriter) CalledReadFrom() bool { return rf.called } +type hijackRespWriter struct { + baseRespWriter + header http.Header + status int + conn net.Conn +} + +func newHijackRespWriter() *hijackRespWriter { + return &hijackRespWriter{ + header: make(http.Header), + conn: stubConn{}, + } +} + +func (hrw *hijackRespWriter) Header() http.Header { + return hrw.header +} + +func (hrw *hijackRespWriter) WriteHeader(statusCode int) { + hrw.status = statusCode +} + +func (hrw *hijackRespWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + br := bufio.NewReader(hrw.conn) + bw := bufio.NewWriter(hrw.conn) + return hrw.conn, bufio.NewReadWriter(br, bw), nil +} + +type stubConn struct{} + +func (stubConn) Read(_ []byte) (int, error) { return 0, io.EOF } +func (stubConn) Write(p []byte) (int, error) { return len(p), nil } +func (stubConn) Close() error { return nil } +func (stubConn) LocalAddr() net.Addr { return stubAddr("local") } +func (stubConn) RemoteAddr() net.Addr { return stubAddr("remote") } +func (stubConn) SetDeadline(time.Time) error { return nil } +func (stubConn) SetReadDeadline(time.Time) error { return nil } +func (stubConn) SetWriteDeadline(time.Time) error { return nil } + +type stubAddr string + +func (a stubAddr) Network() string { return "tcp" } +func (a stubAddr) String() string { return string(a) } + func TestResponseWriterWrapperReadFrom(t *testing.T) { tests := map[string]struct { responseWriter responseWriterSpy @@ -169,3 +216,49 @@ func TestResponseRecorderReadFrom(t *testing.T) { }) } } + +func TestResponseRecorderSwitchingProtocolsIsHijackAware(t *testing.T) { + w := newHijackRespWriter() + var buf bytes.Buffer + + rr := NewResponseRecorder(w, &buf, func(status int, header http.Header) bool { + return true + }) + rr.WriteHeader(http.StatusSwitchingProtocols) + + if rr.Buffered() { + t.Fatal("101 switching protocols response should not remain buffered") + } + if rr.Status() != http.StatusSwitchingProtocols { + t.Fatalf("status = %d, want %d", rr.Status(), http.StatusSwitchingProtocols) + } + if w.status != http.StatusSwitchingProtocols { + t.Fatalf("underlying status = %d, want %d", w.status, http.StatusSwitchingProtocols) + } + + hj, ok := rr.(http.Hijacker) + if !ok { + t.Fatal("response recorder does not implement http.Hijacker") + } + conn, _, err := hj.Hijack() + if err != nil { + t.Fatalf("Hijack() error = %v", err) + } + defer conn.Close() + + if !rr.Hijacked() { + t.Fatal("response recorder should report hijacked state") + } + if !ResponseWriterHijacked(rr) { + t.Fatal("ResponseWriterHijacked() should report true after hijack") + } + if err := rr.WriteResponse(); err != nil { + t.Fatalf("WriteResponse() after hijack returned error: %v", err) + } + if rr.Size() != 0 { + t.Fatalf("size = %d, want 0 after hijack handshake", rr.Size()) + } + if got := w.Written(); got != "" { + t.Fatalf("unexpected buffered body write after hijack: %q", got) + } +} diff --git a/modules/caddyhttp/reverseproxy/caddyfile.go b/modules/caddyhttp/reverseproxy/caddyfile.go index 8716babe336..07277b4f133 100644 --- a/modules/caddyhttp/reverseproxy/caddyfile.go +++ b/modules/caddyhttp/reverseproxy/caddyfile.go @@ -99,6 +99,8 @@ func parseCaddyfile(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, error) // stream_buffer_size // stream_timeout // stream_close_delay +// stream_retain_on_reload +// stream_log_skip_handshake // verbose_logs // // # request manipulation @@ -703,6 +705,18 @@ func (h *Handler) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { h.StreamCloseDelay = caddy.Duration(dur) } + case "stream_retain_on_reload": + if d.NextArg() { + return d.ArgErr() + } + h.StreamRetainOnReload = true + + case "stream_log_skip_handshake": + if d.NextArg() { + return d.ArgErr() + } + h.StreamLogSkipHandshake = true + case "trusted_proxies": for d.NextArg() { if d.Val() == "private_ranges" { diff --git a/modules/caddyhttp/reverseproxy/copyresponse.go b/modules/caddyhttp/reverseproxy/copyresponse.go index c1c9de92ba8..ec1720d31b4 100644 --- a/modules/caddyhttp/reverseproxy/copyresponse.go +++ b/modules/caddyhttp/reverseproxy/copyresponse.go @@ -80,7 +80,7 @@ func (h CopyResponseHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request hrc.isFinalized = true // write the response - return hrc.handler.finalizeResponse(rw, req, hrc.response, repl, hrc.start, hrc.logger) + return hrc.handler.finalizeResponse(rw, req, hrc.response, repl, hrc.start, hrc.logger, hrc.upstreamAddr) } // CopyResponseHeadersHandler is a special HTTP handler which may diff --git a/modules/caddyhttp/reverseproxy/metrics.go b/modules/caddyhttp/reverseproxy/metrics.go index 2488427304e..4b26d86419c 100644 --- a/modules/caddyhttp/reverseproxy/metrics.go +++ b/modules/caddyhttp/reverseproxy/metrics.go @@ -16,6 +16,10 @@ import ( var reverseProxyMetrics = struct { once sync.Once upstreamsHealthy *prometheus.GaugeVec + streamsActive *prometheus.GaugeVec + streamsTotal *prometheus.CounterVec + streamDuration *prometheus.HistogramVec + streamBytes *prometheus.CounterVec logger *zap.Logger }{} @@ -23,6 +27,8 @@ func initReverseProxyMetrics(handler *Handler, registry *prometheus.Registry) { const ns, sub = "caddy", "reverse_proxy" upstreamsLabels := []string{"upstream"} + streamResultLabels := []string{"upstream", "result"} + streamBytesLabels := []string{"upstream", "direction"} reverseProxyMetrics.once.Do(func() { reverseProxyMetrics.upstreamsHealthy = prometheus.NewGaugeVec(prometheus.GaugeOpts{ Namespace: ns, @@ -30,6 +36,31 @@ func initReverseProxyMetrics(handler *Handler, registry *prometheus.Registry) { Name: "upstreams_healthy", Help: "Health status of reverse proxy upstreams.", }, upstreamsLabels) + reverseProxyMetrics.streamsActive = prometheus.NewGaugeVec(prometheus.GaugeOpts{ + Namespace: ns, + Subsystem: sub, + Name: "streams_active", + Help: "Number of currently active upgraded reverse proxy streams.", + }, upstreamsLabels) + reverseProxyMetrics.streamsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: ns, + Subsystem: sub, + Name: "streams_total", + Help: "Total number of upgraded reverse proxy streams by close result.", + }, streamResultLabels) + reverseProxyMetrics.streamDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: ns, + Subsystem: sub, + Name: "stream_duration_seconds", + Help: "Duration of upgraded reverse proxy streams by close result.", + Buckets: prometheus.DefBuckets, + }, streamResultLabels) + reverseProxyMetrics.streamBytes = prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: ns, + Subsystem: sub, + Name: "stream_bytes_total", + Help: "Total bytes proxied across upgraded reverse proxy streams.", + }, streamBytesLabels) }) // duplicate registration could happen if multiple sites with reverse proxy are configured; so ignore the error because @@ -42,10 +73,58 @@ func initReverseProxyMetrics(handler *Handler, registry *prometheus.Registry) { }) { panic(err) } + if err := registry.Register(reverseProxyMetrics.streamsActive); err != nil && + !errors.Is(err, prometheus.AlreadyRegisteredError{ + ExistingCollector: reverseProxyMetrics.streamsActive, + NewCollector: reverseProxyMetrics.streamsActive, + }) { + panic(err) + } + if err := registry.Register(reverseProxyMetrics.streamsTotal); err != nil && + !errors.Is(err, prometheus.AlreadyRegisteredError{ + ExistingCollector: reverseProxyMetrics.streamsTotal, + NewCollector: reverseProxyMetrics.streamsTotal, + }) { + panic(err) + } + if err := registry.Register(reverseProxyMetrics.streamDuration); err != nil && + !errors.Is(err, prometheus.AlreadyRegisteredError{ + ExistingCollector: reverseProxyMetrics.streamDuration, + NewCollector: reverseProxyMetrics.streamDuration, + }) { + panic(err) + } + if err := registry.Register(reverseProxyMetrics.streamBytes); err != nil && + !errors.Is(err, prometheus.AlreadyRegisteredError{ + ExistingCollector: reverseProxyMetrics.streamBytes, + NewCollector: reverseProxyMetrics.streamBytes, + }) { + panic(err) + } reverseProxyMetrics.logger = handler.logger.Named("reverse_proxy.metrics") } +func trackActiveStream(upstream string) func(result string, duration time.Duration, toBackend, fromBackend int64) { + labels := prometheus.Labels{"upstream": upstream} + reverseProxyMetrics.streamsActive.With(labels).Inc() + + var once sync.Once + return func(result string, duration time.Duration, toBackend, fromBackend int64) { + once.Do(func() { + reverseProxyMetrics.streamsActive.With(labels).Dec() + reverseProxyMetrics.streamsTotal.WithLabelValues(upstream, result).Inc() + reverseProxyMetrics.streamDuration.WithLabelValues(upstream, result).Observe(duration.Seconds()) + if toBackend > 0 { + reverseProxyMetrics.streamBytes.WithLabelValues(upstream, "to_upstream").Add(float64(toBackend)) + } + if fromBackend > 0 { + reverseProxyMetrics.streamBytes.WithLabelValues(upstream, "from_upstream").Add(float64(fromBackend)) + } + }) + } +} + type metricsUpstreamsHealthyUpdater struct { handler *Handler } diff --git a/modules/caddyhttp/reverseproxy/metrics_test.go b/modules/caddyhttp/reverseproxy/metrics_test.go new file mode 100644 index 00000000000..edbe9ca8d76 --- /dev/null +++ b/modules/caddyhttp/reverseproxy/metrics_test.go @@ -0,0 +1,67 @@ +package reverseproxy + +import ( + "testing" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/testutil" +) + +func TestTrackActiveStreamRecordsLifecycleAndBytes(t *testing.T) { + const upstream = "127.0.0.1:7443" + + // Use fresh metric vectors for deterministic assertions in this unit test. + reverseProxyMetrics.streamsActive = prometheus.NewGaugeVec(prometheus.GaugeOpts{}, []string{"upstream"}) + reverseProxyMetrics.streamsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{}, []string{"upstream", "result"}) + reverseProxyMetrics.streamDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{}, []string{"upstream", "result"}) + reverseProxyMetrics.streamBytes = prometheus.NewCounterVec(prometheus.CounterOpts{}, []string{"upstream", "direction"}) + + finish := trackActiveStream(upstream) + + if got := testutil.ToFloat64(reverseProxyMetrics.streamsActive.WithLabelValues(upstream)); got != 1 { + t.Fatalf("active streams = %v, want 1", got) + } + + finish("closed", 150*time.Millisecond, 1234, 4321) + + if got := testutil.ToFloat64(reverseProxyMetrics.streamsActive.WithLabelValues(upstream)); got != 0 { + t.Fatalf("active streams = %v, want 0", got) + } + if got := testutil.ToFloat64(reverseProxyMetrics.streamsTotal.WithLabelValues(upstream, "closed")); got != 1 { + t.Fatalf("streams_total closed = %v, want 1", got) + } + if got := testutil.ToFloat64(reverseProxyMetrics.streamBytes.WithLabelValues(upstream, "to_upstream")); got != 1234 { + t.Fatalf("bytes to_upstream = %v, want 1234", got) + } + if got := testutil.ToFloat64(reverseProxyMetrics.streamBytes.WithLabelValues(upstream, "from_upstream")); got != 4321 { + t.Fatalf("bytes from_upstream = %v, want 4321", got) + } + + // A second finish call should be ignored by the once guard. + finish("error", 1*time.Second, 111, 222) + if got := testutil.ToFloat64(reverseProxyMetrics.streamsTotal.WithLabelValues(upstream, "error")); got != 0 { + t.Fatalf("streams_total error = %v, want 0", got) + } +} + +func TestTrackActiveStreamDoesNotCountZeroBytes(t *testing.T) { + const upstream = "127.0.0.1:9000" + + reverseProxyMetrics.streamsActive = prometheus.NewGaugeVec(prometheus.GaugeOpts{}, []string{"upstream"}) + reverseProxyMetrics.streamsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{}, []string{"upstream", "result"}) + reverseProxyMetrics.streamDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{}, []string{"upstream", "result"}) + reverseProxyMetrics.streamBytes = prometheus.NewCounterVec(prometheus.CounterOpts{}, []string{"upstream", "direction"}) + + trackActiveStream(upstream)("timeout", 250*time.Millisecond, 0, 0) + + if got := testutil.ToFloat64(reverseProxyMetrics.streamBytes.WithLabelValues(upstream, "to_upstream")); got != 0 { + t.Fatalf("bytes to_upstream = %v, want 0", got) + } + if got := testutil.ToFloat64(reverseProxyMetrics.streamBytes.WithLabelValues(upstream, "from_upstream")); got != 0 { + t.Fatalf("bytes from_upstream = %v, want 0", got) + } + if got := testutil.ToFloat64(reverseProxyMetrics.streamsTotal.WithLabelValues(upstream, "timeout")); got != 1 { + t.Fatalf("streams_total timeout = %v, want 1", got) + } +} diff --git a/modules/caddyhttp/reverseproxy/reverseproxy.go b/modules/caddyhttp/reverseproxy/reverseproxy.go index 52d2b1ab30f..adb47a9b8da 100644 --- a/modules/caddyhttp/reverseproxy/reverseproxy.go +++ b/modules/caddyhttp/reverseproxy/reverseproxy.go @@ -186,6 +186,18 @@ type Handler struct { // by the previous config closing. Default: no delay. StreamCloseDelay caddy.Duration `json:"stream_close_delay,omitempty"` + // If true, upgraded connections such as WebSockets are retained across + // config reloads when their upstream still exists in the new config. + // Connections using upstreams that are removed are closed during cleanup. + // By default this is false, preserving legacy behavior where upgraded + // connections are closed on reload (optionally delayed by stream_close_delay). + StreamRetainOnReload bool `json:"stream_retain_on_reload,omitempty"` + + // If true, suppresses the access log entry normally emitted when an + // upgraded stream handshake completes and the request unwinds. By default + // the handshake is still logged as a normal request with status 101. + StreamLogSkipHandshake bool `json:"stream_log_skip_handshake,omitempty"` + // If configured, rewrites the copy of the upstream request. // Allows changing the request method and URI (path and query). // Since the rewrite is applied to the copy, it does not persist @@ -240,10 +252,9 @@ type Handler struct { // Holds the handle_response Caddyfile tokens while adapting handleResponseSegments []*caddyfile.Dispenser - // Stores upgraded requests (hijacked connections) for proper cleanup - connections map[io.ReadWriteCloser]openConnection - connectionsCloseTimer *time.Timer - connectionsMu *sync.Mutex + // Tracks hijacked/upgraded connections (WebSocket etc.) so they can be + // closed when their upstream is removed from the config. + tunnel *tunnelState ctx caddy.Context logger *zap.Logger @@ -267,8 +278,7 @@ func (h *Handler) Provision(ctx caddy.Context) error { h.events = eventAppIface.(*caddyevents.App) h.ctx = ctx h.logger = ctx.Logger() - h.connections = make(map[io.ReadWriteCloser]openConnection) - h.connectionsMu = new(sync.Mutex) + h.tunnel = newTunnelState(h.logger, time.Duration(h.StreamCloseDelay)) // warn about unsafe buffering config if h.RequestBuffers == -1 || h.ResponseBuffers == -1 { @@ -439,13 +449,29 @@ func (h *Handler) Provision(ctx caddy.Context) error { // Cleanup cleans up the resources made by h. func (h *Handler) Cleanup() error { - err := h.cleanupConnections() + if !h.StreamRetainOnReload { + // Legacy behavior: close all upgraded connections on reload, either + // immediately or after StreamCloseDelay. + err := h.tunnel.cleanupConnections() + for _, upstream := range h.Upstreams { + _, _ = hosts.Delete(upstream.String()) + } + return err + } - // remove hosts from our config from the pool + var err error for _, upstream := range h.Upstreams { - _, _ = hosts.Delete(upstream.String()) + // hosts.Delete returns deleted=true when the ref count reaches zero, + // meaning no other active config references this upstream. In that + // case close any tunnels proxying to it; otherwise let them survive + // to their natural end since the upstream is still in use. + deleted, _ := hosts.Delete(upstream.String()) + if deleted { + if closeErr := h.tunnel.closeConnectionsForUpstream(upstream.String()); closeErr != nil && err == nil { + err = closeErr + } + } } - return err } @@ -1127,10 +1153,11 @@ func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, origRe // we use the original request here, so that any routes from 'next' // see the original request rather than the proxy cloned request. hrc := &handleResponseContext{ - handler: h, - response: res, - start: start, - logger: logger, + handler: h, + response: res, + start: start, + logger: logger, + upstreamAddr: di.Upstream.String(), } ctx := origReq.Context() ctx = context.WithValue(ctx, proxyHandleResponseContextCtxKey, hrc) @@ -1160,7 +1187,7 @@ func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, origRe } // copy the response body and headers back to the upstream client - return h.finalizeResponse(rw, req, res, repl, start, logger) + return h.finalizeResponse(rw, req, res, repl, start, logger, di.Upstream.String()) } // finalizeResponse prepares and copies the response. @@ -1171,11 +1198,12 @@ func (h *Handler) finalizeResponse( repl *caddy.Replacer, start time.Time, logger *zap.Logger, + upstreamAddr string, ) error { // deal with 101 Switching Protocols responses: (WebSocket, h2c, etc) if res.StatusCode == http.StatusSwitchingProtocols { var wg sync.WaitGroup - h.handleUpgradeResponse(logger, &wg, rw, req, res) + h.handleUpgradeResponse(logger, &wg, rw, req, res, upstreamAddr) wg.Wait() return nil } @@ -1797,6 +1825,9 @@ type handleResponseContext struct { // i.e. copied and closed, to make sure that it doesn't // happen twice. isFinalized bool + + // upstreamAddr is the selected upstream address for this request. + upstreamAddr string } // proxyHandleResponseContextCtxKey is the context key for the active proxy handler diff --git a/modules/caddyhttp/reverseproxy/streaming.go b/modules/caddyhttp/reverseproxy/streaming.go index e454ee65547..37fca60140f 100644 --- a/modules/caddyhttp/reverseproxy/streaming.go +++ b/modules/caddyhttp/reverseproxy/streaming.go @@ -26,6 +26,7 @@ import ( "io" weakrand "math/rand/v2" "mime" + "net" "net/http" "sync" "time" @@ -35,6 +36,7 @@ import ( "go.uber.org/zap/zapcore" "golang.org/x/net/http/httpguts" + "github.com/caddyserver/caddy/v2" "github.com/caddyserver/caddy/v2/modules/caddyhttp" ) @@ -57,7 +59,7 @@ func (rwc h2ReadWriteCloser) Write(p []byte) (n int, err error) { return n, nil } -func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup, rw http.ResponseWriter, req *http.Request, res *http.Response) { +func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup, rw http.ResponseWriter, req *http.Request, res *http.Response, upstreamAddr string) { reqUpType := upgradeType(req.Header) resUpType := upgradeType(res.Header) @@ -90,13 +92,22 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup, copyHeader(rw.Header(), res.Header) normalizeWebsocketHeaders(rw.Header()) + // Capture all h fields needed by the tunnel now, so that the Handler (h) + // is not referenced after this function returns (for HTTP/1.1 hijacked + // connections the tunnel runs in a detached goroutine). + tunnel := h.tunnel + bufferSize := h.StreamBufferSize + streamTimeout := time.Duration(h.StreamTimeout) + var ( conn io.ReadWriteCloser brw *bufio.ReadWriter + isH2 bool ) // websocket over http2 or http3 if extended connect is enabled, assuming backend doesn't support this, the request will be modified to http1.1 upgrade // TODO: once we can reliably detect backend support this, it can be removed for those backends if body, ok := caddyhttp.GetVar(req.Context(), "extended_connect_websocket_body").(io.ReadCloser); ok { + isH2 = true req.Body = body rw.Header().Del("Upgrade") rw.Header().Del("Connection") @@ -143,26 +154,24 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup, } } - // adopted from https://github.com/golang/go/commit/8bcf2834afdf6a1f7937390903a41518715ef6f5 - backConnCloseCh := make(chan struct{}) - go func() { - // Ensure that the cancellation of a request closes the backend. - // See issue https://golang.org/issue/35559. - select { - case <-req.Context().Done(): - case <-backConnCloseCh: - } - backConn.Close() - }() - defer close(backConnCloseCh) - - start := time.Now() - defer func() { - conn.Close() - if c := logger.Check(zapcore.DebugLevel, "connection closed"); c != nil { - c.Write(zap.Duration("duration", time.Since(start))) - } - }() + // For H2 extended connect: close backConn when the request context is + // cancelled (e.g. client disconnects). For HTTP/1.1 hijacked connections + // we skip this because req.Context() may be cancelled when ServeHTTP + // returns early, which would prematurely close the backend connection. + if isH2 { + // adopted from https://github.com/golang/go/commit/8bcf2834afdf6a1f7937390903a41518715ef6f5 + backConnCloseCh := make(chan struct{}) + go func() { + // Ensure that the cancellation of a request closes the backend. + // See issue https://golang.org/issue/35559. + select { + case <-req.Context().Done(): + case <-backConnCloseCh: + } + backConn.Close() + }() + defer close(backConnCloseCh) + } if err := brw.Flush(); err != nil { if c := logger.Check(zapcore.DebugLevel, "response flush"); c != nil { @@ -184,13 +193,11 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup, } } - // Ensure the hijacked client connection, and the new connection established - // with the backend, are both closed in the event of a server shutdown. This - // is done by registering them. We also try to gracefully close connections - // we recognize as websockets. - // We need to make sure the client connection messages (i.e. to upstream) - // are masked, so we need to know whether the connection is considered the - // server or the client side of the proxy. + // Register both connections with the tunnel tracker. We also try to + // gracefully close connections we recognize as websockets. We need to make + // sure the client connection messages (i.e. to upstream) are masked, so we + // need to know whether the connection is considered the server or the + // client side of the proxy. gracefulClose := func(conn io.ReadWriteCloser, isClient bool) func() error { if isWebsocket(req) { return func() error { @@ -199,43 +206,186 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup, } return nil } - deleteFrontConn := h.registerConnection(conn, gracefulClose(conn, false)) - deleteBackConn := h.registerConnection(backConn, gracefulClose(backConn, true)) - defer deleteFrontConn() + deleteFrontConn := tunnel.registerConnection(conn, gracefulClose(conn, false), upstreamAddr) + deleteBackConn := tunnel.registerConnection(backConn, gracefulClose(backConn, true), upstreamAddr) + if h.StreamLogSkipHandshake { + caddyhttp.SetVar(req.Context(), caddyhttp.LogSkipVar, true) + } + repl := req.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer) + repl.Set("http.reverse_proxy.upgraded", true) + finishMetrics := trackActiveStream(upstreamAddr) + + start := time.Now() + + if isH2 { + h.handleH2UpgradeTunnel(logger, wg, conn, backConn, deleteFrontConn, deleteBackConn, bufferSize, streamTimeout, start, finishMetrics) + } else { + h.handleDetachedUpgradeTunnel(logger, conn, backConn, deleteFrontConn, deleteBackConn, bufferSize, streamTimeout, start, finishMetrics) + // Return immediately without touching wg. finalizeResponse's + // wg.Wait() returns at once since wg was never incremented. + } +} + +func (h *Handler) handleH2UpgradeTunnel( + logger *zap.Logger, + wg *sync.WaitGroup, + conn io.ReadWriteCloser, + backConn io.ReadWriteCloser, + deleteFrontConn func(), + deleteBackConn func(), + bufferSize int, + streamTimeout time.Duration, + start time.Time, + finishMetrics func(result string, duration time.Duration, toBackend, fromBackend int64), +) { + // H2 extended connect: ServeHTTP must block because rw and req.Body are + // only valid while the handler goroutine is running. Defers clean up + // when the select below fires and this function returns. defer deleteBackConn() + defer deleteFrontConn() + var ( + toBackend int64 + fromBackend int64 + result = "closed" + ) + // when a stream timeout is encountered, no error will be read from errc + // a buffer size of 2 will allow both the read and write goroutines to send the error and exit + // see: https://github.com/caddyserver/caddy/issues/7418 + errc := make(chan error, 2) spc := switchProtocolCopier{ user: conn, backend: backConn, wg: wg, - bufferSize: h.StreamBufferSize, + bufferSize: bufferSize, + sent: &toBackend, + received: &fromBackend, } + wg.Add(2) - // setup the timeout if requested var timeoutc <-chan time.Time - if h.StreamTimeout > 0 { - timer := time.NewTimer(time.Duration(h.StreamTimeout)) + if streamTimeout > 0 { + timer := time.NewTimer(streamTimeout) defer timer.Stop() timeoutc = timer.C } - // when a stream timeout is encountered, no error will be read from errc - // a buffer size of 2 will allow both the read and write goroutines to send the error and exit - // see: https://github.com/caddyserver/caddy/issues/7418 - errc := make(chan error, 2) - wg.Add(2) go spc.copyToBackend(errc) go spc.copyFromBackend(errc) select { case err := <-errc: + result = classifyStreamResult(err) if c := logger.Check(zapcore.DebugLevel, "streaming error"); c != nil { c.Write(zap.Error(err)) } - case time := <-timeoutc: + case t := <-timeoutc: + result = "timeout" if c := logger.Check(zapcore.DebugLevel, "stream timed out"); c != nil { - c.Write(zap.Time("timeout", time)) + c.Write(zap.Time("timeout", t)) + } + } + + // Close both ends to unblock the still-running copy goroutine, + // then wait for it so byte counts are final before metrics/logging. + conn.Close() + backConn.Close() + wg.Wait() + + finishMetrics(result, time.Since(start), toBackend, fromBackend) + if c := logger.Check(zapcore.DebugLevel, "connection closed"); c != nil { + c.Write( + zap.Duration("duration", time.Since(start)), + zap.Int64("bytes_to_backend", toBackend), + zap.Int64("bytes_from_backend", fromBackend), + ) + } +} + +func (h *Handler) handleDetachedUpgradeTunnel( + logger *zap.Logger, + conn io.ReadWriteCloser, + backConn io.ReadWriteCloser, + deleteFrontConn func(), + deleteBackConn func(), + bufferSize int, + streamTimeout time.Duration, + start time.Time, + finishMetrics func(result string, duration time.Duration, toBackend, fromBackend int64), +) { + // HTTP/1.1 hijacked connection: launch a detached goroutine so that + // ServeHTTP can return immediately, allowing the Handler to be GC'd + // after a config reload. The goroutine captures only tunnel (a small + // *tunnelState), logger, conn/backConn, and scalar config values. + go func() { + var ( + toBackend int64 + fromBackend int64 + result = "closed" + ) + defer deleteBackConn() + defer deleteFrontConn() + defer func() { + finishMetrics(result, time.Since(start), toBackend, fromBackend) + if c := logger.Check(zapcore.DebugLevel, "connection closed"); c != nil { + c.Write( + zap.Duration("duration", time.Since(start)), + zap.Int64("bytes_to_backend", toBackend), + zap.Int64("bytes_from_backend", fromBackend), + ) + } + }() + + var innerWg sync.WaitGroup + // when a stream timeout is encountered, no error will be read from errc + // a buffer size of 2 will allow both the read and write goroutines to send the error and exit + // see: https://github.com/caddyserver/caddy/issues/7418 + errc := make(chan error, 2) + spc := switchProtocolCopier{ + user: conn, + backend: backConn, + wg: &innerWg, + bufferSize: bufferSize, + sent: &toBackend, + received: &fromBackend, + } + innerWg.Add(2) + + var timeoutc <-chan time.Time + if streamTimeout > 0 { + timer := time.NewTimer(streamTimeout) + defer timer.Stop() + timeoutc = timer.C + } + + go spc.copyToBackend(errc) + go spc.copyFromBackend(errc) + select { + case err := <-errc: + result = classifyStreamResult(err) + if c := logger.Check(zapcore.DebugLevel, "streaming error"); c != nil { + c.Write(zap.Error(err)) + } + case t := <-timeoutc: + result = "timeout" + if c := logger.Check(zapcore.DebugLevel, "stream timed out"); c != nil { + c.Write(zap.Time("timeout", t)) + } } + + // Close both ends to unblock the still-running copy goroutine, + // then wait for it to finish so byte counts are accurate before + // the deferred log fires. + conn.Close() + backConn.Close() + innerWg.Wait() + }() +} + +func classifyStreamResult(err error) string { + if err == nil || errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) { + return "closed" } + return "error" } // flushInterval returns the p.FlushInterval value, conditionally @@ -375,75 +525,86 @@ func (h Handler) copyBuffer(dst io.Writer, src io.Reader, buf []byte, logger *za } } -// registerConnection holds onto conn so it can be closed in the event -// of a server shutdown. This is useful because hijacked connections or -// connections dialed to backends don't close when server is shut down. -// The caller should call the returned delete() function when the -// connection is done to remove it from memory. -func (h *Handler) registerConnection(conn io.ReadWriteCloser, gracefulClose func() error) (del func()) { - h.connectionsMu.Lock() - h.connections[conn] = openConnection{conn, gracefulClose} - h.connectionsMu.Unlock() +// openConnection maps an open connection to an optional function for graceful +// close and records which upstream address the connection is proxying to. +type openConnection struct { + conn io.ReadWriteCloser + gracefulClose func() error + upstream string +} + +// tunnelState tracks hijacked/upgraded connections for selective cleanup. +type tunnelState struct { + connections map[io.ReadWriteCloser]openConnection + closeTimer *time.Timer + closeDelay time.Duration + mu sync.Mutex + logger *zap.Logger +} + +func newTunnelState(logger *zap.Logger, closeDelay time.Duration) *tunnelState { + return &tunnelState{ + connections: make(map[io.ReadWriteCloser]openConnection), + closeDelay: closeDelay, + logger: logger, + } +} + +// registerConnection stores conn in the tracking map. The caller must invoke +// the returned del func when the connection is done. +func (ts *tunnelState) registerConnection(conn io.ReadWriteCloser, gracefulClose func() error, upstream string) (del func()) { + ts.mu.Lock() + ts.connections[conn] = openConnection{conn, gracefulClose, upstream} + ts.mu.Unlock() return func() { - h.connectionsMu.Lock() - delete(h.connections, conn) - // if there is no connection left before the connections close timer fires - if len(h.connections) == 0 && h.connectionsCloseTimer != nil { - // we release the timer that holds the reference to Handler - if (*h.connectionsCloseTimer).Stop() { - h.logger.Debug("stopped streaming connections close timer - all connections are already closed") + ts.mu.Lock() + delete(ts.connections, conn) + if len(ts.connections) == 0 && ts.closeTimer != nil { + if ts.closeTimer.Stop() { + ts.logger.Debug("stopped streaming connections close timer - all connections are already closed") } - h.connectionsCloseTimer = nil + ts.closeTimer = nil } - h.connectionsMu.Unlock() + ts.mu.Unlock() } } -// closeConnections immediately closes all hijacked connections (both to client and backend). -func (h *Handler) closeConnections() error { +// closeConnections closes all tracked connections. +func (ts *tunnelState) closeConnections() error { var err error - h.connectionsMu.Lock() - defer h.connectionsMu.Unlock() - - for _, oc := range h.connections { + ts.mu.Lock() + defer ts.mu.Unlock() + for _, oc := range ts.connections { if oc.gracefulClose != nil { - // this is potentially blocking while we have the lock on the connections - // map, but that should be OK since the server has in theory shut down - // and we are no longer using the connections map - gracefulErr := oc.gracefulClose() - if gracefulErr != nil && err == nil { + if gracefulErr := oc.gracefulClose(); gracefulErr != nil && err == nil { err = gracefulErr } } - closeErr := oc.conn.Close() - if closeErr != nil && err == nil { + if closeErr := oc.conn.Close(); closeErr != nil && err == nil { err = closeErr } } return err } -// cleanupConnections closes hijacked connections. -// Depending on the value of StreamCloseDelay it does that either immediately -// or sets up a timer that will do that later. -func (h *Handler) cleanupConnections() error { - if h.StreamCloseDelay == 0 { - return h.closeConnections() - } - - h.connectionsMu.Lock() - defer h.connectionsMu.Unlock() - // the handler is shut down, no new connection can appear, - // so we can skip setting up the timer when there are no connections - if len(h.connections) > 0 { - delay := time.Duration(h.StreamCloseDelay) - h.connectionsCloseTimer = time.AfterFunc(delay, func() { - if c := h.logger.Check(zapcore.DebugLevel, "closing streaming connections after delay"); c != nil { +// cleanupConnections closes upgraded connections. Depending on closeDelay it +// does that either immediately or after a timer. +func (ts *tunnelState) cleanupConnections() error { + if ts.closeDelay == 0 { + return ts.closeConnections() + } + + ts.mu.Lock() + defer ts.mu.Unlock() + if len(ts.connections) > 0 { + delay := ts.closeDelay + ts.closeTimer = time.AfterFunc(delay, func() { + if c := ts.logger.Check(zapcore.DebugLevel, "closing streaming connections after delay"); c != nil { c.Write(zap.Duration("delay", delay)) } - err := h.closeConnections() + err := ts.closeConnections() if err != nil { - if c := h.logger.Check(zapcore.ErrorLevel, "failed to closed connections after delay"); c != nil { + if c := ts.logger.Check(zapcore.ErrorLevel, "failed to close connections after delay"); c != nil { c.Write( zap.Error(err), zap.Duration("delay", delay), @@ -567,11 +728,26 @@ func isWebsocket(r *http.Request) bool { httpguts.HeaderValuesContainsToken(r.Header["Upgrade"], "websocket") } -// openConnection maps an open connection to -// an optional function for graceful close. -type openConnection struct { - conn io.ReadWriteCloser - gracefulClose func() error +// closeConnectionsForUpstream closes all tracked connections that were +// established to the given upstream address. +func (ts *tunnelState) closeConnectionsForUpstream(addr string) error { + var err error + ts.mu.Lock() + defer ts.mu.Unlock() + for _, oc := range ts.connections { + if oc.upstream != addr { + continue + } + if oc.gracefulClose != nil { + if gracefulErr := oc.gracefulClose(); gracefulErr != nil && err == nil { + err = gracefulErr + } + } + if closeErr := oc.conn.Close(); closeErr != nil && err == nil { + err = closeErr + } + } + return err } type maxLatencyWriter struct { @@ -642,16 +818,23 @@ type switchProtocolCopier struct { user, backend io.ReadWriteCloser wg *sync.WaitGroup bufferSize int + // sent and received accumulate byte counts for each direction. + // They are written before wg.Done() and read after wg.Wait(), so no + // additional synchronization is needed beyond the WaitGroup barrier. + sent *int64 // bytes copied to backend; must be non-nil + received *int64 // bytes copied from backend; must be non-nil } func (c switchProtocolCopier) copyFromBackend(errc chan<- error) { - _, err := io.CopyBuffer(c.user, c.backend, c.buffer()) + n, err := io.CopyBuffer(c.user, c.backend, c.buffer()) + *c.received = n errc <- err c.wg.Done() } func (c switchProtocolCopier) copyToBackend(errc chan<- error) { - _, err := io.CopyBuffer(c.backend, c.user, c.buffer()) + n, err := io.CopyBuffer(c.backend, c.user, c.buffer()) + *c.sent = n errc <- err c.wg.Done() } diff --git a/modules/caddyhttp/reverseproxy/streaming_test.go b/modules/caddyhttp/reverseproxy/streaming_test.go index ce0db65a06c..d2441739a66 100644 --- a/modules/caddyhttp/reverseproxy/streaming_test.go +++ b/modules/caddyhttp/reverseproxy/streaming_test.go @@ -7,8 +7,10 @@ import ( "strings" "sync" "testing" + "time" "github.com/caddyserver/caddy/v2" + "github.com/caddyserver/caddy/v2/caddyconfig/caddyfile" ) func TestHandlerCopyResponse(t *testing.T) { @@ -41,12 +43,15 @@ func TestSwitchProtocolCopierBufferSize(t *testing.T) { var wg sync.WaitGroup var errc = make(chan error, 1) var dst bytes.Buffer + var sent, received int64 copier := switchProtocolCopier{ user: nopReadWriteCloser{Reader: strings.NewReader("hello")}, backend: nopReadWriteCloser{Writer: &dst}, wg: &wg, bufferSize: 7, + sent: &sent, + received: &received, } buf := copier.buffer() @@ -80,3 +85,132 @@ type nopReadWriteCloser struct { } func (nopReadWriteCloser) Close() error { return nil } + +type trackingReadWriteCloser struct { + closed chan struct{} + one sync.Once +} + +func newTrackingReadWriteCloser() *trackingReadWriteCloser { + return &trackingReadWriteCloser{closed: make(chan struct{})} +} + +func (c *trackingReadWriteCloser) Read(_ []byte) (int, error) { return 0, io.EOF } +func (c *trackingReadWriteCloser) Write(p []byte) (int, error) { return len(p), nil } +func (c *trackingReadWriteCloser) Close() error { + c.one.Do(func() { + close(c.closed) + }) + return nil +} + +func (c *trackingReadWriteCloser) isClosed() bool { + select { + case <-c.closed: + return true + default: + return false + } +} + +func TestHandlerCleanupLegacyModeClosesAllConnections(t *testing.T) { + ts := newTunnelState(caddy.Log(), 0) + connA := newTrackingReadWriteCloser() + connB := newTrackingReadWriteCloser() + ts.registerConnection(connA, nil, "a") + ts.registerConnection(connB, nil, "b") + + h := &Handler{ + tunnel: ts, + StreamRetainOnReload: false, + } + + if err := h.Cleanup(); err != nil { + t.Fatalf("cleanup failed: %v", err) + } + if !connA.isClosed() || !connB.isClosed() { + t.Fatalf("legacy cleanup should close all upgraded connections") + } +} + +func TestHandlerCleanupLegacyModeHonorsDelay(t *testing.T) { + ts := newTunnelState(caddy.Log(), 40*time.Millisecond) + conn := newTrackingReadWriteCloser() + ts.registerConnection(conn, nil, "a") + + h := &Handler{ + tunnel: ts, + StreamRetainOnReload: false, + } + + if err := h.Cleanup(); err != nil { + t.Fatalf("cleanup failed: %v", err) + } + if conn.isClosed() { + t.Fatal("connection should not close immediately when stream_close_delay is set") + } + + select { + case <-conn.closed: + case <-time.After(500 * time.Millisecond): + t.Fatal("connection did not close after stream_close_delay elapsed") + } +} + +func TestHandlerCleanupRetainModeClosesOnlyRemovedUpstreams(t *testing.T) { + const upstreamA = "upstream-a" + const upstreamB = "upstream-b" + + // Simulate old+new configs both referencing upstreamA (refcount 2), + // while upstreamB is only referenced by the old config (refcount 1). + hosts.LoadOrStore(upstreamA, struct{}{}) + hosts.LoadOrStore(upstreamA, struct{}{}) + hosts.LoadOrStore(upstreamB, struct{}{}) + t.Cleanup(func() { + _, _ = hosts.Delete(upstreamA) + _, _ = hosts.Delete(upstreamA) + _, _ = hosts.Delete(upstreamB) + }) + + ts := newTunnelState(caddy.Log(), 0) + connA := newTrackingReadWriteCloser() + connB := newTrackingReadWriteCloser() + ts.registerConnection(connA, nil, upstreamA) + ts.registerConnection(connB, nil, upstreamB) + + h := &Handler{ + tunnel: ts, + StreamRetainOnReload: true, + Upstreams: UpstreamPool{ + &Upstream{Dial: upstreamA}, + &Upstream{Dial: upstreamB}, + }, + } + + if err := h.Cleanup(); err != nil { + t.Fatalf("cleanup failed: %v", err) + } + + if connA.isClosed() { + t.Fatal("connection for retained upstream should remain open") + } + if !connB.isClosed() { + t.Fatal("connection for removed upstream should be closed") + } +} + +func TestHandlerUnmarshalCaddyfileStreamLogSkipHandshake(t *testing.T) { + d := caddyfile.NewTestDispenser(` + reverse_proxy localhost:9000 { + stream_log_skip_handshake + } + `) + + var h Handler + if err := h.UnmarshalCaddyfile(d); err != nil { + t.Fatalf("UnmarshalCaddyfile() error = %v", err) + } + if !h.StreamLogSkipHandshake { + t.Fatal("expected stream_log_skip_handshake to enable StreamLogSkipHandshake") + } +} From daea7788ad49975e4cfcafabcebe4e67be8b56ee Mon Sep 17 00:00:00 2001 From: Francis Lavoie Date: Mon, 13 Apr 2026 05:03:05 -0400 Subject: [PATCH 02/17] lint --- caddytest/integration/stream_reload_stress_test.go | 10 +++++----- modules/caddyhttp/reverseproxy/streaming.go | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/caddytest/integration/stream_reload_stress_test.go b/caddytest/integration/stream_reload_stress_test.go index cd0b354caef..45473e2219d 100644 --- a/caddytest/integration/stream_reload_stress_test.go +++ b/caddytest/integration/stream_reload_stress_test.go @@ -86,12 +86,12 @@ func TestReverseProxyReloadStressUpgradedStreamsHeapProfiles(t *testing.T) { } type stressRunResult struct { - streamCount int - aliveAfterReloads int + streamCount int + aliveAfterReloads int aliveBeforeDelayExpiry int // only meaningful for close_delay mode - beforeReload heapSnapshot - midReload heapSnapshot // after all reloads, before delay expiry clean-up - afterReload heapSnapshot // after all streams have been fully cleaned up + beforeReload heapSnapshot + midReload heapSnapshot // after all reloads, before delay expiry clean-up + afterReload heapSnapshot // after all streams have been fully cleaned up } type heapSnapshot struct { diff --git a/modules/caddyhttp/reverseproxy/streaming.go b/modules/caddyhttp/reverseproxy/streaming.go index 37fca60140f..d86796e58fb 100644 --- a/modules/caddyhttp/reverseproxy/streaming.go +++ b/modules/caddyhttp/reverseproxy/streaming.go @@ -246,7 +246,7 @@ func (h *Handler) handleH2UpgradeTunnel( var ( toBackend int64 fromBackend int64 - result = "closed" + result string ) // when a stream timeout is encountered, no error will be read from errc From 307dfd0431b5432bcf7c06b2718e0ea0de966f8a Mon Sep 17 00:00:00 2001 From: Francis Lavoie Date: Mon, 13 Apr 2026 05:44:03 -0400 Subject: [PATCH 03/17] Improved logging facilities --- modules/caddyhttp/reverseproxy/caddyfile.go | 41 +++++++++- .../caddyhttp/reverseproxy/reverseproxy.go | 80 ++++++++++++++++++- modules/caddyhttp/reverseproxy/streaming.go | 46 ++++++++--- .../caddyhttp/reverseproxy/streaming_test.go | 21 ++++- 4 files changed, 164 insertions(+), 24 deletions(-) diff --git a/modules/caddyhttp/reverseproxy/caddyfile.go b/modules/caddyhttp/reverseproxy/caddyfile.go index 07277b4f133..c692267c10c 100644 --- a/modules/caddyhttp/reverseproxy/caddyfile.go +++ b/modules/caddyhttp/reverseproxy/caddyfile.go @@ -100,7 +100,11 @@ func parseCaddyfile(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, error) // stream_timeout // stream_close_delay // stream_retain_on_reload -// stream_log_skip_handshake +// stream_logs { +// level +// logger_name +// skip_handshake +// } // verbose_logs // // # request manipulation @@ -711,11 +715,42 @@ func (h *Handler) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { } h.StreamRetainOnReload = true - case "stream_log_skip_handshake": + case "stream_logs": if d.NextArg() { return d.ArgErr() } - h.StreamLogSkipHandshake = true + if h.StreamLogs == nil { + h.StreamLogs = new(StreamLogs) + } + + nesting := d.Nesting() + for d.NextBlock(nesting) { + switch d.Val() { + case "level": + if !d.NextArg() { + return d.ArgErr() + } + h.StreamLogs.Level = d.Val() + if d.NextArg() { + return d.ArgErr() + } + case "logger_name": + if !d.NextArg() { + return d.ArgErr() + } + h.StreamLogs.LoggerName = d.Val() + if d.NextArg() { + return d.ArgErr() + } + case "skip_handshake": + if d.NextArg() { + return d.ArgErr() + } + h.StreamLogs.SkipHandshake = true + default: + return d.Errf("unrecognized stream_logs option: %s", d.Val()) + } + } case "trusted_proxies": for d.NextArg() { diff --git a/modules/caddyhttp/reverseproxy/reverseproxy.go b/modules/caddyhttp/reverseproxy/reverseproxy.go index adb47a9b8da..3ebded469df 100644 --- a/modules/caddyhttp/reverseproxy/reverseproxy.go +++ b/modules/caddyhttp/reverseproxy/reverseproxy.go @@ -193,10 +193,9 @@ type Handler struct { // connections are closed on reload (optionally delayed by stream_close_delay). StreamRetainOnReload bool `json:"stream_retain_on_reload,omitempty"` - // If true, suppresses the access log entry normally emitted when an - // upgraded stream handshake completes and the request unwinds. By default - // the handshake is still logged as a normal request with status 101. - StreamLogSkipHandshake bool `json:"stream_log_skip_handshake,omitempty"` + // Controls logging behavior for upgraded stream lifecycle events. + // If omitted, defaults are used (level=DEBUG, logger_name="http.handlers.reverse_proxy.stream"). + StreamLogs *StreamLogs `json:"stream_logs,omitempty"` // If configured, rewrites the copy of the upstream request. // Allows changing the request method and URI (path and query). @@ -259,8 +258,34 @@ type Handler struct { ctx caddy.Context logger *zap.Logger events *caddyevents.App + + streamLogLevel zapcore.Level + streamLogLoggerName string +} + +// StreamLogs controls logging for upgraded stream lifecycle events. +type StreamLogs struct { + // The minimum level at which stream lifecycle events are logged. + // Supported values are debug, info, warn, and error. Default: debug. + Level string `json:"level,omitempty"` + + // Logger name for stream lifecycle logs. Default: "http.handlers.reverse_proxy.stream". + // Special value "access" uses the access logger namespace and, if set, + // respects the first value in access_logger_names/log_name for the request. + LoggerName string `json:"logger_name,omitempty"` + + // If true, suppresses the access log entry normally emitted when an + // upgraded stream handshake completes and the request unwinds. + SkipHandshake bool `json:"skip_handshake,omitempty"` } +const ( + defaultStreamLogLevel = zapcore.DebugLevel + defaultStreamLoggerName = "http.handlers.reverse_proxy.stream" + streamLoggerNameUseAccess = "access" + defaultAccessLoggerBase = "http.log.access" +) + // CaddyModule returns the Caddy module information. func (Handler) CaddyModule() caddy.ModuleInfo { return caddy.ModuleInfo{ @@ -279,6 +304,20 @@ func (h *Handler) Provision(ctx caddy.Context) error { h.ctx = ctx h.logger = ctx.Logger() h.tunnel = newTunnelState(h.logger, time.Duration(h.StreamCloseDelay)) + h.streamLogLevel = defaultStreamLogLevel + h.streamLogLoggerName = defaultStreamLoggerName + if h.StreamLogs != nil { + if h.StreamLogs.Level != "" { + lvl, err := zapcore.ParseLevel(strings.ToLower(strings.TrimSpace(h.StreamLogs.Level))) + if err != nil { + return fmt.Errorf("invalid stream_logs.level %q: %w", h.StreamLogs.Level, err) + } + h.streamLogLevel = lvl + } + if name := strings.TrimSpace(h.StreamLogs.LoggerName); name != "" { + h.streamLogLoggerName = name + } + } // warn about unsafe buffering config if h.RequestBuffers == -1 || h.ResponseBuffers == -1 { @@ -447,6 +486,39 @@ func (h *Handler) Provision(ctx caddy.Context) error { return nil } +func (h Handler) streamLogsSkipHandshake() bool { + return h.StreamLogs != nil && h.StreamLogs.SkipHandshake +} + +func (h Handler) streamLoggerForRequest(req *http.Request) *zap.Logger { + name := strings.TrimSpace(h.streamLogLoggerName) + if name == "" { + name = defaultStreamLoggerName + } + + if name == streamLoggerNameUseAccess { + logger := caddy.Log().Named(defaultAccessLoggerBase) + names := caddyhttp.GetVar(req.Context(), caddyhttp.AccessLoggerNameVarKey) + namesSlice, ok := names.([]any) + if !ok { + return logger + } + for _, v := range namesSlice { + name, ok := v.(string) + if !ok { + continue + } + if name == "" { + return logger + } + return logger.Named(name) + } + return logger + } + + return caddy.Log().Named(name) +} + // Cleanup cleans up the resources made by h. func (h *Handler) Cleanup() error { if !h.StreamRetainOnReload { diff --git a/modules/caddyhttp/reverseproxy/streaming.go b/modules/caddyhttp/reverseproxy/streaming.go index d86796e58fb..9aece01e808 100644 --- a/modules/caddyhttp/reverseproxy/streaming.go +++ b/modules/caddyhttp/reverseproxy/streaming.go @@ -208,26 +208,31 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup, } deleteFrontConn := tunnel.registerConnection(conn, gracefulClose(conn, false), upstreamAddr) deleteBackConn := tunnel.registerConnection(backConn, gracefulClose(backConn, true), upstreamAddr) - if h.StreamLogSkipHandshake { + if h.streamLogsSkipHandshake() { caddyhttp.SetVar(req.Context(), caddyhttp.LogSkipVar, true) } repl := req.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer) repl.Set("http.reverse_proxy.upgraded", true) + streamUUID, _ := repl.GetString("http.request.uuid") + streamFields := makeStreamLogFields(streamUUID) + streamLogger := h.streamLoggerForRequest(req) + streamLevel := h.streamLogLevel finishMetrics := trackActiveStream(upstreamAddr) start := time.Now() if isH2 { - h.handleH2UpgradeTunnel(logger, wg, conn, backConn, deleteFrontConn, deleteBackConn, bufferSize, streamTimeout, start, finishMetrics) + h.handleH2UpgradeTunnel(streamLogger, streamLevel, wg, conn, backConn, deleteFrontConn, deleteBackConn, bufferSize, streamTimeout, start, finishMetrics, streamFields) } else { - h.handleDetachedUpgradeTunnel(logger, conn, backConn, deleteFrontConn, deleteBackConn, bufferSize, streamTimeout, start, finishMetrics) + h.handleDetachedUpgradeTunnel(streamLogger, streamLevel, conn, backConn, deleteFrontConn, deleteBackConn, bufferSize, streamTimeout, start, finishMetrics, streamFields) // Return immediately without touching wg. finalizeResponse's // wg.Wait() returns at once since wg was never incremented. } } func (h *Handler) handleH2UpgradeTunnel( - logger *zap.Logger, + streamLogger *zap.Logger, + streamLevel zapcore.Level, wg *sync.WaitGroup, conn io.ReadWriteCloser, backConn io.ReadWriteCloser, @@ -237,6 +242,7 @@ func (h *Handler) handleH2UpgradeTunnel( streamTimeout time.Duration, start time.Time, finishMetrics func(result string, duration time.Duration, toBackend, fromBackend int64), + streamFields []zap.Field, ) { // H2 extended connect: ServeHTTP must block because rw and req.Body are // only valid while the handler goroutine is running. Defers clean up @@ -275,12 +281,12 @@ func (h *Handler) handleH2UpgradeTunnel( select { case err := <-errc: result = classifyStreamResult(err) - if c := logger.Check(zapcore.DebugLevel, "streaming error"); c != nil { + if c := streamLogger.Check(streamLevel, "streaming error"); c != nil { c.Write(zap.Error(err)) } case t := <-timeoutc: result = "timeout" - if c := logger.Check(zapcore.DebugLevel, "stream timed out"); c != nil { + if c := streamLogger.Check(streamLevel, "stream timed out"); c != nil { c.Write(zap.Time("timeout", t)) } } @@ -292,17 +298,20 @@ func (h *Handler) handleH2UpgradeTunnel( wg.Wait() finishMetrics(result, time.Since(start), toBackend, fromBackend) - if c := logger.Check(zapcore.DebugLevel, "connection closed"); c != nil { - c.Write( + if c := streamLogger.Check(streamLevel, "connection closed"); c != nil { + fields := append([]zap.Field{}, streamFields...) + fields = append(fields, zap.Duration("duration", time.Since(start)), zap.Int64("bytes_to_backend", toBackend), zap.Int64("bytes_from_backend", fromBackend), ) + c.Write(fields...) } } func (h *Handler) handleDetachedUpgradeTunnel( - logger *zap.Logger, + streamLogger *zap.Logger, + streamLevel zapcore.Level, conn io.ReadWriteCloser, backConn io.ReadWriteCloser, deleteFrontConn func(), @@ -311,6 +320,7 @@ func (h *Handler) handleDetachedUpgradeTunnel( streamTimeout time.Duration, start time.Time, finishMetrics func(result string, duration time.Duration, toBackend, fromBackend int64), + streamFields []zap.Field, ) { // HTTP/1.1 hijacked connection: launch a detached goroutine so that // ServeHTTP can return immediately, allowing the Handler to be GC'd @@ -326,12 +336,14 @@ func (h *Handler) handleDetachedUpgradeTunnel( defer deleteFrontConn() defer func() { finishMetrics(result, time.Since(start), toBackend, fromBackend) - if c := logger.Check(zapcore.DebugLevel, "connection closed"); c != nil { - c.Write( + if c := streamLogger.Check(streamLevel, "connection closed"); c != nil { + fields := append([]zap.Field{}, streamFields...) + fields = append(fields, zap.Duration("duration", time.Since(start)), zap.Int64("bytes_to_backend", toBackend), zap.Int64("bytes_from_backend", fromBackend), ) + c.Write(fields...) } }() @@ -362,12 +374,12 @@ func (h *Handler) handleDetachedUpgradeTunnel( select { case err := <-errc: result = classifyStreamResult(err) - if c := logger.Check(zapcore.DebugLevel, "streaming error"); c != nil { + if c := streamLogger.Check(streamLevel, "streaming error"); c != nil { c.Write(zap.Error(err)) } case t := <-timeoutc: result = "timeout" - if c := logger.Check(zapcore.DebugLevel, "stream timed out"); c != nil { + if c := streamLogger.Check(streamLevel, "stream timed out"); c != nil { c.Write(zap.Time("timeout", t)) } } @@ -388,6 +400,14 @@ func classifyStreamResult(err error) string { return "error" } +func makeStreamLogFields(streamUUID string) []zap.Field { + fields := make([]zap.Field, 0, 1) + if streamUUID != "" { + fields = append(fields, zap.String("uuid", streamUUID)) + } + return fields +} + // flushInterval returns the p.FlushInterval value, conditionally // overriding its value for a specific request/response. func (h Handler) flushInterval(req *http.Request, res *http.Response) time.Duration { diff --git a/modules/caddyhttp/reverseproxy/streaming_test.go b/modules/caddyhttp/reverseproxy/streaming_test.go index d2441739a66..e6a3ce3a1a6 100644 --- a/modules/caddyhttp/reverseproxy/streaming_test.go +++ b/modules/caddyhttp/reverseproxy/streaming_test.go @@ -199,10 +199,14 @@ func TestHandlerCleanupRetainModeClosesOnlyRemovedUpstreams(t *testing.T) { } } -func TestHandlerUnmarshalCaddyfileStreamLogSkipHandshake(t *testing.T) { +func TestHandlerUnmarshalCaddyfileStreamLogsBlock(t *testing.T) { d := caddyfile.NewTestDispenser(` reverse_proxy localhost:9000 { - stream_log_skip_handshake + stream_logs { + level info + logger_name access + skip_handshake + } } `) @@ -210,7 +214,16 @@ func TestHandlerUnmarshalCaddyfileStreamLogSkipHandshake(t *testing.T) { if err := h.UnmarshalCaddyfile(d); err != nil { t.Fatalf("UnmarshalCaddyfile() error = %v", err) } - if !h.StreamLogSkipHandshake { - t.Fatal("expected stream_log_skip_handshake to enable StreamLogSkipHandshake") + if h.StreamLogs == nil { + t.Fatal("expected stream_logs to be configured") + } + if h.StreamLogs.Level != "info" { + t.Fatalf("expected stream_logs.level=info, got %q", h.StreamLogs.Level) + } + if h.StreamLogs.LoggerName != "access" { + t.Fatalf("expected stream_logs.logger_name=access, got %q", h.StreamLogs.LoggerName) + } + if !h.StreamLogs.SkipHandshake { + t.Fatal("expected stream_logs.skip_handshake=true") } } From 7ef9ecd48a20c50733efb8d7414f11ff27fd5599 Mon Sep 17 00:00:00 2001 From: Francis Lavoie Date: Sat, 18 Apr 2026 14:16:20 -0400 Subject: [PATCH 04/17] Adjustments from Weidi's review --- .github/workflows/ci.yml | 4 +- .../reverseproxy_extended_connect_test.go | 328 ++++++++++++++++++ .../integration/stream_reload_stress_test.go | 29 +- modules/caddyhttp/encode/encode.go | 5 - .../reverseproxy/extended_connect_test.go | 146 ++++++++ modules/caddyhttp/reverseproxy/streaming.go | 39 +-- 6 files changed, 510 insertions(+), 41 deletions(-) create mode 100644 caddytest/integration/reverseproxy_extended_connect_test.go create mode 100644 modules/caddyhttp/reverseproxy/extended_connect_test.go diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2c5052723c2..89135cbc2cf 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -132,6 +132,8 @@ jobs: - name: Run tests # id: step_test # continue-on-error: true + env: + GODEBUG: http2xconnect=1 run: | # (go test -v -coverprofile=cover-profile.out -race ./... 2>&1) > test-results/test-result.out go test -v -coverprofile="cover-profile.out" -short -race ./... @@ -191,7 +193,7 @@ jobs: retries=3 exit_code=0 while ((retries > 0)); do - CGO_ENABLED=0 go test -p 1 -v ./... + GODEBUG=http2xconnect=1 CGO_ENABLED=0 go test -p 1 -v ./... exit_code=$? if ((exit_code == 0)); then break diff --git a/caddytest/integration/reverseproxy_extended_connect_test.go b/caddytest/integration/reverseproxy_extended_connect_test.go new file mode 100644 index 00000000000..8822988be09 --- /dev/null +++ b/caddytest/integration/reverseproxy_extended_connect_test.go @@ -0,0 +1,328 @@ +package integration + +import ( + "bytes" + "crypto/tls" + "errors" + "fmt" + "io" + "net" + "net/http" + "strings" + "testing" + "time" + + "golang.org/x/net/http2" + "golang.org/x/net/http2/hpack" + + "github.com/caddyserver/caddy/v2/caddytest" +) + +var errExtendedConnectUnsupportedByPeer = errors.New("peer did not advertise RFC 8441 extended CONNECT support") + +func TestReverseProxyExtendedConnectOverH2(t *testing.T) { + tester := caddytest.NewTester(t) + backend := newWebsocketUpgradeEchoBackend(t) + defer backend.Close() + + tester.InitServer(fmt.Sprintf(` +{ + admin localhost:2999 + http_port 9080 + https_port 9443 + grace_period 1ns + skip_install_trust + servers :9443 { + protocols h2 + } +} + +https://localhost:9443 { + reverse_proxy %s +} +`, backend.addr), "caddyfile") + + const payload = "extended-connect-echo\n" + if err := assertExtendedConnectH2Echo("localhost:9443", payload); err != nil { + if errors.Is(err, errExtendedConnectUnsupportedByPeer) { + t.Skipf("skipping extended CONNECT integration test: %v", err) + } + t.Fatalf("extended connect h2 echo failed: %v", err) + } +} + +func assertExtendedConnectH2Echo(addr, payload string) error { + conn, err := tlsDialH2(addr) + if err != nil { + return fmt.Errorf("dialing h2 tls: %w", err) + } + defer conn.Close() + + if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil { + return fmt.Errorf("setting deadline: %w", err) + } + + fr := http2.NewFramer(conn, conn) + + if _, err := conn.Write([]byte(http2.ClientPreface)); err != nil { + return fmt.Errorf("writing client preface: %w", err) + } + if err := fr.WriteSettings(http2.Setting{ID: http2.SettingEnableConnectProtocol, Val: 1}); err != nil { + return fmt.Errorf("writing client settings: %w", err) + } + + supported, err := waitForServerSettings(fr) + if err != nil { + return err + } + if !supported { + return errExtendedConnectUnsupportedByPeer + } + if err := waitForSettingsAck(fr); err != nil { + return err + } + + if err := writeExtendedConnectHeaders(fr, addr); err != nil { + return err + } + + status, err := readResponseStatus(fr, 1) + if err != nil { + return err + } + if status != "200" { + return fmt.Errorf("unexpected extended connect status: got=%s want=200", status) + } + + if err := fr.WriteData(1, false, []byte(payload)); err != nil { + return fmt.Errorf("writing stream data: %w", err) + } + + echo, err := readStreamData(fr, 1, len(payload)) + if err != nil { + return err + } + if echo != payload { + return fmt.Errorf("unexpected echoed payload: got=%q want=%q", echo, payload) + } + + _ = fr.WriteRSTStream(1, http2.ErrCodeNo) + return nil +} + +func tlsDialH2(addr string) (net.Conn, error) { + var lastErr error + for i := 0; i < 30; i++ { + dialer := &net.Dialer{Timeout: 2 * time.Second} + conn, err := tls.DialWithDialer(dialer, "tcp", addr, &tls.Config{ + ServerName: "localhost", + InsecureSkipVerify: true, + NextProtos: []string{"h2"}, + }) + if err == nil { + return conn, nil + } + lastErr = err + time.Sleep(100 * time.Millisecond) + } + return nil, lastErr +} + +func waitForServerSettings(fr *http2.Framer) (bool, error) { + for { + frame, err := fr.ReadFrame() + if err != nil { + return false, fmt.Errorf("reading frame before connect: %w", err) + } + settings, ok := frame.(*http2.SettingsFrame) + if !ok { + continue + } + if settings.IsAck() { + continue + } + + supported := false + if err := settings.ForeachSetting(func(s http2.Setting) error { + if s.ID == http2.SettingEnableConnectProtocol && s.Val == 1 { + supported = true + } + return nil + }); err != nil { + return false, fmt.Errorf("reading server settings: %w", err) + } + + if err := fr.WriteSettingsAck(); err != nil { + return false, fmt.Errorf("writing settings ack: %w", err) + } + return supported, nil + } +} + +func waitForSettingsAck(fr *http2.Framer) error { + for { + frame, err := fr.ReadFrame() + if err != nil { + return fmt.Errorf("reading settings ack: %w", err) + } + settings, ok := frame.(*http2.SettingsFrame) + if ok && settings.IsAck() { + return nil + } + } +} + +func writeExtendedConnectHeaders(fr *http2.Framer, addr string) error { + var hb bytes.Buffer + enc := hpack.NewEncoder(&hb) + for _, hf := range []hpack.HeaderField{ + {Name: ":method", Value: "CONNECT"}, + {Name: ":scheme", Value: "https"}, + {Name: ":authority", Value: addr}, + {Name: ":path", Value: "/upgrade"}, + {Name: ":protocol", Value: "websocket"}, + } { + if err := enc.WriteField(hf); err != nil { + return fmt.Errorf("encoding request headers: %w", err) + } + } + + if err := fr.WriteHeaders(http2.HeadersFrameParam{ + StreamID: 1, + BlockFragment: hb.Bytes(), + EndHeaders: true, + EndStream: false, + }); err != nil { + return fmt.Errorf("writing extended connect headers: %w", err) + } + return nil +} + +func readResponseStatus(fr *http2.Framer, streamID uint32) (string, error) { + var block bytes.Buffer + + for { + frame, err := fr.ReadFrame() + if err != nil { + return "", fmt.Errorf("reading response headers: %w", err) + } + if rst, ok := frame.(*http2.RSTStreamFrame); ok && rst.StreamID == streamID { + return "", fmt.Errorf("stream reset before response headers: %s", rst.ErrCode) + } + + h, ok := frame.(*http2.HeadersFrame) + if !ok || h.StreamID != streamID { + continue + } + + if _, err := block.Write(h.HeaderBlockFragment()); err != nil { + return "", fmt.Errorf("buffering response header fragment: %w", err) + } + for !h.HeadersEnded() { + next, err := fr.ReadFrame() + if err != nil { + return "", fmt.Errorf("reading continuation frame: %w", err) + } + c, ok := next.(*http2.ContinuationFrame) + if !ok || c.StreamID != streamID { + continue + } + if _, err := block.Write(c.HeaderBlockFragment()); err != nil { + return "", fmt.Errorf("buffering continuation fragment: %w", err) + } + if c.HeadersEnded() { + break + } + } + break + } + + var status string + dec := hpack.NewDecoder(4096, func(f hpack.HeaderField) { + if f.Name == ":status" { + status = f.Value + } + }) + if _, err := dec.Write(block.Bytes()); err != nil { + return "", fmt.Errorf("decoding response header block: %w", err) + } + if status == "" { + return "", fmt.Errorf("missing :status in response headers") + } + return status, nil +} + +func readStreamData(fr *http2.Framer, streamID uint32, n int) (string, error) { + buf := make([]byte, 0, n) + for len(buf) < n { + frame, err := fr.ReadFrame() + if err != nil { + return "", fmt.Errorf("reading stream data: %w", err) + } + d, ok := frame.(*http2.DataFrame) + if !ok || d.StreamID != streamID { + continue + } + buf = append(buf, d.Data()...) + } + return string(buf[:n]), nil +} + +type websocketUpgradeEchoBackend struct { + addr string + ln net.Listener + server *http.Server +} + +func newWebsocketUpgradeEchoBackend(t *testing.T) *websocketUpgradeEchoBackend { + t.Helper() + + backend := &websocketUpgradeEchoBackend{} + backend.server = &http.Server{ + Handler: http.HandlerFunc(backend.serveHTTP), + } + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listening for websocket backend: %v", err) + } + backend.ln = ln + backend.addr = ln.Addr().String() + + go func() { + _ = backend.server.Serve(ln) + }() + + return backend +} + +func (b *websocketUpgradeEchoBackend) serveHTTP(w http.ResponseWriter, r *http.Request) { + if !strings.EqualFold(r.Header.Get("Connection"), "Upgrade") || !strings.EqualFold(r.Header.Get("Upgrade"), "websocket") { + http.Error(w, "upgrade required", http.StatusUpgradeRequired) + return + } + + hijacker, ok := w.(http.Hijacker) + if !ok { + http.Error(w, "hijacking not supported", http.StatusInternalServerError) + return + } + + conn, rw, err := hijacker.Hijack() + if err != nil { + return + } + + _, _ = rw.WriteString("HTTP/1.1 101 Switching Protocols\r\nConnection: Upgrade\r\nUpgrade: websocket\r\n\r\n") + _ = rw.Flush() + + go func() { + defer conn.Close() + _, _ = io.Copy(conn, conn) + }() +} + +func (b *websocketUpgradeEchoBackend) Close() { + _ = b.server.Close() + _ = b.ln.Close() +} diff --git a/caddytest/integration/stream_reload_stress_test.go b/caddytest/integration/stream_reload_stress_test.go index 45473e2219d..ff140f9a45c 100644 --- a/caddytest/integration/stream_reload_stress_test.go +++ b/caddytest/integration/stream_reload_stress_test.go @@ -21,9 +21,11 @@ import ( "github.com/caddyserver/caddy/v2/caddytest" ) -// stressCloseDelay is the stream_close_delay used for the close_delay scenario. -// Long enough to outlast all test reloads; short enough to keep total test time reasonable. -const stressCloseDelay = 3 * time.Second +const ( + defaultStressStreamCount = 1 + defaultStressReloadCount = 1 + defaultStressCloseDelay = 500 * time.Millisecond +) func TestReverseProxyReloadStressUpgradedStreamsHeapProfiles(t *testing.T) { tester := caddytest.NewTester(t).WithDefaultOverrides(caddytest.Config{ @@ -43,7 +45,7 @@ func TestReverseProxyReloadStressUpgradedStreamsHeapProfiles(t *testing.T) { // Reloads are spread across time and interleaved with echo-checks so // stream health is exercised at each reload boundary, not only at the end. legacy := runReloadStress(t, tester, backend.addr, "legacy", false, 0) - closeDelay := runReloadStress(t, tester, backend.addr, "close_delay", false, stressCloseDelay) + closeDelay := runReloadStress(t, tester, backend.addr, "close_delay", false, stressCloseDelay(t)) retain := runReloadStress(t, tester, backend.addr, "retain", true, 0) if legacy.aliveAfterReloads != 0 { @@ -110,8 +112,8 @@ func runReloadStress(t *testing.T, tester *caddytest.Tester, backendAddr, mode s const echoEvery = 6 // perform an echo check every N reloads - streamCount := envIntOrDefault(t, "CADDY_STRESS_STREAM_COUNT", 12) - reloadCount := envIntOrDefault(t, "CADDY_STRESS_RELOAD_COUNT", 24) + streamCount := envIntOrDefault(t, "CADDY_STRESS_STREAM_COUNT", defaultStressStreamCount) + reloadCount := envIntOrDefault(t, "CADDY_STRESS_RELOAD_COUNT", defaultStressReloadCount) tester.InitServer(reloadStressConfig(backendAddr, retain, closeDelay, 0), "caddyfile") @@ -209,6 +211,21 @@ func envIntOrDefault(t *testing.T, key string, def int) int { return v } +func stressCloseDelay(t *testing.T) time.Duration { + t.Helper() + + const key = "CADDY_STRESS_CLOSE_DELAY" + raw := strings.TrimSpace(os.Getenv(key)) + if raw == "" { + return defaultStressCloseDelay + } + v, err := time.ParseDuration(raw) + if err != nil || v <= 0 { + t.Fatalf("invalid %s=%q: must be a positive duration", key, raw) + } + return v +} + func loadCaddyfileConfig(t *testing.T, rawConfig string) { t.Helper() diff --git a/modules/caddyhttp/encode/encode.go b/modules/caddyhttp/encode/encode.go index ecf85495a39..0474768f0d7 100644 --- a/modules/caddyhttp/encode/encode.go +++ b/modules/caddyhttp/encode/encode.go @@ -405,11 +405,6 @@ func (rw *responseWriter) ReadFrom(r io.Reader) (int64, error) { // Close writes any remaining buffered response and // deallocates any active resources. func (rw *responseWriter) Close() error { - if caddyhttp.ResponseWriterHijacked(rw.ResponseWriter) { - rw.releaseEncoder() - return nil - } - // didn't write, probably head request if !rw.wroteHeader { cl, err := strconv.Atoi(rw.Header().Get("Content-Length")) diff --git a/modules/caddyhttp/reverseproxy/extended_connect_test.go b/modules/caddyhttp/reverseproxy/extended_connect_test.go new file mode 100644 index 00000000000..5cb27d807e3 --- /dev/null +++ b/modules/caddyhttp/reverseproxy/extended_connect_test.go @@ -0,0 +1,146 @@ +package reverseproxy + +import ( + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + + "go.uber.org/zap" + + "github.com/caddyserver/caddy/v2/modules/caddyhttp" +) + +type extendedConnectCapture struct { + method string + headers http.Header + body []byte + extendedBodyPresent bool + extendedConnectBody []byte +} + +type extendedConnectCaptureTransport struct { + mu sync.Mutex + capture extendedConnectCapture +} + +func (tr *extendedConnectCaptureTransport) RoundTrip(req *http.Request) (*http.Response, error) { + body, err := io.ReadAll(req.Body) + if err != nil { + return nil, err + } + + c := extendedConnectCapture{ + method: req.Method, + headers: req.Header.Clone(), + body: body, + } + if rc, ok := caddyhttp.GetVar(req.Context(), "extended_connect_websocket_body").(io.ReadCloser); ok { + c.extendedBodyPresent = true + c.extendedConnectBody, err = io.ReadAll(rc) + if err != nil { + return nil, err + } + _ = rc.Close() + } + + tr.mu.Lock() + tr.capture = c + tr.mu.Unlock() + + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader("ok")), + Request: req, + }, nil +} + +func (tr *extendedConnectCaptureTransport) Snapshot() extendedConnectCapture { + tr.mu.Lock() + defer tr.mu.Unlock() + return tr.capture +} + +func TestServeHTTPRewritesExtendedConnectWebsocketRequest(t *testing.T) { + tests := []struct { + name string + protoMajor int + proto string + headers map[string]string + }{ + { + name: "h2 extended connect", + protoMajor: 2, + proto: "HTTP/2.0", + headers: map[string]string{ + ":protocol": "websocket", + }, + }, + { + name: "h3 extended connect", + protoMajor: 3, + proto: "websocket", + headers: map[string]string{}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + const payload = "extended-connect-body" + + transport := new(extendedConnectCaptureTransport) + h := &Handler{ + logger: zap.NewNop(), + Transport: transport, + Upstreams: UpstreamPool{ + &Upstream{Host: new(Host), Dial: "127.0.0.1:8443"}, + }, + LoadBalancing: &LoadBalancing{ + SelectionPolicy: &RoundRobinSelection{}, + }, + } + + req := httptest.NewRequest(http.MethodConnect, "http://example.test/upgrade", strings.NewReader(payload)) + req.ProtoMajor = tc.protoMajor + req.Proto = tc.proto + for key, value := range tc.headers { + req.Header.Set(key, value) + } + req = prepareTestRequest(req) + + rr := httptest.NewRecorder() + err := h.ServeHTTP(rr, req, caddyhttp.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { + return nil + })) + if err != nil { + t.Fatalf("ServeHTTP() error = %v", err) + } + + captured := transport.Snapshot() + if captured.method != http.MethodGet { + t.Fatalf("upstream method = %s, want %s", captured.method, http.MethodGet) + } + if got := captured.headers.Get("Upgrade"); !strings.EqualFold(got, "websocket") { + t.Fatalf("Upgrade header = %q, want websocket", got) + } + if got := captured.headers.Get("Connection"); !strings.EqualFold(got, "Upgrade") { + t.Fatalf("Connection header = %q, want Upgrade", got) + } + if got := captured.headers.Get(":protocol"); got != "" { + t.Fatalf(":protocol header should be removed, got %q", got) + } + if len(captured.body) != 0 { + t.Fatalf("upstream request body length = %d, want 0", len(captured.body)) + } + if !captured.extendedBodyPresent { + t.Fatal("extended_connect_websocket_body variable missing from request context") + } + if string(captured.extendedConnectBody) != payload { + t.Fatalf("extended_connect_websocket_body = %q, want %q", string(captured.extendedConnectBody), payload) + } + }) + } +} diff --git a/modules/caddyhttp/reverseproxy/streaming.go b/modules/caddyhttp/reverseproxy/streaming.go index 9aece01e808..407c4acf9f9 100644 --- a/modules/caddyhttp/reverseproxy/streaming.go +++ b/modules/caddyhttp/reverseproxy/streaming.go @@ -100,14 +100,14 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup, streamTimeout := time.Duration(h.StreamTimeout) var ( - conn io.ReadWriteCloser - brw *bufio.ReadWriter - isH2 bool + conn io.ReadWriteCloser + brw *bufio.ReadWriter + isExtendedConnect bool ) // websocket over http2 or http3 if extended connect is enabled, assuming backend doesn't support this, the request will be modified to http1.1 upgrade // TODO: once we can reliably detect backend support this, it can be removed for those backends if body, ok := caddyhttp.GetVar(req.Context(), "extended_connect_websocket_body").(io.ReadCloser); ok { - isH2 = true + isExtendedConnect = true req.Body = body rw.Header().Del("Upgrade") rw.Header().Del("Connection") @@ -115,13 +115,13 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup, rw.WriteHeader(http.StatusOK) if c := logger.Check(zap.DebugLevel, "upgrading connection"); c != nil { - c.Write(zap.Int("http_version", 2)) + c.Write(zap.Int("http_version", req.ProtoMajor)) } //nolint:bodyclose flushErr := http.NewResponseController(rw).Flush() if flushErr != nil { - if c := h.logger.Check(zap.ErrorLevel, "failed to flush http2 websocket response"); c != nil { + if c := h.logger.Check(zap.ErrorLevel, "failed to flush extended_connect websocket response"); c != nil { c.Write(zap.Error(flushErr)) } return @@ -154,25 +154,6 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup, } } - // For H2 extended connect: close backConn when the request context is - // cancelled (e.g. client disconnects). For HTTP/1.1 hijacked connections - // we skip this because req.Context() may be cancelled when ServeHTTP - // returns early, which would prematurely close the backend connection. - if isH2 { - // adopted from https://github.com/golang/go/commit/8bcf2834afdf6a1f7937390903a41518715ef6f5 - backConnCloseCh := make(chan struct{}) - go func() { - // Ensure that the cancellation of a request closes the backend. - // See issue https://golang.org/issue/35559. - select { - case <-req.Context().Done(): - case <-backConnCloseCh: - } - backConn.Close() - }() - defer close(backConnCloseCh) - } - if err := brw.Flush(); err != nil { if c := logger.Check(zapcore.DebugLevel, "response flush"); c != nil { c.Write(zap.Error(err)) @@ -221,8 +202,8 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup, start := time.Now() - if isH2 { - h.handleH2UpgradeTunnel(streamLogger, streamLevel, wg, conn, backConn, deleteFrontConn, deleteBackConn, bufferSize, streamTimeout, start, finishMetrics, streamFields) + if isExtendedConnect { + h.handleExtendedConnectUpgradeTunnel(streamLogger, streamLevel, wg, conn, backConn, deleteFrontConn, deleteBackConn, bufferSize, streamTimeout, start, finishMetrics, streamFields) } else { h.handleDetachedUpgradeTunnel(streamLogger, streamLevel, conn, backConn, deleteFrontConn, deleteBackConn, bufferSize, streamTimeout, start, finishMetrics, streamFields) // Return immediately without touching wg. finalizeResponse's @@ -230,7 +211,7 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup, } } -func (h *Handler) handleH2UpgradeTunnel( +func (h *Handler) handleExtendedConnectUpgradeTunnel( streamLogger *zap.Logger, streamLevel zapcore.Level, wg *sync.WaitGroup, @@ -244,7 +225,7 @@ func (h *Handler) handleH2UpgradeTunnel( finishMetrics func(result string, duration time.Duration, toBackend, fromBackend int64), streamFields []zap.Field, ) { - // H2 extended connect: ServeHTTP must block because rw and req.Body are + // Extended CONNECT: ServeHTTP must block because rw and req.Body are // only valid while the handler goroutine is running. Defers clean up // when the select below fires and this function returns. defer deleteBackConn() From b9b12025c62c64e712fa13ccb059557a9fcef756 Mon Sep 17 00:00:00 2001 From: WeidiDeng Date: Tue, 21 Apr 2026 10:06:30 +0800 Subject: [PATCH 05/17] record bytes read and written for response writers unless detached --- modules/caddyhttp/responsewriter.go | 90 ++++++++++++++++++------ modules/caddyhttp/responsewriter_test.go | 8 +-- 2 files changed, 74 insertions(+), 24 deletions(-) diff --git a/modules/caddyhttp/responsewriter.go b/modules/caddyhttp/responsewriter.go index d5b43bf42de..a477c7abe1f 100644 --- a/modules/caddyhttp/responsewriter.go +++ b/modules/caddyhttp/responsewriter.go @@ -21,6 +21,8 @@ import ( "io" "net" "net/http" + + "github.com/caddyserver/caddy/v2" ) // ResponseWriterWrapper wraps an underlying ResponseWriter and @@ -71,6 +73,7 @@ type responseRecorder struct { wroteHeader bool stream bool hijacked bool + detached bool readSize *int } @@ -155,12 +158,6 @@ func (rr *responseRecorder) WriteHeader(statusCode int) { // save statusCode always, in case HTTP middleware upgrades websocket // connections by manually setting headers and writing status 101 rr.statusCode = statusCode - if statusCode == http.StatusSwitchingProtocols { - rr.stream = true - rr.wroteHeader = true - rr.ResponseWriterWrapper.WriteHeader(statusCode) - return - } // decide whether we should buffer the response if rr.shouldBuffer == nil { @@ -169,12 +166,12 @@ func (rr *responseRecorder) WriteHeader(statusCode int) { rr.stream = !rr.shouldBuffer(rr.statusCode, rr.ResponseWriterWrapper.Header()) } - // 1xx responses aren't final; just informational - if statusCode < 100 || statusCode > 199 { + // 1xx responses except 101 aren't final; just informational + if statusCode < 100 || statusCode > 199 || statusCode == http.StatusSwitchingProtocols { rr.wroteHeader = true } - // if informational or not buffered, immediately write header + // if 1xx or not buffered, immediately write header if rr.stream || (100 <= statusCode && statusCode <= 199) { rr.ResponseWriterWrapper.WriteHeader(statusCode) } @@ -230,8 +227,12 @@ func (rr *responseRecorder) Buffered() bool { return !rr.stream } -func (rr *responseRecorder) Hijacked() bool { - return rr.hijacked +func (rr *responseRecorder) DetachAfterHijack(detached bool) bool { + if rr.hijacked { + return false + } + rr.detached = detached + return true } func (rr *responseRecorder) WriteResponse() error { @@ -268,6 +269,12 @@ func (rr *responseRecorder) setReadSize(size *int) { } func (rr *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if !rr.wroteHeader { + // hijacking without writing status code first works as long as subsequent writes follows http1.1 + // wire format, but it will show up with a status code of 0 in the access log and bytes written + // will include response headers. + caddy.Log().Debug("hijacking without writing status code first") + } //nolint:bodyclose conn, brw, err := http.NewResponseController(rr.ResponseWriterWrapper).Hijack() if err != nil { @@ -276,13 +283,16 @@ func (rr *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { rr.hijacked = true rr.stream = true rr.wroteHeader = true - // Per http documentation, returned bufio.Writer is empty, but bufio.Read maybe not. - // Return the raw hijacked connection so upgraded stream traffic does not keep - // traversing the response recorder hot path. + if rr.detached { + return conn, brw, nil + } + // Per http documentation, returned bufio.Writer is empty, but bufio.Read maybe not + conn = &hijackedConn{conn, rr} brw.Writer.Reset(conn) buffered := brw.Reader.Buffered() if buffered != 0 { + conn.(*hijackedConn).updateReadSize(buffered) data, _ := brw.Peek(buffered) brw.Reader.Reset(io.MultiReader(bytes.NewReader(data), conn)) // peek to make buffered data appear, as Reset will make it 0 @@ -293,12 +303,49 @@ func (rr *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { return conn, brw, nil } -// ResponseWriterHijacked reports whether w or one of its wrapped response -// writers has been hijacked. -func ResponseWriterHijacked(w http.ResponseWriter) bool { +// used to track the size of hijacked response writers +type hijackedConn struct { + net.Conn + rr *responseRecorder +} + +func (hc *hijackedConn) updateReadSize(n int) { + if hc.rr.readSize != nil { + *hc.rr.readSize += n + } +} + +func (hc *hijackedConn) Read(p []byte) (int, error) { + n, err := hc.Conn.Read(p) + hc.updateReadSize(n) + return n, err +} + +func (hc *hijackedConn) WriteTo(w io.Writer) (int64, error) { + n, err := io.Copy(w, hc.Conn) + hc.updateReadSize(int(n)) + return n, err +} + +func (hc *hijackedConn) Write(p []byte) (int, error) { + n, err := hc.Conn.Write(p) + hc.rr.size += n + return n, err +} + +func (hc *hijackedConn) ReadFrom(r io.Reader) (int64, error) { + n, err := io.Copy(hc.Conn, r) + hc.rr.size += int(n) + return n, err +} + +// DetachResponseWriterAfterHijack detaches w or one of its wrapped response +// writers when it's hijacked. Returns true if not already hijacked. +// When detached, bytes read or written stats will not be recorded for the hijacked connection, and it's safe to use the connection after http middleware returns. +func DetachResponseWriterAfterHijack(w http.ResponseWriter, detached bool) bool { for w != nil { - if hijacked, ok := w.(interface{ Hijacked() bool }); ok && hijacked.Hijacked() { - return true + if detacher, ok := w.(interface{ DetachAfterHijack(bool) bool }); ok { + return detacher.DetachAfterHijack(detached) } unwrapper, ok := w.(interface{ Unwrap() http.ResponseWriter }) if !ok { @@ -321,7 +368,7 @@ type ResponseRecorder interface { Status() int Buffer() *bytes.Buffer Buffered() bool - Hijacked() bool + DetachAfterHijack(bool) bool Size() int WriteResponse() error } @@ -341,4 +388,7 @@ var ( // see PR #5022 (25%-50% speedup) _ io.ReaderFrom = (*ResponseWriterWrapper)(nil) _ io.ReaderFrom = (*responseRecorder)(nil) + _ io.ReaderFrom = (*hijackedConn)(nil) + + _ io.WriterTo = (*hijackedConn)(nil) ) diff --git a/modules/caddyhttp/responsewriter_test.go b/modules/caddyhttp/responsewriter_test.go index 72e416db1e4..ec8c3b5ab38 100644 --- a/modules/caddyhttp/responsewriter_test.go +++ b/modules/caddyhttp/responsewriter_test.go @@ -246,11 +246,11 @@ func TestResponseRecorderSwitchingProtocolsIsHijackAware(t *testing.T) { } defer conn.Close() - if !rr.Hijacked() { - t.Fatal("response recorder should report hijacked state") + if rr.DetachAfterHijack(true) { + t.Fatal("response recorder should report hijacked state by returning false") } - if !ResponseWriterHijacked(rr) { - t.Fatal("ResponseWriterHijacked() should report true after hijack") + if DetachResponseWriterAfterHijack(rr, true) { + t.Fatal("DetachResponseWriterAfterHijack() should report false after hijack") } if err := rr.WriteResponse(); err != nil { t.Fatalf("WriteResponse() after hijack returned error: %v", err) From e7055d85a4868a62d5ca6d4a5949620654f91c2e Mon Sep 17 00:00:00 2001 From: WeidiDeng Date: Tue, 21 Apr 2026 10:07:13 +0800 Subject: [PATCH 06/17] simplify streaming handling --- modules/caddyhttp/encode/encode.go | 13 +- .../caddyhttp/reverseproxy/reverseproxy.go | 7 +- modules/caddyhttp/reverseproxy/streaming.go | 141 ++++-------------- 3 files changed, 33 insertions(+), 128 deletions(-) diff --git a/modules/caddyhttp/encode/encode.go b/modules/caddyhttp/encode/encode.go index 0474768f0d7..ac995c37b32 100644 --- a/modules/caddyhttp/encode/encode.go +++ b/modules/caddyhttp/encode/encode.go @@ -422,20 +422,13 @@ func (rw *responseWriter) Close() error { var err error if rw.w != nil { err = rw.w.Close() - rw.releaseEncoder() + rw.w.Reset(nil) + rw.config.writerPools[rw.encodingName].Put(rw.w) + rw.w = nil } return err } -func (rw *responseWriter) releaseEncoder() { - if rw.w == nil { - return - } - rw.w.Reset(nil) - rw.config.writerPools[rw.encodingName].Put(rw.w) - rw.w = nil -} - // Unwrap returns the underlying ResponseWriter. func (rw *responseWriter) Unwrap() http.ResponseWriter { return rw.ResponseWriter diff --git a/modules/caddyhttp/reverseproxy/reverseproxy.go b/modules/caddyhttp/reverseproxy/reverseproxy.go index 3ebded469df..e827d5bac94 100644 --- a/modules/caddyhttp/reverseproxy/reverseproxy.go +++ b/modules/caddyhttp/reverseproxy/reverseproxy.go @@ -191,6 +191,9 @@ type Handler struct { // Connections using upstreams that are removed are closed during cleanup. // By default this is false, preserving legacy behavior where upgraded // connections are closed on reload (optionally delayed by stream_close_delay). + // Only http1.1 websocket connections are affected, websockets for h2/h3 are not affected. + // If true, bytes transferred for http1.1 in the access logs will be zero but those stats + // can be found in the stream logs for http1/2/3 regardless if this is enabled. StreamRetainOnReload bool `json:"stream_retain_on_reload,omitempty"` // Controls logging behavior for upgraded stream lifecycle events. @@ -1274,9 +1277,7 @@ func (h *Handler) finalizeResponse( ) error { // deal with 101 Switching Protocols responses: (WebSocket, h2c, etc) if res.StatusCode == http.StatusSwitchingProtocols { - var wg sync.WaitGroup - h.handleUpgradeResponse(logger, &wg, rw, req, res, upstreamAddr) - wg.Wait() + h.handleUpgradeResponse(logger, rw, req, res, upstreamAddr) return nil } diff --git a/modules/caddyhttp/reverseproxy/streaming.go b/modules/caddyhttp/reverseproxy/streaming.go index 407c4acf9f9..3835653df1a 100644 --- a/modules/caddyhttp/reverseproxy/streaming.go +++ b/modules/caddyhttp/reverseproxy/streaming.go @@ -40,12 +40,12 @@ import ( "github.com/caddyserver/caddy/v2/modules/caddyhttp" ) -type h2ReadWriteCloser struct { +type extendedConnectReadWriteCloser struct { io.ReadCloser http.ResponseWriter } -func (rwc h2ReadWriteCloser) Write(p []byte) (n int, err error) { +func (rwc extendedConnectReadWriteCloser) Write(p []byte) (n int, err error) { n, err = rwc.ResponseWriter.Write(p) if err != nil { return 0, err @@ -59,7 +59,7 @@ func (rwc h2ReadWriteCloser) Write(p []byte) (n int, err error) { return n, nil } -func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup, rw http.ResponseWriter, req *http.Request, res *http.Response, upstreamAddr string) { +func (h *Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWriter, req *http.Request, res *http.Response, upstreamAddr string) { reqUpType := upgradeType(req.Header) resUpType := upgradeType(res.Header) @@ -99,15 +99,25 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup, bufferSize := h.StreamBufferSize streamTimeout := time.Duration(h.StreamTimeout) + if h.StreamRetainOnReload { + // the return value should be true as it's not hijacked yet, but some middleware may wrap response writers incorrectly + if !caddyhttp.DetachResponseWriterAfterHijack(rw, true) { + if c := logger.Check(zap.DebugLevel, "detaching connection failed"); c != nil { + c.Write(zap.String("tip", "check if your response writers have an Unwrap method or if already hijacked")) + } + } + } + var ( - conn io.ReadWriteCloser - brw *bufio.ReadWriter - isExtendedConnect bool + conn io.ReadWriteCloser + brw *bufio.ReadWriter + detached = h.StreamRetainOnReload ) // websocket over http2 or http3 if extended connect is enabled, assuming backend doesn't support this, the request will be modified to http1.1 upgrade // TODO: once we can reliably detect backend support this, it can be removed for those backends if body, ok := caddyhttp.GetVar(req.Context(), "extended_connect_websocket_body").(io.ReadCloser); ok { - isExtendedConnect = true + // websocket over extended connect can't be detached. rw and req.Body are only valid while the handler goroutine is running + detached = false req.Body = body rw.Header().Del("Upgrade") rw.Header().Del("Connection") @@ -126,7 +136,7 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup, } return } - conn = h2ReadWriteCloser{req.Body, rw} + conn = extendedConnectReadWriteCloser{req.Body, rw} // bufio is not needed, use minimal buffer brw = bufio.NewReadWriter(bufio.NewReaderSize(conn, 1), bufio.NewWriterSize(conn, 1)) } else { @@ -202,35 +212,20 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup, start := time.Now() - if isExtendedConnect { - h.handleExtendedConnectUpgradeTunnel(streamLogger, streamLevel, wg, conn, backConn, deleteFrontConn, deleteBackConn, bufferSize, streamTimeout, start, finishMetrics, streamFields) + if !detached { + h.handleUpgradeTunnel(streamLogger, streamLevel, conn, backConn, deleteFrontConn, deleteBackConn, bufferSize, streamTimeout, start, finishMetrics, streamFields) } else { - h.handleDetachedUpgradeTunnel(streamLogger, streamLevel, conn, backConn, deleteFrontConn, deleteBackConn, bufferSize, streamTimeout, start, finishMetrics, streamFields) - // Return immediately without touching wg. finalizeResponse's - // wg.Wait() returns at once since wg was never incremented. + // start a new goroutine + go h.handleUpgradeTunnel(streamLogger, streamLevel, conn, backConn, deleteFrontConn, deleteBackConn, bufferSize, streamTimeout, start, finishMetrics, streamFields) } } -func (h *Handler) handleExtendedConnectUpgradeTunnel( - streamLogger *zap.Logger, - streamLevel zapcore.Level, - wg *sync.WaitGroup, - conn io.ReadWriteCloser, - backConn io.ReadWriteCloser, - deleteFrontConn func(), - deleteBackConn func(), - bufferSize int, - streamTimeout time.Duration, - start time.Time, - finishMetrics func(result string, duration time.Duration, toBackend, fromBackend int64), - streamFields []zap.Field, -) { - // Extended CONNECT: ServeHTTP must block because rw and req.Body are - // only valid while the handler goroutine is running. Defers clean up - // when the select below fires and this function returns. +// handleUpgradeTunnel returns when transfer is done. +func (h *Handler) handleUpgradeTunnel(streamLogger *zap.Logger, streamLevel zapcore.Level, conn io.ReadWriteCloser, backConn io.ReadWriteCloser, deleteFrontConn func(), deleteBackConn func(), bufferSize int, streamTimeout time.Duration, start time.Time, finishMetrics func(result string, duration time.Duration, toBackend int64, fromBackend int64), streamFields []zap.Field) { defer deleteBackConn() defer deleteFrontConn() var ( + wg sync.WaitGroup toBackend int64 fromBackend int64 result string @@ -243,7 +238,7 @@ func (h *Handler) handleExtendedConnectUpgradeTunnel( spc := switchProtocolCopier{ user: conn, backend: backConn, - wg: wg, + wg: &wg, bufferSize: bufferSize, sent: &toBackend, received: &fromBackend, @@ -290,90 +285,6 @@ func (h *Handler) handleExtendedConnectUpgradeTunnel( } } -func (h *Handler) handleDetachedUpgradeTunnel( - streamLogger *zap.Logger, - streamLevel zapcore.Level, - conn io.ReadWriteCloser, - backConn io.ReadWriteCloser, - deleteFrontConn func(), - deleteBackConn func(), - bufferSize int, - streamTimeout time.Duration, - start time.Time, - finishMetrics func(result string, duration time.Duration, toBackend, fromBackend int64), - streamFields []zap.Field, -) { - // HTTP/1.1 hijacked connection: launch a detached goroutine so that - // ServeHTTP can return immediately, allowing the Handler to be GC'd - // after a config reload. The goroutine captures only tunnel (a small - // *tunnelState), logger, conn/backConn, and scalar config values. - go func() { - var ( - toBackend int64 - fromBackend int64 - result = "closed" - ) - defer deleteBackConn() - defer deleteFrontConn() - defer func() { - finishMetrics(result, time.Since(start), toBackend, fromBackend) - if c := streamLogger.Check(streamLevel, "connection closed"); c != nil { - fields := append([]zap.Field{}, streamFields...) - fields = append(fields, - zap.Duration("duration", time.Since(start)), - zap.Int64("bytes_to_backend", toBackend), - zap.Int64("bytes_from_backend", fromBackend), - ) - c.Write(fields...) - } - }() - - var innerWg sync.WaitGroup - // when a stream timeout is encountered, no error will be read from errc - // a buffer size of 2 will allow both the read and write goroutines to send the error and exit - // see: https://github.com/caddyserver/caddy/issues/7418 - errc := make(chan error, 2) - spc := switchProtocolCopier{ - user: conn, - backend: backConn, - wg: &innerWg, - bufferSize: bufferSize, - sent: &toBackend, - received: &fromBackend, - } - innerWg.Add(2) - - var timeoutc <-chan time.Time - if streamTimeout > 0 { - timer := time.NewTimer(streamTimeout) - defer timer.Stop() - timeoutc = timer.C - } - - go spc.copyToBackend(errc) - go spc.copyFromBackend(errc) - select { - case err := <-errc: - result = classifyStreamResult(err) - if c := streamLogger.Check(streamLevel, "streaming error"); c != nil { - c.Write(zap.Error(err)) - } - case t := <-timeoutc: - result = "timeout" - if c := streamLogger.Check(streamLevel, "stream timed out"); c != nil { - c.Write(zap.Time("timeout", t)) - } - } - - // Close both ends to unblock the still-running copy goroutine, - // then wait for it to finish so byte counts are accurate before - // the deferred log fires. - conn.Close() - backConn.Close() - innerWg.Wait() - }() -} - func classifyStreamResult(err error) string { if err == nil || errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) { return "closed" From cee04ab28e8889f36dfb3f2a62bceb936147e273 Mon Sep 17 00:00:00 2001 From: WeidiDeng Date: Tue, 21 Apr 2026 11:46:28 +0800 Subject: [PATCH 07/17] correctly close detached streams --- .../caddyhttp/reverseproxy/reverseproxy.go | 47 ++++++++++++++----- modules/caddyhttp/reverseproxy/streaming.go | 36 ++++++++------ 2 files changed, 58 insertions(+), 25 deletions(-) diff --git a/modules/caddyhttp/reverseproxy/reverseproxy.go b/modules/caddyhttp/reverseproxy/reverseproxy.go index e827d5bac94..08aa4eda660 100644 --- a/modules/caddyhttp/reverseproxy/reverseproxy.go +++ b/modules/caddyhttp/reverseproxy/reverseproxy.go @@ -322,6 +322,10 @@ func (h *Handler) Provision(ctx caddy.Context) error { } } + if h.StreamRetainOnReload { + registerDetachedTunnelStates(h.tunnel) + } + // warn about unsafe buffering config if h.RequestBuffers == -1 || h.ResponseBuffers == -1 { h.logger.Warn("UNLIMITED BUFFERING: buffering is enabled without any cap on buffer size, which can result in OOM crashes") @@ -522,19 +526,40 @@ func (h Handler) streamLoggerForRequest(req *http.Request) *zap.Logger { return caddy.Log().Named(name) } -// Cleanup cleans up the resources made by h. -func (h *Handler) Cleanup() error { - if !h.StreamRetainOnReload { - // Legacy behavior: close all upgraded connections on reload, either - // immediately or after StreamCloseDelay. - err := h.tunnel.cleanupConnections() - for _, upstream := range h.Upstreams { - _, _ = hosts.Delete(upstream.String()) +var ( + detachedTunnelStates = make(map[*tunnelState]struct{}) + detachedTunnelStatesMu sync.Mutex +) + +func registerDetachedTunnelStates(ts *tunnelState) { + detachedTunnelStatesMu.Lock() + defer detachedTunnelStatesMu.Unlock() + detachedTunnelStates[ts] = struct{}{} +} + +func notifyDetachedTunnelStatesOfUpstreamRemoval(upstream string, self *tunnelState) error { + detachedTunnelStatesMu.Lock() + defer detachedTunnelStatesMu.Unlock() + + var err error + for tunnel := range detachedTunnelStates { + if closeErr := tunnel.closeConnectionsForUpstream(upstream); closeErr != nil && tunnel == self && err == nil { + err = closeErr } - return err } + return err +} - var err error +func unregisterDetachedTunnelStates(ts *tunnelState) { + detachedTunnelStatesMu.Lock() + defer detachedTunnelStatesMu.Unlock() + delete(detachedTunnelStates, ts) +} + +// Cleanup cleans up the resources made by h. +func (h *Handler) Cleanup() error { + // even if StreamRetainOnReload is true, extended connect websockets may still be running + err := h.tunnel.cleanupAttachedConnections() for _, upstream := range h.Upstreams { // hosts.Delete returns deleted=true when the ref count reaches zero, // meaning no other active config references this upstream. In that @@ -542,7 +567,7 @@ func (h *Handler) Cleanup() error { // to their natural end since the upstream is still in use. deleted, _ := hosts.Delete(upstream.String()) if deleted { - if closeErr := h.tunnel.closeConnectionsForUpstream(upstream.String()); closeErr != nil && err == nil { + if closeErr := notifyDetachedTunnelStatesOfUpstreamRemoval(upstream.String(), h.tunnel); closeErr != nil && err == nil { err = closeErr } } diff --git a/modules/caddyhttp/reverseproxy/streaming.go b/modules/caddyhttp/reverseproxy/streaming.go index 3835653df1a..ccb056d5828 100644 --- a/modules/caddyhttp/reverseproxy/streaming.go +++ b/modules/caddyhttp/reverseproxy/streaming.go @@ -197,8 +197,8 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWrit } return nil } - deleteFrontConn := tunnel.registerConnection(conn, gracefulClose(conn, false), upstreamAddr) - deleteBackConn := tunnel.registerConnection(backConn, gracefulClose(backConn, true), upstreamAddr) + deleteFrontConn := tunnel.registerConnection(conn, gracefulClose(conn, false), detached, upstreamAddr) + deleteBackConn := tunnel.registerConnection(backConn, gracefulClose(backConn, true), detached, upstreamAddr) if h.streamLogsSkipHandshake() { caddyhttp.SetVar(req.Context(), caddyhttp.LogSkipVar, true) } @@ -442,6 +442,7 @@ func (h Handler) copyBuffer(dst io.Writer, src io.Reader, buf []byte, logger *za type openConnection struct { conn io.ReadWriteCloser gracefulClose func() error + detached bool upstream string } @@ -464,29 +465,36 @@ func newTunnelState(logger *zap.Logger, closeDelay time.Duration) *tunnelState { // registerConnection stores conn in the tracking map. The caller must invoke // the returned del func when the connection is done. -func (ts *tunnelState) registerConnection(conn io.ReadWriteCloser, gracefulClose func() error, upstream string) (del func()) { +func (ts *tunnelState) registerConnection(conn io.ReadWriteCloser, gracefulClose func() error, detached bool, upstream string) (del func()) { ts.mu.Lock() - ts.connections[conn] = openConnection{conn, gracefulClose, upstream} + ts.connections[conn] = openConnection{conn, gracefulClose, detached, upstream} ts.mu.Unlock() return func() { ts.mu.Lock() delete(ts.connections, conn) - if len(ts.connections) == 0 && ts.closeTimer != nil { - if ts.closeTimer.Stop() { - ts.logger.Debug("stopped streaming connections close timer - all connections are already closed") + if len(ts.connections) == 0 { + unregisterDetachedTunnelStates(ts) + if ts.closeTimer != nil { + if ts.closeTimer.Stop() { + ts.logger.Debug("stopped streaming connections close timer - all connections are already closed") + } + ts.closeTimer = nil } - ts.closeTimer = nil } ts.mu.Unlock() } } -// closeConnections closes all tracked connections. -func (ts *tunnelState) closeConnections() error { +// closeAttachedConnections closes all tracked attached connections. +func (ts *tunnelState) closeAttachedConnections() error { var err error ts.mu.Lock() defer ts.mu.Unlock() for _, oc := range ts.connections { + // detached connections are only closed when the upstream is gone from the config + if oc.detached { + continue + } if oc.gracefulClose != nil { if gracefulErr := oc.gracefulClose(); gracefulErr != nil && err == nil { err = gracefulErr @@ -499,11 +507,11 @@ func (ts *tunnelState) closeConnections() error { return err } -// cleanupConnections closes upgraded connections. Depending on closeDelay it +// cleanupAttachedConnections closes upgraded attached connections. Depending on closeDelay it // does that either immediately or after a timer. -func (ts *tunnelState) cleanupConnections() error { +func (ts *tunnelState) cleanupAttachedConnections() error { if ts.closeDelay == 0 { - return ts.closeConnections() + return ts.closeAttachedConnections() } ts.mu.Lock() @@ -514,7 +522,7 @@ func (ts *tunnelState) cleanupConnections() error { if c := ts.logger.Check(zapcore.DebugLevel, "closing streaming connections after delay"); c != nil { c.Write(zap.Duration("delay", delay)) } - err := ts.closeConnections() + err := ts.closeAttachedConnections() if err != nil { if c := ts.logger.Check(zapcore.ErrorLevel, "failed to close connections after delay"); c != nil { c.Write( From ccc76ac1f6c528e6d6bbef32dd90025635acb97f Mon Sep 17 00:00:00 2001 From: WeidiDeng Date: Tue, 21 Apr 2026 11:48:00 +0800 Subject: [PATCH 08/17] make handleUpgradeTunnel a standalone func --- modules/caddyhttp/reverseproxy/streaming.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/modules/caddyhttp/reverseproxy/streaming.go b/modules/caddyhttp/reverseproxy/streaming.go index ccb056d5828..d664e66d872 100644 --- a/modules/caddyhttp/reverseproxy/streaming.go +++ b/modules/caddyhttp/reverseproxy/streaming.go @@ -213,15 +213,15 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWrit start := time.Now() if !detached { - h.handleUpgradeTunnel(streamLogger, streamLevel, conn, backConn, deleteFrontConn, deleteBackConn, bufferSize, streamTimeout, start, finishMetrics, streamFields) + handleUpgradeTunnel(streamLogger, streamLevel, conn, backConn, deleteFrontConn, deleteBackConn, bufferSize, streamTimeout, start, finishMetrics, streamFields) } else { // start a new goroutine - go h.handleUpgradeTunnel(streamLogger, streamLevel, conn, backConn, deleteFrontConn, deleteBackConn, bufferSize, streamTimeout, start, finishMetrics, streamFields) + go handleUpgradeTunnel(streamLogger, streamLevel, conn, backConn, deleteFrontConn, deleteBackConn, bufferSize, streamTimeout, start, finishMetrics, streamFields) } } // handleUpgradeTunnel returns when transfer is done. -func (h *Handler) handleUpgradeTunnel(streamLogger *zap.Logger, streamLevel zapcore.Level, conn io.ReadWriteCloser, backConn io.ReadWriteCloser, deleteFrontConn func(), deleteBackConn func(), bufferSize int, streamTimeout time.Duration, start time.Time, finishMetrics func(result string, duration time.Duration, toBackend int64, fromBackend int64), streamFields []zap.Field) { +func handleUpgradeTunnel(streamLogger *zap.Logger, streamLevel zapcore.Level, conn io.ReadWriteCloser, backConn io.ReadWriteCloser, deleteFrontConn func(), deleteBackConn func(), bufferSize int, streamTimeout time.Duration, start time.Time, finishMetrics func(result string, duration time.Duration, toBackend int64, fromBackend int64), streamFields []zap.Field) { defer deleteBackConn() defer deleteFrontConn() var ( From 6ba6cf5d13085c983e884b4c120024367e0662c1 Mon Sep 17 00:00:00 2001 From: WeidiDeng Date: Tue, 21 Apr 2026 11:54:36 +0800 Subject: [PATCH 09/17] fix tests --- modules/caddyhttp/reverseproxy/streaming_test.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/modules/caddyhttp/reverseproxy/streaming_test.go b/modules/caddyhttp/reverseproxy/streaming_test.go index e6a3ce3a1a6..a7e3504b7f7 100644 --- a/modules/caddyhttp/reverseproxy/streaming_test.go +++ b/modules/caddyhttp/reverseproxy/streaming_test.go @@ -117,8 +117,8 @@ func TestHandlerCleanupLegacyModeClosesAllConnections(t *testing.T) { ts := newTunnelState(caddy.Log(), 0) connA := newTrackingReadWriteCloser() connB := newTrackingReadWriteCloser() - ts.registerConnection(connA, nil, "a") - ts.registerConnection(connB, nil, "b") + ts.registerConnection(connA, nil, false, "a") + ts.registerConnection(connB, nil, false, "b") h := &Handler{ tunnel: ts, @@ -136,7 +136,7 @@ func TestHandlerCleanupLegacyModeClosesAllConnections(t *testing.T) { func TestHandlerCleanupLegacyModeHonorsDelay(t *testing.T) { ts := newTunnelState(caddy.Log(), 40*time.Millisecond) conn := newTrackingReadWriteCloser() - ts.registerConnection(conn, nil, "a") + ts.registerConnection(conn, nil, false, "a") h := &Handler{ tunnel: ts, @@ -175,8 +175,8 @@ func TestHandlerCleanupRetainModeClosesOnlyRemovedUpstreams(t *testing.T) { ts := newTunnelState(caddy.Log(), 0) connA := newTrackingReadWriteCloser() connB := newTrackingReadWriteCloser() - ts.registerConnection(connA, nil, upstreamA) - ts.registerConnection(connB, nil, upstreamB) + ts.registerConnection(connA, nil, true, upstreamA) + ts.registerConnection(connB, nil, true, upstreamB) h := &Handler{ tunnel: ts, From f970f397e2ad8fc0d037e094fbf7ff72558ceeeb Mon Sep 17 00:00:00 2001 From: WeidiDeng Date: Tue, 21 Apr 2026 14:55:46 +0800 Subject: [PATCH 10/17] fix tests --- modules/caddyhttp/responsewriter_test.go | 6 +++--- modules/caddyhttp/reverseproxy/streaming_test.go | 1 + 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/modules/caddyhttp/responsewriter_test.go b/modules/caddyhttp/responsewriter_test.go index ec8c3b5ab38..4111164815e 100644 --- a/modules/caddyhttp/responsewriter_test.go +++ b/modules/caddyhttp/responsewriter_test.go @@ -226,9 +226,6 @@ func TestResponseRecorderSwitchingProtocolsIsHijackAware(t *testing.T) { }) rr.WriteHeader(http.StatusSwitchingProtocols) - if rr.Buffered() { - t.Fatal("101 switching protocols response should not remain buffered") - } if rr.Status() != http.StatusSwitchingProtocols { t.Fatalf("status = %d, want %d", rr.Status(), http.StatusSwitchingProtocols) } @@ -246,6 +243,9 @@ func TestResponseRecorderSwitchingProtocolsIsHijackAware(t *testing.T) { } defer conn.Close() + if rr.Buffered() { + t.Fatal("hijacked response should not remain buffered") + } if rr.DetachAfterHijack(true) { t.Fatal("response recorder should report hijacked state by returning false") } diff --git a/modules/caddyhttp/reverseproxy/streaming_test.go b/modules/caddyhttp/reverseproxy/streaming_test.go index a7e3504b7f7..18acba3f474 100644 --- a/modules/caddyhttp/reverseproxy/streaming_test.go +++ b/modules/caddyhttp/reverseproxy/streaming_test.go @@ -173,6 +173,7 @@ func TestHandlerCleanupRetainModeClosesOnlyRemovedUpstreams(t *testing.T) { }) ts := newTunnelState(caddy.Log(), 0) + registerDetachedTunnelStates(ts) connA := newTrackingReadWriteCloser() connB := newTrackingReadWriteCloser() ts.registerConnection(connA, nil, true, upstreamA) From ed44e4d3f65964c5f7e9309c1d3bbacfd166cb34 Mon Sep 17 00:00:00 2001 From: WeidiDeng Date: Tue, 21 Apr 2026 14:55:56 +0800 Subject: [PATCH 11/17] change the log level if hijacking without writing a status code first --- modules/caddyhttp/responsewriter.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/caddyhttp/responsewriter.go b/modules/caddyhttp/responsewriter.go index a477c7abe1f..c593c616244 100644 --- a/modules/caddyhttp/responsewriter.go +++ b/modules/caddyhttp/responsewriter.go @@ -272,8 +272,8 @@ func (rr *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { if !rr.wroteHeader { // hijacking without writing status code first works as long as subsequent writes follows http1.1 // wire format, but it will show up with a status code of 0 in the access log and bytes written - // will include response headers. - caddy.Log().Debug("hijacking without writing status code first") + // will include response headers. Response headers won't be present in the log if not set on the response writer. + caddy.Log().Warn("hijacking without writing status code first") } //nolint:bodyclose conn, brw, err := http.NewResponseController(rr.ResponseWriterWrapper).Hijack() From 733aaba10203a6e478b031001d4012a7a37c09ce Mon Sep 17 00:00:00 2001 From: WeidiDeng Date: Tue, 21 Apr 2026 17:09:40 +0800 Subject: [PATCH 12/17] only clean up connections when stopped --- modules/caddyhttp/reverseproxy/streaming.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/modules/caddyhttp/reverseproxy/streaming.go b/modules/caddyhttp/reverseproxy/streaming.go index d664e66d872..753e1e6cb08 100644 --- a/modules/caddyhttp/reverseproxy/streaming.go +++ b/modules/caddyhttp/reverseproxy/streaming.go @@ -451,6 +451,7 @@ type tunnelState struct { connections map[io.ReadWriteCloser]openConnection closeTimer *time.Timer closeDelay time.Duration + stopped bool mu sync.Mutex logger *zap.Logger } @@ -472,7 +473,7 @@ func (ts *tunnelState) registerConnection(conn io.ReadWriteCloser, gracefulClose return func() { ts.mu.Lock() delete(ts.connections, conn) - if len(ts.connections) == 0 { + if len(ts.connections) == 0 && ts.stopped { unregisterDetachedTunnelStates(ts) if ts.closeTimer != nil { if ts.closeTimer.Stop() { @@ -490,6 +491,7 @@ func (ts *tunnelState) closeAttachedConnections() error { var err error ts.mu.Lock() defer ts.mu.Unlock() + ts.stopped = true for _, oc := range ts.connections { // detached connections are only closed when the upstream is gone from the config if oc.detached { @@ -654,6 +656,9 @@ func (ts *tunnelState) closeConnectionsForUpstream(addr string) error { var err error ts.mu.Lock() defer ts.mu.Unlock() + if !ts.stopped { + return nil + } for _, oc := range ts.connections { if oc.upstream != addr { continue From 1b8d60c4592524e7c4627ad9b9138eee9248820d Mon Sep 17 00:00:00 2001 From: Francis Lavoie Date: Tue, 21 Apr 2026 07:29:23 -0400 Subject: [PATCH 13/17] Move type and const down to the bottom --- .../caddyhttp/reverseproxy/reverseproxy.go | 46 +++++++++---------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/modules/caddyhttp/reverseproxy/reverseproxy.go b/modules/caddyhttp/reverseproxy/reverseproxy.go index 08aa4eda660..02bcdcb4807 100644 --- a/modules/caddyhttp/reverseproxy/reverseproxy.go +++ b/modules/caddyhttp/reverseproxy/reverseproxy.go @@ -266,29 +266,6 @@ type Handler struct { streamLogLoggerName string } -// StreamLogs controls logging for upgraded stream lifecycle events. -type StreamLogs struct { - // The minimum level at which stream lifecycle events are logged. - // Supported values are debug, info, warn, and error. Default: debug. - Level string `json:"level,omitempty"` - - // Logger name for stream lifecycle logs. Default: "http.handlers.reverse_proxy.stream". - // Special value "access" uses the access logger namespace and, if set, - // respects the first value in access_logger_names/log_name for the request. - LoggerName string `json:"logger_name,omitempty"` - - // If true, suppresses the access log entry normally emitted when an - // upgraded stream handshake completes and the request unwinds. - SkipHandshake bool `json:"skip_handshake,omitempty"` -} - -const ( - defaultStreamLogLevel = zapcore.DebugLevel - defaultStreamLoggerName = "http.handlers.reverse_proxy.stream" - streamLoggerNameUseAccess = "access" - defaultAccessLoggerBase = "http.log.access" -) - // CaddyModule returns the Caddy module information. func (Handler) CaddyModule() caddy.ModuleInfo { return caddy.ModuleInfo{ @@ -1891,6 +1868,22 @@ func (brc bodyReadCloser) Close() error { return nil } +// StreamLogs controls logging for upgraded stream lifecycle events. +type StreamLogs struct { + // The minimum level at which stream lifecycle events are logged. + // Supported values are debug, info, warn, and error. Default: debug. + Level string `json:"level,omitempty"` + + // Logger name for stream lifecycle logs. Default: "http.handlers.reverse_proxy.stream". + // Special value "access" uses the access logger namespace and, if set, + // respects the first value in access_logger_names/log_name for the request. + LoggerName string `json:"logger_name,omitempty"` + + // If true, suppresses the access log entry normally emitted when an + // upgraded stream handshake completes and the request unwinds. + SkipHandshake bool `json:"skip_handshake,omitempty"` +} + // bufPool is used for buffering requests and responses. var bufPool = sync.Pool{ New: func() any { @@ -1936,6 +1929,13 @@ const proxyHandleResponseContextCtxKey caddy.CtxKey = "reverse_proxy_handle_resp // errNoUpstream occurs when there are no upstream available. var errNoUpstream = fmt.Errorf("no upstreams available") +const ( + defaultStreamLogLevel = zapcore.DebugLevel + defaultStreamLoggerName = "http.handlers.reverse_proxy.stream" + streamLoggerNameUseAccess = "access" + defaultAccessLoggerBase = "http.log.access" +) + // Interface guards var ( _ caddy.Provisioner = (*Handler)(nil) From e3b1bf80f4f14184ed8a5c3a8084ca6a2150374c Mon Sep 17 00:00:00 2001 From: Francis Lavoie Date: Tue, 21 Apr 2026 08:12:57 -0400 Subject: [PATCH 14/17] Rename to tunnelTracker, reflow some comments --- modules/caddyhttp/responsewriter.go | 16 ++-- .../caddyhttp/reverseproxy/reverseproxy.go | 38 ++++---- modules/caddyhttp/reverseproxy/streaming.go | 87 ++++++++++++++----- .../caddyhttp/reverseproxy/streaming_test.go | 14 +-- 4 files changed, 102 insertions(+), 53 deletions(-) diff --git a/modules/caddyhttp/responsewriter.go b/modules/caddyhttp/responsewriter.go index c593c616244..d710160bd54 100644 --- a/modules/caddyhttp/responsewriter.go +++ b/modules/caddyhttp/responsewriter.go @@ -270,9 +270,11 @@ func (rr *responseRecorder) setReadSize(size *int) { func (rr *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { if !rr.wroteHeader { - // hijacking without writing status code first works as long as subsequent writes follows http1.1 - // wire format, but it will show up with a status code of 0 in the access log and bytes written - // will include response headers. Response headers won't be present in the log if not set on the response writer. + // hijacking without writing status code first works as long as + // subsequent writes follows http1.1 wire format, but it will + // show up with a status code of 0 in the access log and bytes + // written will include response headers. Response headers won't + // be present in the log if not set on the response writer. caddy.Log().Warn("hijacking without writing status code first") } //nolint:bodyclose @@ -339,9 +341,11 @@ func (hc *hijackedConn) ReadFrom(r io.Reader) (int64, error) { return n, err } -// DetachResponseWriterAfterHijack detaches w or one of its wrapped response -// writers when it's hijacked. Returns true if not already hijacked. -// When detached, bytes read or written stats will not be recorded for the hijacked connection, and it's safe to use the connection after http middleware returns. +// DetachResponseWriterAfterHijack detaches w or one of its wrapped +// response writers when it's hijacked. Returns true if not already +// hijacked. When detached, bytes read or written stats will not be +// recorded for the hijacked connection, and it's safe to use the +// connection after http middleware returns. func DetachResponseWriterAfterHijack(w http.ResponseWriter, detached bool) bool { for w != nil { if detacher, ok := w.(interface{ DetachAfterHijack(bool) bool }); ok { diff --git a/modules/caddyhttp/reverseproxy/reverseproxy.go b/modules/caddyhttp/reverseproxy/reverseproxy.go index 02bcdcb4807..e95bf5f8a90 100644 --- a/modules/caddyhttp/reverseproxy/reverseproxy.go +++ b/modules/caddyhttp/reverseproxy/reverseproxy.go @@ -256,7 +256,7 @@ type Handler struct { // Tracks hijacked/upgraded connections (WebSocket etc.) so they can be // closed when their upstream is removed from the config. - tunnel *tunnelState + tunnelTracker *tunnelTracker ctx caddy.Context logger *zap.Logger @@ -283,7 +283,7 @@ func (h *Handler) Provision(ctx caddy.Context) error { h.events = eventAppIface.(*caddyevents.App) h.ctx = ctx h.logger = ctx.Logger() - h.tunnel = newTunnelState(h.logger, time.Duration(h.StreamCloseDelay)) + h.tunnelTracker = newTunnelTracker(h.logger, time.Duration(h.StreamCloseDelay)) h.streamLogLevel = defaultStreamLogLevel h.streamLogLoggerName = defaultStreamLoggerName if h.StreamLogs != nil { @@ -300,7 +300,7 @@ func (h *Handler) Provision(ctx caddy.Context) error { } if h.StreamRetainOnReload { - registerDetachedTunnelStates(h.tunnel) + registerDetachedTunnelTrackers(h.tunnelTracker) } // warn about unsafe buffering config @@ -504,22 +504,22 @@ func (h Handler) streamLoggerForRequest(req *http.Request) *zap.Logger { } var ( - detachedTunnelStates = make(map[*tunnelState]struct{}) - detachedTunnelStatesMu sync.Mutex + detachedTunnelTrackers = make(map[*tunnelTracker]struct{}) + detachedTunnelTrackersMu sync.Mutex ) -func registerDetachedTunnelStates(ts *tunnelState) { - detachedTunnelStatesMu.Lock() - defer detachedTunnelStatesMu.Unlock() - detachedTunnelStates[ts] = struct{}{} +func registerDetachedTunnelTrackers(ts *tunnelTracker) { + detachedTunnelTrackersMu.Lock() + defer detachedTunnelTrackersMu.Unlock() + detachedTunnelTrackers[ts] = struct{}{} } -func notifyDetachedTunnelStatesOfUpstreamRemoval(upstream string, self *tunnelState) error { - detachedTunnelStatesMu.Lock() - defer detachedTunnelStatesMu.Unlock() +func notifyDetachedTunnelTrackersOfUpstreamRemoval(upstream string, self *tunnelTracker) error { + detachedTunnelTrackersMu.Lock() + defer detachedTunnelTrackersMu.Unlock() var err error - for tunnel := range detachedTunnelStates { + for tunnel := range detachedTunnelTrackers { if closeErr := tunnel.closeConnectionsForUpstream(upstream); closeErr != nil && tunnel == self && err == nil { err = closeErr } @@ -527,16 +527,16 @@ func notifyDetachedTunnelStatesOfUpstreamRemoval(upstream string, self *tunnelSt return err } -func unregisterDetachedTunnelStates(ts *tunnelState) { - detachedTunnelStatesMu.Lock() - defer detachedTunnelStatesMu.Unlock() - delete(detachedTunnelStates, ts) +func unregisterDetachedTunnelTrackers(ts *tunnelTracker) { + detachedTunnelTrackersMu.Lock() + defer detachedTunnelTrackersMu.Unlock() + delete(detachedTunnelTrackers, ts) } // Cleanup cleans up the resources made by h. func (h *Handler) Cleanup() error { // even if StreamRetainOnReload is true, extended connect websockets may still be running - err := h.tunnel.cleanupAttachedConnections() + err := h.tunnelTracker.cleanupAttachedConnections() for _, upstream := range h.Upstreams { // hosts.Delete returns deleted=true when the ref count reaches zero, // meaning no other active config references this upstream. In that @@ -544,7 +544,7 @@ func (h *Handler) Cleanup() error { // to their natural end since the upstream is still in use. deleted, _ := hosts.Delete(upstream.String()) if deleted { - if closeErr := notifyDetachedTunnelStatesOfUpstreamRemoval(upstream.String(), h.tunnel); closeErr != nil && err == nil { + if closeErr := notifyDetachedTunnelTrackersOfUpstreamRemoval(upstream.String(), h.tunnelTracker); closeErr != nil && err == nil { err = closeErr } } diff --git a/modules/caddyhttp/reverseproxy/streaming.go b/modules/caddyhttp/reverseproxy/streaming.go index 753e1e6cb08..c93a57e471b 100644 --- a/modules/caddyhttp/reverseproxy/streaming.go +++ b/modules/caddyhttp/reverseproxy/streaming.go @@ -95,12 +95,13 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWrit // Capture all h fields needed by the tunnel now, so that the Handler (h) // is not referenced after this function returns (for HTTP/1.1 hijacked // connections the tunnel runs in a detached goroutine). - tunnel := h.tunnel + tunnel := h.tunnelTracker bufferSize := h.StreamBufferSize streamTimeout := time.Duration(h.StreamTimeout) if h.StreamRetainOnReload { - // the return value should be true as it's not hijacked yet, but some middleware may wrap response writers incorrectly + // the return value should be true as it's not hijacked yet, + // but some middleware may wrap response writers incorrectly if !caddyhttp.DetachResponseWriterAfterHijack(rw, true) { if c := logger.Check(zap.DebugLevel, "detaching connection failed"); c != nil { c.Write(zap.String("tip", "check if your response writers have an Unwrap method or if already hijacked")) @@ -113,10 +114,14 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWrit brw *bufio.ReadWriter detached = h.StreamRetainOnReload ) - // websocket over http2 or http3 if extended connect is enabled, assuming backend doesn't support this, the request will be modified to http1.1 upgrade - // TODO: once we can reliably detect backend support this, it can be removed for those backends + // websocket over http2 or http3 if extended connect is enabled, + // assuming backend doesn't support this, the request will be + // modified to http1.1 upgrade + // TODO: once we can reliably detect backend support this, it can + // be removed for those backends if body, ok := caddyhttp.GetVar(req.Context(), "extended_connect_websocket_body").(io.ReadCloser); ok { - // websocket over extended connect can't be detached. rw and req.Body are only valid while the handler goroutine is running + // websocket over extended connect can't be detached. rw and req.Body + // are only valid while the handler goroutine is running detached = false req.Body = body rw.Header().Del("Upgrade") @@ -213,15 +218,51 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWrit start := time.Now() if !detached { - handleUpgradeTunnel(streamLogger, streamLevel, conn, backConn, deleteFrontConn, deleteBackConn, bufferSize, streamTimeout, start, finishMetrics, streamFields) + handleUpgradeTunnel( + streamLogger, + streamLevel, + conn, + backConn, + deleteFrontConn, + deleteBackConn, + bufferSize, + streamTimeout, + start, + finishMetrics, + streamFields, + ) } else { // start a new goroutine - go handleUpgradeTunnel(streamLogger, streamLevel, conn, backConn, deleteFrontConn, deleteBackConn, bufferSize, streamTimeout, start, finishMetrics, streamFields) + go handleUpgradeTunnel( + streamLogger, + streamLevel, + conn, + backConn, + deleteFrontConn, + deleteBackConn, + bufferSize, + streamTimeout, + start, + finishMetrics, + streamFields, + ) } } // handleUpgradeTunnel returns when transfer is done. -func handleUpgradeTunnel(streamLogger *zap.Logger, streamLevel zapcore.Level, conn io.ReadWriteCloser, backConn io.ReadWriteCloser, deleteFrontConn func(), deleteBackConn func(), bufferSize int, streamTimeout time.Duration, start time.Time, finishMetrics func(result string, duration time.Duration, toBackend int64, fromBackend int64), streamFields []zap.Field) { +func handleUpgradeTunnel( + streamLogger *zap.Logger, + streamLevel zapcore.Level, + conn io.ReadWriteCloser, + backConn io.ReadWriteCloser, + deleteFrontConn func(), + deleteBackConn func(), + bufferSize int, + streamTimeout time.Duration, + start time.Time, + finishMetrics func(result string, duration time.Duration, toBackend int64, fromBackend int64), + streamFields []zap.Field, +) { defer deleteBackConn() defer deleteFrontConn() var ( @@ -232,7 +273,8 @@ func handleUpgradeTunnel(streamLogger *zap.Logger, streamLevel zapcore.Level, co ) // when a stream timeout is encountered, no error will be read from errc - // a buffer size of 2 will allow both the read and write goroutines to send the error and exit + // a buffer size of 2 will allow both the read and write goroutines to + // send the error and exit // see: https://github.com/caddyserver/caddy/issues/7418 errc := make(chan error, 2) spc := switchProtocolCopier{ @@ -286,7 +328,10 @@ func handleUpgradeTunnel(streamLogger *zap.Logger, streamLevel zapcore.Level, co } func classifyStreamResult(err error) string { - if err == nil || errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) { + if err == nil || + errors.Is(err, io.EOF) || + errors.Is(err, net.ErrClosed) || + errors.Is(err, context.Canceled) { return "closed" } return "error" @@ -446,8 +491,8 @@ type openConnection struct { upstream string } -// tunnelState tracks hijacked/upgraded connections for selective cleanup. -type tunnelState struct { +// tunnelTracker tracks hijacked/upgraded connections for selective cleanup. +type tunnelTracker struct { connections map[io.ReadWriteCloser]openConnection closeTimer *time.Timer closeDelay time.Duration @@ -456,8 +501,8 @@ type tunnelState struct { logger *zap.Logger } -func newTunnelState(logger *zap.Logger, closeDelay time.Duration) *tunnelState { - return &tunnelState{ +func newTunnelTracker(logger *zap.Logger, closeDelay time.Duration) *tunnelTracker { + return &tunnelTracker{ connections: make(map[io.ReadWriteCloser]openConnection), closeDelay: closeDelay, logger: logger, @@ -466,7 +511,7 @@ func newTunnelState(logger *zap.Logger, closeDelay time.Duration) *tunnelState { // registerConnection stores conn in the tracking map. The caller must invoke // the returned del func when the connection is done. -func (ts *tunnelState) registerConnection(conn io.ReadWriteCloser, gracefulClose func() error, detached bool, upstream string) (del func()) { +func (ts *tunnelTracker) registerConnection(conn io.ReadWriteCloser, gracefulClose func() error, detached bool, upstream string) (del func()) { ts.mu.Lock() ts.connections[conn] = openConnection{conn, gracefulClose, detached, upstream} ts.mu.Unlock() @@ -474,7 +519,7 @@ func (ts *tunnelState) registerConnection(conn io.ReadWriteCloser, gracefulClose ts.mu.Lock() delete(ts.connections, conn) if len(ts.connections) == 0 && ts.stopped { - unregisterDetachedTunnelStates(ts) + unregisterDetachedTunnelTrackers(ts) if ts.closeTimer != nil { if ts.closeTimer.Stop() { ts.logger.Debug("stopped streaming connections close timer - all connections are already closed") @@ -487,7 +532,7 @@ func (ts *tunnelState) registerConnection(conn io.ReadWriteCloser, gracefulClose } // closeAttachedConnections closes all tracked attached connections. -func (ts *tunnelState) closeAttachedConnections() error { +func (ts *tunnelTracker) closeAttachedConnections() error { var err error ts.mu.Lock() defer ts.mu.Unlock() @@ -509,9 +554,9 @@ func (ts *tunnelState) closeAttachedConnections() error { return err } -// cleanupAttachedConnections closes upgraded attached connections. Depending on closeDelay it -// does that either immediately or after a timer. -func (ts *tunnelState) cleanupAttachedConnections() error { +// cleanupAttachedConnections closes upgraded attached connections. +// Depending on closeDelay it does that either immediately or after a timer. +func (ts *tunnelTracker) cleanupAttachedConnections() error { if ts.closeDelay == 0 { return ts.closeAttachedConnections() } @@ -652,7 +697,7 @@ func isWebsocket(r *http.Request) bool { // closeConnectionsForUpstream closes all tracked connections that were // established to the given upstream address. -func (ts *tunnelState) closeConnectionsForUpstream(addr string) error { +func (ts *tunnelTracker) closeConnectionsForUpstream(addr string) error { var err error ts.mu.Lock() defer ts.mu.Unlock() diff --git a/modules/caddyhttp/reverseproxy/streaming_test.go b/modules/caddyhttp/reverseproxy/streaming_test.go index 18acba3f474..2de24a864cf 100644 --- a/modules/caddyhttp/reverseproxy/streaming_test.go +++ b/modules/caddyhttp/reverseproxy/streaming_test.go @@ -114,14 +114,14 @@ func (c *trackingReadWriteCloser) isClosed() bool { } func TestHandlerCleanupLegacyModeClosesAllConnections(t *testing.T) { - ts := newTunnelState(caddy.Log(), 0) + ts := newTunnelTracker(caddy.Log(), 0) connA := newTrackingReadWriteCloser() connB := newTrackingReadWriteCloser() ts.registerConnection(connA, nil, false, "a") ts.registerConnection(connB, nil, false, "b") h := &Handler{ - tunnel: ts, + tunnelTracker: ts, StreamRetainOnReload: false, } @@ -134,12 +134,12 @@ func TestHandlerCleanupLegacyModeClosesAllConnections(t *testing.T) { } func TestHandlerCleanupLegacyModeHonorsDelay(t *testing.T) { - ts := newTunnelState(caddy.Log(), 40*time.Millisecond) + ts := newTunnelTracker(caddy.Log(), 40*time.Millisecond) conn := newTrackingReadWriteCloser() ts.registerConnection(conn, nil, false, "a") h := &Handler{ - tunnel: ts, + tunnelTracker: ts, StreamRetainOnReload: false, } @@ -172,15 +172,15 @@ func TestHandlerCleanupRetainModeClosesOnlyRemovedUpstreams(t *testing.T) { _, _ = hosts.Delete(upstreamB) }) - ts := newTunnelState(caddy.Log(), 0) - registerDetachedTunnelStates(ts) + ts := newTunnelTracker(caddy.Log(), 0) + registerDetachedTunnelTrackers(ts) connA := newTrackingReadWriteCloser() connB := newTrackingReadWriteCloser() ts.registerConnection(connA, nil, true, upstreamA) ts.registerConnection(connB, nil, true, upstreamB) h := &Handler{ - tunnel: ts, + tunnelTracker: ts, StreamRetainOnReload: true, Upstreams: UpstreamPool{ &Upstream{Dial: upstreamA}, From 558ec222db7b3935210d001d85b6c2799a33f685 Mon Sep 17 00:00:00 2001 From: Francis Lavoie Date: Tue, 21 Apr 2026 08:17:49 -0400 Subject: [PATCH 15/17] Add note about capturing h --- modules/caddyhttp/reverseproxy/streaming.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/modules/caddyhttp/reverseproxy/streaming.go b/modules/caddyhttp/reverseproxy/streaming.go index c93a57e471b..3f6b40cd9f9 100644 --- a/modules/caddyhttp/reverseproxy/streaming.go +++ b/modules/caddyhttp/reverseproxy/streaming.go @@ -193,7 +193,8 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWrit // gracefully close connections we recognize as websockets. We need to make // sure the client connection messages (i.e. to upstream) are masked, so we // need to know whether the connection is considered the server or the - // client side of the proxy. + // client side of the proxy. Note that gracefulClose must not capture h, + // since the tunnel may outlive the handler instance. gracefulClose := func(conn io.ReadWriteCloser, isClient bool) func() error { if isWebsocket(req) { return func() error { From 97f5fe007957e6c7eb2abd410ae0c30621fcadc9 Mon Sep 17 00:00:00 2001 From: Francis Lavoie Date: Tue, 21 Apr 2026 08:38:38 -0400 Subject: [PATCH 16/17] Rename to stream_detached --- .../integration/stream_reload_stress_test.go | 44 +++++++++---------- modules/caddyhttp/reverseproxy/caddyfile.go | 6 +-- .../caddyhttp/reverseproxy/reverseproxy.go | 24 +++++----- modules/caddyhttp/reverseproxy/streaming.go | 4 +- .../caddyhttp/reverseproxy/streaming_test.go | 16 +++---- 5 files changed, 48 insertions(+), 46 deletions(-) diff --git a/caddytest/integration/stream_reload_stress_test.go b/caddytest/integration/stream_reload_stress_test.go index ff140f9a45c..6ae6e9fa087 100644 --- a/caddytest/integration/stream_reload_stress_test.go +++ b/caddytest/integration/stream_reload_stress_test.go @@ -40,13 +40,13 @@ func TestReverseProxyReloadStressUpgradedStreamsHeapProfiles(t *testing.T) { // // legacy – no delay, close on reload immediately (old default) // close_delay – stream_close_delay, the old "keep-alive workaround" - // retain – stream_retain_on_reload, the new explicit retain flag + // detached – stream_detached, the new explicit detached flag // // Reloads are spread across time and interleaved with echo-checks so // stream health is exercised at each reload boundary, not only at the end. legacy := runReloadStress(t, tester, backend.addr, "legacy", false, 0) closeDelay := runReloadStress(t, tester, backend.addr, "close_delay", false, stressCloseDelay(t)) - retain := runReloadStress(t, tester, backend.addr, "retain", true, 0) + detached := runReloadStress(t, tester, backend.addr, "detached", true, 0) if legacy.aliveAfterReloads != 0 { t.Fatalf("legacy mode left %d upgraded streams alive after reloads", legacy.aliveAfterReloads) @@ -57,8 +57,8 @@ func TestReverseProxyReloadStressUpgradedStreamsHeapProfiles(t *testing.T) { if closeDelay.aliveAfterReloads != 0 { t.Fatalf("close_delay mode left %d upgraded streams alive after delay expiry", closeDelay.aliveAfterReloads) } - if retain.aliveAfterReloads != retain.streamCount { - t.Fatalf("retain mode kept %d/%d upgraded streams alive after reloads", retain.aliveAfterReloads, retain.streamCount) + if detached.aliveAfterReloads != detached.streamCount { + t.Fatalf("detached mode kept %d/%d upgraded streams alive after reloads", detached.aliveAfterReloads, detached.streamCount) } t.Logf("legacy heap: before=%s mid=%s after=%s delta(before→after)=%s objects(before=%d after=%d) handler_frames(before=%d after=%d)", @@ -77,13 +77,13 @@ func TestReverseProxyReloadStressUpgradedStreamsHeapProfiles(t *testing.T) { closeDelay.beforeReload.HeapObjects, closeDelay.afterReload.HeapObjects, closeDelay.beforeReload.handlerFrames, closeDelay.afterReload.handlerFrames, ) - t.Logf("retain heap: before=%s mid=%s after=%s delta(before→after)=%s objects(before=%d after=%d) handler_frames(before=%d after=%d)", - formatBytes(retain.beforeReload.HeapInuse), - formatBytes(retain.midReload.HeapInuse), - formatBytes(retain.afterReload.HeapInuse), - formatBytesDiff(retain.beforeReload.HeapInuse, retain.afterReload.HeapInuse), - retain.beforeReload.HeapObjects, retain.afterReload.HeapObjects, - retain.beforeReload.handlerFrames, retain.afterReload.handlerFrames, + t.Logf("detached heap: before=%s mid=%s after=%s delta(before→after)=%s objects(before=%d after=%d) handler_frames(before=%d after=%d)", + formatBytes(detached.beforeReload.HeapInuse), + formatBytes(detached.midReload.HeapInuse), + formatBytes(detached.afterReload.HeapInuse), + formatBytesDiff(detached.beforeReload.HeapInuse, detached.afterReload.HeapInuse), + detached.beforeReload.HeapObjects, detached.afterReload.HeapObjects, + detached.beforeReload.handlerFrames, detached.afterReload.handlerFrames, ) } @@ -107,7 +107,7 @@ type heapSnapshot struct { // config reloads spread over time. An echo check is performed every 6 reloads so // stream health is exercised at each reload boundary rather than only at the end. // closeDelay mirrors the stream_close_delay config option; pass 0 to disable. -func runReloadStress(t *testing.T, tester *caddytest.Tester, backendAddr, mode string, retain bool, closeDelay time.Duration) stressRunResult { +func runReloadStress(t *testing.T, tester *caddytest.Tester, backendAddr, mode string, detach bool, closeDelay time.Duration) stressRunResult { t.Helper() const echoEvery = 6 // perform an echo check every N reloads @@ -115,7 +115,7 @@ func runReloadStress(t *testing.T, tester *caddytest.Tester, backendAddr, mode s streamCount := envIntOrDefault(t, "CADDY_STRESS_STREAM_COUNT", defaultStressStreamCount) reloadCount := envIntOrDefault(t, "CADDY_STRESS_RELOAD_COUNT", defaultStressReloadCount) - tester.InitServer(reloadStressConfig(backendAddr, retain, closeDelay, 0), "caddyfile") + tester.InitServer(reloadStressConfig(backendAddr, detach, closeDelay, 0), "caddyfile") clients := make([]*upgradedStreamClient, 0, streamCount) for i := 0; i < streamCount; i++ { @@ -134,7 +134,7 @@ func runReloadStress(t *testing.T, tester *caddytest.Tester, backendAddr, mode s // pause briefly and measure stream health so the snapshot reflects real-world // reload cadence rather than a tight loop. for i := 1; i <= reloadCount; i++ { - loadCaddyfileConfig(t, reloadStressConfig(backendAddr, retain, closeDelay, i)) + loadCaddyfileConfig(t, reloadStressConfig(backendAddr, detach, closeDelay, i)) // Small pause after each reload to let connection teardown propagate. time.Sleep(50 * time.Millisecond) @@ -143,11 +143,11 @@ func runReloadStress(t *testing.T, tester *caddytest.Tester, backendAddr, mode s alive := countAliveStreams(clients) t.Logf("%s mode: %d/%d streams alive after reload %d", mode, alive, streamCount, i) - // In retain mode every stream must survive every reload (upstream unchanged). - if retain { + // In detached mode, every stream must survive every reload (upstream unchanged). + if detach { for j, client := range clients { if err := client.echo(fmt.Sprintf("%s-mid-%02d-%02d\n", mode, i, j)); err != nil { - t.Fatalf("retain mode stream %d died at reload %d: %v", j, i, err) + t.Fatalf("detached mode stream %d died at reload %d: %v", j, i, err) } } } @@ -160,11 +160,11 @@ func runReloadStress(t *testing.T, tester *caddytest.Tester, backendAddr, mode s // For legacy mode: the reloads close streams immediately; wait for that to complete. // For close_delay mode: streams are still alive here; wait for the delay to fire. - // For retain mode: streams survive indefinitely; no wait needed. + // For detached mode: streams survive indefinitely; no wait needed. var aliveBeforeDelayExpiry int aliveAfterReloads := countAliveStreams(clients) switch { - case retain: + case detach: // nothing to wait for case closeDelay > 0: // streams should still be alive at this point (delay hasn't expired) @@ -251,10 +251,10 @@ func loadCaddyfileConfig(t *testing.T, rawConfig string) { } } -func reloadStressConfig(backendAddr string, retain bool, closeDelay time.Duration, revision int) string { +func reloadStressConfig(backendAddr string, detach bool, closeDelay time.Duration, revision int) string { var directives string - if retain { - directives += "\n\t\tstream_retain_on_reload" + if detach { + directives += "\n\t\tstream_detached" } if closeDelay > 0 { directives += fmt.Sprintf("\n\t\tstream_close_delay %s", closeDelay) diff --git a/modules/caddyhttp/reverseproxy/caddyfile.go b/modules/caddyhttp/reverseproxy/caddyfile.go index c692267c10c..56eb3fd112d 100644 --- a/modules/caddyhttp/reverseproxy/caddyfile.go +++ b/modules/caddyhttp/reverseproxy/caddyfile.go @@ -99,7 +99,7 @@ func parseCaddyfile(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, error) // stream_buffer_size // stream_timeout // stream_close_delay -// stream_retain_on_reload +// stream_detached // stream_logs { // level // logger_name @@ -709,11 +709,11 @@ func (h *Handler) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { h.StreamCloseDelay = caddy.Duration(dur) } - case "stream_retain_on_reload": + case "stream_detached": if d.NextArg() { return d.ArgErr() } - h.StreamRetainOnReload = true + h.StreamDetached = true case "stream_logs": if d.NextArg() { diff --git a/modules/caddyhttp/reverseproxy/reverseproxy.go b/modules/caddyhttp/reverseproxy/reverseproxy.go index e95bf5f8a90..61f31b7657e 100644 --- a/modules/caddyhttp/reverseproxy/reverseproxy.go +++ b/modules/caddyhttp/reverseproxy/reverseproxy.go @@ -186,15 +186,17 @@ type Handler struct { // by the previous config closing. Default: no delay. StreamCloseDelay caddy.Duration `json:"stream_close_delay,omitempty"` - // If true, upgraded connections such as WebSockets are retained across - // config reloads when their upstream still exists in the new config. - // Connections using upstreams that are removed are closed during cleanup. - // By default this is false, preserving legacy behavior where upgraded - // connections are closed on reload (optionally delayed by stream_close_delay). - // Only http1.1 websocket connections are affected, websockets for h2/h3 are not affected. - // If true, bytes transferred for http1.1 in the access logs will be zero but those stats - // can be found in the stream logs for http1/2/3 regardless if this is enabled. - StreamRetainOnReload bool `json:"stream_retain_on_reload,omitempty"` + // If true, upgraded connections such as WebSockets are detached from + // the handler and retained across config reloads when their upstream + // still exists in the new config. Connections using upstreams that are + // removed are closed during cleanup. By default this is false, preserving + // legacy behavior where upgraded connections are closed on reload + // (optionally delayed by stream_close_delay). + // Only http1.1 websocket connections are affected, websockets for h2/h3 + // are not affected. If true, bytes transferred for http1.1 in the access + // logs will be zero but those stats can be found in the stream logs for + // http1/2/3 regardless if this is enabled. + StreamDetached bool `json:"stream_detached,omitempty"` // Controls logging behavior for upgraded stream lifecycle events. // If omitted, defaults are used (level=DEBUG, logger_name="http.handlers.reverse_proxy.stream"). @@ -299,7 +301,7 @@ func (h *Handler) Provision(ctx caddy.Context) error { } } - if h.StreamRetainOnReload { + if h.StreamDetached { registerDetachedTunnelTrackers(h.tunnelTracker) } @@ -535,7 +537,7 @@ func unregisterDetachedTunnelTrackers(ts *tunnelTracker) { // Cleanup cleans up the resources made by h. func (h *Handler) Cleanup() error { - // even if StreamRetainOnReload is true, extended connect websockets may still be running + // even if StreamDetached is true, extended connect websockets may still be running err := h.tunnelTracker.cleanupAttachedConnections() for _, upstream := range h.Upstreams { // hosts.Delete returns deleted=true when the ref count reaches zero, diff --git a/modules/caddyhttp/reverseproxy/streaming.go b/modules/caddyhttp/reverseproxy/streaming.go index 3f6b40cd9f9..7cb7ff7da23 100644 --- a/modules/caddyhttp/reverseproxy/streaming.go +++ b/modules/caddyhttp/reverseproxy/streaming.go @@ -99,7 +99,7 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWrit bufferSize := h.StreamBufferSize streamTimeout := time.Duration(h.StreamTimeout) - if h.StreamRetainOnReload { + if h.StreamDetached { // the return value should be true as it's not hijacked yet, // but some middleware may wrap response writers incorrectly if !caddyhttp.DetachResponseWriterAfterHijack(rw, true) { @@ -112,7 +112,7 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWrit var ( conn io.ReadWriteCloser brw *bufio.ReadWriter - detached = h.StreamRetainOnReload + detached = h.StreamDetached ) // websocket over http2 or http3 if extended connect is enabled, // assuming backend doesn't support this, the request will be diff --git a/modules/caddyhttp/reverseproxy/streaming_test.go b/modules/caddyhttp/reverseproxy/streaming_test.go index 2de24a864cf..7dc5e476cf3 100644 --- a/modules/caddyhttp/reverseproxy/streaming_test.go +++ b/modules/caddyhttp/reverseproxy/streaming_test.go @@ -121,8 +121,8 @@ func TestHandlerCleanupLegacyModeClosesAllConnections(t *testing.T) { ts.registerConnection(connB, nil, false, "b") h := &Handler{ - tunnelTracker: ts, - StreamRetainOnReload: false, + tunnelTracker: ts, + StreamDetached: false, } if err := h.Cleanup(); err != nil { @@ -139,8 +139,8 @@ func TestHandlerCleanupLegacyModeHonorsDelay(t *testing.T) { ts.registerConnection(conn, nil, false, "a") h := &Handler{ - tunnelTracker: ts, - StreamRetainOnReload: false, + tunnelTracker: ts, + StreamDetached: false, } if err := h.Cleanup(); err != nil { @@ -157,7 +157,7 @@ func TestHandlerCleanupLegacyModeHonorsDelay(t *testing.T) { } } -func TestHandlerCleanupRetainModeClosesOnlyRemovedUpstreams(t *testing.T) { +func TestHandlerCleanupDetachedModeClosesOnlyRemovedUpstreams(t *testing.T) { const upstreamA = "upstream-a" const upstreamB = "upstream-b" @@ -180,8 +180,8 @@ func TestHandlerCleanupRetainModeClosesOnlyRemovedUpstreams(t *testing.T) { ts.registerConnection(connB, nil, true, upstreamB) h := &Handler{ - tunnelTracker: ts, - StreamRetainOnReload: true, + tunnelTracker: ts, + StreamDetached: true, Upstreams: UpstreamPool{ &Upstream{Dial: upstreamA}, &Upstream{Dial: upstreamB}, @@ -193,7 +193,7 @@ func TestHandlerCleanupRetainModeClosesOnlyRemovedUpstreams(t *testing.T) { } if connA.isClosed() { - t.Fatal("connection for retained upstream should remain open") + t.Fatal("connection for detached upstream should remain open") } if !connB.isClosed() { t.Fatal("connection for removed upstream should be closed") From eeb13f1ca8a5e801804343310d1cf73e1e9848fc Mon Sep 17 00:00:00 2001 From: Francis Lavoie Date: Sat, 25 Apr 2026 05:42:43 -0400 Subject: [PATCH 17/17] More comments --- modules/caddyhttp/reverseproxy/streaming.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/modules/caddyhttp/reverseproxy/streaming.go b/modules/caddyhttp/reverseproxy/streaming.go index 7cb7ff7da23..a50e615e423 100644 --- a/modules/caddyhttp/reverseproxy/streaming.go +++ b/modules/caddyhttp/reverseproxy/streaming.go @@ -485,6 +485,8 @@ func (h Handler) copyBuffer(dst io.Writer, src io.Reader, buf []byte, logger *za // openConnection maps an open connection to an optional function for graceful // close and records which upstream address the connection is proxying to. +// Also tracks whether the connection is detached, which means it should only be +// closed when the upstream is removed from the config, not on every reload. type openConnection struct { conn io.ReadWriteCloser gracefulClose func() error @@ -493,6 +495,9 @@ type openConnection struct { } // tunnelTracker tracks hijacked/upgraded connections for selective cleanup. +// This exists to detach the lifecycle of streaming connections from the proxy +// Handler and config, since we typically want them to survive past config reloads. +// It also allows for selective connection cleanup based on their attachment status. type tunnelTracker struct { connections map[io.ReadWriteCloser]openConnection closeTimer *time.Timer