Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions protocol/invocation/rpcinvocation.go
Original file line number Diff line number Diff line change
Expand Up @@ -267,10 +267,11 @@ func (r *RPCInvocation) MergeAttachmentFromContext(ctx context.Context) {
return
}
for k, v := range header {
key := strings.ToLower(k)
if len(v) == 1 {
r.SetAttachment(k, v[0])
r.SetAttachment(key, v[0])
} else {
r.SetAttachment(k, v)
r.SetAttachment(key, v)
}
}
}
Expand Down
10 changes: 3 additions & 7 deletions protocol/invocation/rpcinvocation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,13 +371,9 @@ func TestRPCInvocation_GetAttachmentAsContext(t *testing.T) {
header := triple_protocol.ExtractFromOutgoingContext(ctx)
assert.NotNil(t, header)

// Verify that string attachments are in the header
// NewOutgoingContext stores keys as lowercase, so check both ways
assert.Contains(t, header, "key1")
assert.Equal(t, []string{"value1"}, header["key1"]) //nolint:staticcheck

assert.Contains(t, header, "key2")
assert.Equal(t, []string{"value2", "value3"}, header["key2"]) //nolint:staticcheck
// Verify that string attachments are in the header.
assert.Equal(t, []string{"value1"}, header.Values("key1"))
assert.Equal(t, []string{"value2", "value3"}, header.Values("key2"))

// key3 (int) should not be in the header since it's not a string
assert.NotContains(t, header, "key3")
Expand Down
13 changes: 8 additions & 5 deletions protocol/triple/triple_protocol/header.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import (
"encoding/base64"
"fmt"
"net/http"
"strings"
)

// EncodeBinaryHeader base64-encodes the data. It always emits unpadded values.
Expand Down Expand Up @@ -101,10 +100,13 @@ func newIncomingContext(ctx context.Context, data http.Header) context.Context {
extraData, ok := ctx.Value(extraDataKey{}).(map[string]http.Header)
if !ok {
extraData = map[string]http.Header{}
} else {
extraData = cloneExtraData(extraData)
}

for key, vals := range data {
header[strings.ToLower(key)] = vals
// Context headers use canonical keys so http.Header.Get/Values work as expected.
header[http.CanonicalHeaderKey(key)] = append([]string(nil), vals...)
}
Comment thread
leno23 marked this conversation as resolved.

extraData[headerIncomingKey] = header
Expand All @@ -128,7 +130,8 @@ func NewOutgoingContext(ctx context.Context, data http.Header) context.Context {
var header = http.Header{}

for key, vals := range data {
header[strings.ToLower(key)] = append([]string(nil), vals...)
// Context headers use canonical keys so http.Header.Get/Values work as expected.
header[http.CanonicalHeaderKey(key)] = append([]string(nil), vals...)
}

extraData, ok := ctx.Value(extraDataKey{}).(map[string]http.Header)
Expand Down Expand Up @@ -182,8 +185,8 @@ func AppendToOutgoingContext(ctx context.Context, kv ...string) context.Context
extraData[headerOutgoingKey] = header
}
for i := 0; i < len(kv); i += 2 {
// todo(DMwangnima): think about lowering
header.Add(strings.ToLower(kv[i]), kv[i+1])
key := http.CanonicalHeaderKey(kv[i])
header[key] = append(header[key], kv[i+1])
}
return ctx
}
Expand Down
37 changes: 37 additions & 0 deletions protocol/triple/triple_protocol/header_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ package triple_protocol

import (
"bytes"
"context"
"fmt"
"net/http"
"testing"
"testing/quick"
Expand Down Expand Up @@ -58,3 +60,38 @@ func TestHeaderMerge(t *testing.T) {
}
assert.Equal(t, header, expect)
}

func TestNewIncomingContextClonesHeaders(t *testing.T) {
baseCtx := NewOutgoingContext(context.Background(), http.Header{
"Request-Id": []string{"outgoing"},
})
inputValues := []string{"incoming"}
input := http.Header{
"request-id": inputValues,
}

ctx := newIncomingContext(baseCtx, input)
incoming, ok := FromIncomingContext(ctx)
assert.True(t, ok)
incoming.Values("Request-Id")[0] = "changed"
incoming.Add("Another", "value")

assert.Equal(t, []string{"incoming"}, inputValues)
outgoing := ExtractFromOutgoingContext(baseCtx)
assert.Equal(t, []string{"outgoing"}, outgoing.Values("Request-Id"))
}

func ExampleNewOutgoingContext() {
ctx := NewOutgoingContext(context.Background(), http.Header{
"hello": []string{"triple"},
})
ctx = AppendToOutgoingContext(ctx, "hello", "dubbo", "hey", "hessian")

headers := ExtractFromOutgoingContext(ctx)
fmt.Println(headers.Values("hello"))
fmt.Println(headers.Get("hey"))

// Output:
// [triple dubbo]
// hessian
}
2 changes: 2 additions & 0 deletions protocol/triple/triple_protocol/protocol_grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,7 @@ func (hc *grpcHandlerConn) ExportableHeader() http.Header {
res := make(http.Header)
hdr := hc.request.Header
for key, vals := range hdr {
// Exported attachments stay lowercase to match Dubbo/gRPC metadata keys.
key = strings.ToLower(key)
if IsReservedHeader(key) && !IsWhitelistedHeader(key) {
continue
Expand All @@ -492,6 +493,7 @@ func (hc *grpcHandlerConn) ExportableHeader() http.Header {
// reserved by gRPC protocol. Any other headers are classified as the
// user-specified metadata.
func IsReservedHeader(hdr string) bool {
hdr = strings.ToLower(hdr)
if hdr != "" && hdr[0] == ':' {
return true
}
Expand Down
10 changes: 10 additions & 0 deletions protocol/triple/triple_protocol/protocol_grpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,16 @@ func TestGRPCHandlerSender(t *testing.T) {
})
}

func TestIsReservedHeaderCanonicalizesInput(t *testing.T) {
t.Parallel()

for _, header := range []string{"Content-Type", "Grpc-Status", "TE"} {
if !IsReservedHeader(header) {
t.Fatalf("expected %q to be reserved", header)
}
}
}

func testGRPCHandlerConnMetadata(t *testing.T, conn handlerConnCloser) {
// Closing the sender shouldn't unpredictably mutate user-visible headers or
// trailers.
Expand Down
14 changes: 14 additions & 0 deletions protocol/triple/triple_protocol/triple.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,13 @@ func (r *Response) Any() any {
// Header returns the HTTP headers for this response. Headers beginning with
// "Triple-" and "Grpc-" are reserved for use by the Triple and gRPC
// protocols: applications may read them but shouldn't write them.
//
// Unary clients can pass a response wrapper to the generated client method,
// then read response headers after the call returns:
//
// response := NewResponse(&greet.GreetResponse{})
// err := client.Greet(ctx, NewRequest(&greet.GreetRequest{}), response)
// values := response.Header().Values("hello")
func (r *Response) Header() http.Header {
if r.header == nil {
r.header = make(http.Header)
Expand All @@ -244,6 +251,13 @@ func (r *Response) Header() http.Header {
// Trailers beginning with "Triple-" and "Grpc-" are reserved for use by the
// Triple and gRPC protocols: applications may read them but shouldn't write
// them.
//
// Unary clients can read trailers from the same response wrapper after the
// generated client method returns:
//
// response := NewResponse(&greet.GreetResponse{})
// err := client.Greet(ctx, NewRequest(&greet.GreetRequest{}), response)
// values := response.Trailer().Values("end")
func (r *Response) Trailer() http.Header {
if r.trailer == nil {
r.trailer = make(http.Header)
Expand Down
Loading