diff --git a/protocol/invocation/rpcinvocation.go b/protocol/invocation/rpcinvocation.go index 8219c8c221..242489c4ea 100644 --- a/protocol/invocation/rpcinvocation.go +++ b/protocol/invocation/rpcinvocation.go @@ -267,10 +267,11 @@ func (r *RPCInvocation) MergeAttachmentFromContext(ctx context.Context) { return } for k, v := range header { + key := strings.ToLower(k) if len(v) == 1 { - r.SetAttachment(k, v[0]) + r.SetAttachment(key, v[0]) } else { - r.SetAttachment(k, v) + r.SetAttachment(key, v) } } } diff --git a/protocol/invocation/rpcinvocation_test.go b/protocol/invocation/rpcinvocation_test.go index c20ac483d1..2b1373d5b9 100644 --- a/protocol/invocation/rpcinvocation_test.go +++ b/protocol/invocation/rpcinvocation_test.go @@ -371,13 +371,9 @@ func TestRPCInvocation_GetAttachmentAsContext(t *testing.T) { header := triple_protocol.ExtractFromOutgoingContext(ctx) assert.NotNil(t, header) - // Verify that string attachments are in the header - // NewOutgoingContext stores keys as lowercase, so check both ways - assert.Contains(t, header, "key1") - assert.Equal(t, []string{"value1"}, header["key1"]) //nolint:staticcheck - - assert.Contains(t, header, "key2") - assert.Equal(t, []string{"value2", "value3"}, header["key2"]) //nolint:staticcheck + // Verify that string attachments are in the header. + assert.Equal(t, []string{"value1"}, header.Values("key1")) + assert.Equal(t, []string{"value2", "value3"}, header.Values("key2")) // key3 (int) should not be in the header since it's not a string assert.NotContains(t, header, "key3") diff --git a/protocol/triple/triple_protocol/header.go b/protocol/triple/triple_protocol/header.go index adc9f7ecb7..79757716c0 100644 --- a/protocol/triple/triple_protocol/header.go +++ b/protocol/triple/triple_protocol/header.go @@ -19,7 +19,6 @@ import ( "encoding/base64" "fmt" "net/http" - "strings" ) // EncodeBinaryHeader base64-encodes the data. It always emits unpadded values. @@ -101,10 +100,13 @@ func newIncomingContext(ctx context.Context, data http.Header) context.Context { extraData, ok := ctx.Value(extraDataKey{}).(map[string]http.Header) if !ok { extraData = map[string]http.Header{} + } else { + extraData = cloneExtraData(extraData) } for key, vals := range data { - header[strings.ToLower(key)] = vals + // Context headers use canonical keys so http.Header.Get/Values work as expected. + header[http.CanonicalHeaderKey(key)] = append([]string(nil), vals...) } extraData[headerIncomingKey] = header @@ -128,7 +130,8 @@ func NewOutgoingContext(ctx context.Context, data http.Header) context.Context { var header = http.Header{} for key, vals := range data { - header[strings.ToLower(key)] = append([]string(nil), vals...) + // Context headers use canonical keys so http.Header.Get/Values work as expected. + header[http.CanonicalHeaderKey(key)] = append([]string(nil), vals...) } extraData, ok := ctx.Value(extraDataKey{}).(map[string]http.Header) @@ -182,8 +185,8 @@ func AppendToOutgoingContext(ctx context.Context, kv ...string) context.Context extraData[headerOutgoingKey] = header } for i := 0; i < len(kv); i += 2 { - // todo(DMwangnima): think about lowering - header.Add(strings.ToLower(kv[i]), kv[i+1]) + key := http.CanonicalHeaderKey(kv[i]) + header[key] = append(header[key], kv[i+1]) } return ctx } diff --git a/protocol/triple/triple_protocol/header_test.go b/protocol/triple/triple_protocol/header_test.go index 3b73b3570f..febd96cc70 100644 --- a/protocol/triple/triple_protocol/header_test.go +++ b/protocol/triple/triple_protocol/header_test.go @@ -16,6 +16,8 @@ package triple_protocol import ( "bytes" + "context" + "fmt" "net/http" "testing" "testing/quick" @@ -58,3 +60,38 @@ func TestHeaderMerge(t *testing.T) { } assert.Equal(t, header, expect) } + +func TestNewIncomingContextClonesHeaders(t *testing.T) { + baseCtx := NewOutgoingContext(context.Background(), http.Header{ + "Request-Id": []string{"outgoing"}, + }) + inputValues := []string{"incoming"} + input := http.Header{ + "request-id": inputValues, + } + + ctx := newIncomingContext(baseCtx, input) + incoming, ok := FromIncomingContext(ctx) + assert.True(t, ok) + incoming.Values("Request-Id")[0] = "changed" + incoming.Add("Another", "value") + + assert.Equal(t, []string{"incoming"}, inputValues) + outgoing := ExtractFromOutgoingContext(baseCtx) + assert.Equal(t, []string{"outgoing"}, outgoing.Values("Request-Id")) +} + +func ExampleNewOutgoingContext() { + ctx := NewOutgoingContext(context.Background(), http.Header{ + "hello": []string{"triple"}, + }) + ctx = AppendToOutgoingContext(ctx, "hello", "dubbo", "hey", "hessian") + + headers := ExtractFromOutgoingContext(ctx) + fmt.Println(headers.Values("hello")) + fmt.Println(headers.Get("hey")) + + // Output: + // [triple dubbo] + // hessian +} diff --git a/protocol/triple/triple_protocol/protocol_grpc.go b/protocol/triple/triple_protocol/protocol_grpc.go index a663d3221d..44f5bd4efb 100644 --- a/protocol/triple/triple_protocol/protocol_grpc.go +++ b/protocol/triple/triple_protocol/protocol_grpc.go @@ -476,6 +476,7 @@ func (hc *grpcHandlerConn) ExportableHeader() http.Header { res := make(http.Header) hdr := hc.request.Header for key, vals := range hdr { + // Exported attachments stay lowercase to match Dubbo/gRPC metadata keys. key = strings.ToLower(key) if IsReservedHeader(key) && !IsWhitelistedHeader(key) { continue @@ -492,6 +493,7 @@ func (hc *grpcHandlerConn) ExportableHeader() http.Header { // reserved by gRPC protocol. Any other headers are classified as the // user-specified metadata. func IsReservedHeader(hdr string) bool { + hdr = strings.ToLower(hdr) if hdr != "" && hdr[0] == ':' { return true } diff --git a/protocol/triple/triple_protocol/protocol_grpc_test.go b/protocol/triple/triple_protocol/protocol_grpc_test.go index 6a454831ab..f0aea1ac82 100644 --- a/protocol/triple/triple_protocol/protocol_grpc_test.go +++ b/protocol/triple/triple_protocol/protocol_grpc_test.go @@ -124,6 +124,16 @@ func TestGRPCHandlerSender(t *testing.T) { }) } +func TestIsReservedHeaderCanonicalizesInput(t *testing.T) { + t.Parallel() + + for _, header := range []string{"Content-Type", "Grpc-Status", "TE"} { + if !IsReservedHeader(header) { + t.Fatalf("expected %q to be reserved", header) + } + } +} + func testGRPCHandlerConnMetadata(t *testing.T, conn handlerConnCloser) { // Closing the sender shouldn't unpredictably mutate user-visible headers or // trailers. diff --git a/protocol/triple/triple_protocol/triple.go b/protocol/triple/triple_protocol/triple.go index 3ee8f0ce43..99b5bf642c 100644 --- a/protocol/triple/triple_protocol/triple.go +++ b/protocol/triple/triple_protocol/triple.go @@ -230,6 +230,13 @@ func (r *Response) Any() any { // Header returns the HTTP headers for this response. Headers beginning with // "Triple-" and "Grpc-" are reserved for use by the Triple and gRPC // protocols: applications may read them but shouldn't write them. +// +// Unary clients can pass a response wrapper to the generated client method, +// then read response headers after the call returns: +// +// response := NewResponse(&greet.GreetResponse{}) +// err := client.Greet(ctx, NewRequest(&greet.GreetRequest{}), response) +// values := response.Header().Values("hello") func (r *Response) Header() http.Header { if r.header == nil { r.header = make(http.Header) @@ -244,6 +251,13 @@ func (r *Response) Header() http.Header { // Trailers beginning with "Triple-" and "Grpc-" are reserved for use by the // Triple and gRPC protocols: applications may read them but shouldn't write // them. +// +// Unary clients can read trailers from the same response wrapper after the +// generated client method returns: +// +// response := NewResponse(&greet.GreetResponse{}) +// err := client.Greet(ctx, NewRequest(&greet.GreetRequest{}), response) +// values := response.Trailer().Values("end") func (r *Response) Trailer() http.Header { if r.trailer == nil { r.trailer = make(http.Header)