diff --git a/src/StreamJsonRpc/MessagePackFormatter.cs b/src/StreamJsonRpc/MessagePackFormatter.cs index e722b08c..ef805fad 100644 --- a/src/StreamJsonRpc/MessagePackFormatter.cs +++ b/src/StreamJsonRpc/MessagePackFormatter.cs @@ -168,6 +168,17 @@ private interface IJsonRpcMessagePackRetention set => base.MultiplexingStream = value; } + /// + /// Gets a value indicating whether the W3C traceparent property + /// should be serialized as a string instead of a more compact binary format. + /// + /// The default value is . + public bool TraceParentAsW3CString { get; init; } + + private IMessagePackFormatter TraceParentFormatter => this.TraceParentAsW3CString + ? TraceParentAsStringFormatter.Instance + : TraceParentAsBinaryFormatter.Instance; + /// /// Sets the to use for serialization of user data. /// @@ -367,7 +378,7 @@ private IFormatterResolver CreateTopLevelMessageResolver() new JsonRpcResultFormatter(this), new JsonRpcErrorFormatter(this), new JsonRpcErrorDetailFormatter(this), - new TraceParentFormatter(), + new TraceParentDelegatingFormatter(this), }; var resolvers = new IFormatterResolver[] { @@ -1527,7 +1538,7 @@ public Protocol.JsonRpcRequest Deserialize(ref MessagePackReader reader, Message } else if (TraceParentPropertyName.TryRead(stringKey)) { - TraceParent traceParent = options.Resolver.GetFormatterWithVerify().Deserialize(ref reader, options); + TraceParent traceParent = this.formatter.TraceParentFormatter.Deserialize(ref reader, options); result.TraceParent = traceParent.ToString(); } else if (TraceStatePropertyName.TryRead(stringKey)) @@ -1620,7 +1631,7 @@ public void Serialize(ref MessagePackWriter writer, Protocol.JsonRpcRequest valu if (value.TraceParent?.Length > 0) { TraceParentPropertyName.Write(ref writer); - options.Resolver.GetFormatterWithVerify().Serialize(ref writer, new TraceParent(value.TraceParent), options); + this.formatter.TraceParentFormatter.Serialize(ref writer, new TraceParent(value.TraceParent), options); if (value.TraceState?.Length > 0) { @@ -1933,8 +1944,23 @@ public EventArgs Deserialize(ref MessagePackReader reader, MessagePackSerializer } } - private class TraceParentFormatter : IMessagePackFormatter + private class TraceParentDelegatingFormatter(MessagePackFormatter formatter) : IMessagePackFormatter + { + public TraceParent Deserialize(ref MessagePackReader reader, MessagePackSerializerOptions options) + { + return formatter.TraceParentFormatter.Deserialize(ref reader, options); + } + + public void Serialize(ref MessagePackWriter writer, TraceParent value, MessagePackSerializerOptions options) + { + formatter.TraceParentFormatter.Serialize(ref writer, value, options); + } + } + + private class TraceParentAsBinaryFormatter : IMessagePackFormatter { + internal static readonly TraceParentAsBinaryFormatter Instance = new(); + public unsafe TraceParent Deserialize(ref MessagePackReader reader, MessagePackSerializerOptions options) { if (reader.ReadArrayHeader() != 2) @@ -1983,6 +2009,44 @@ public unsafe void Serialize(ref MessagePackWriter writer, TraceParent value, Me } } + private class TraceParentAsStringFormatter : IMessagePackFormatter + { + internal static readonly TraceParentAsStringFormatter Instance = new(); + + public TraceParent Deserialize(ref MessagePackReader reader, MessagePackSerializerOptions options) + { + ReadOnlySequence utf8Sequence = reader.ReadStringSequence() ?? throw new MessagePackSerializationException("Unexpected null value."); + if (utf8Sequence.Length != TraceParent.Length) + { + throw new MessagePackSerializationException("Unexpected length for traceparent string."); + } + + Span utf8Bytes = stackalloc byte[TraceParent.Length]; + utf8Sequence.CopyTo(utf8Bytes); + + Span chars = stackalloc char[TraceParent.Length]; + if (!Encoding.UTF8.TryGetChars(utf8Bytes, chars, out int charsWritten)) + { + throw new MessagePackSerializationException("Invalid UTF-8 in traceparent string."); + } + + return new TraceParent(chars); + } + + public void Serialize(ref MessagePackWriter writer, TraceParent value, MessagePackSerializerOptions options) + { + if (value.Version != 0) + { + throw new NotSupportedException("traceparent version " + value.Version + " is not supported."); + } + + Span chars = stackalloc char[TraceParent.Length]; + value.WriteTo(chars); + + writer.Write(chars.ToString()); + } + } + private class TopLevelPropertyBag : TopLevelPropertyBagBase { private readonly MessagePackSerializerOptions serializerOptions; diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.TraceParentConverter.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.TraceParentConverter.cs index 9d9c1908..28ea9a80 100644 --- a/src/StreamJsonRpc/NerdbankMessagePackFormatter.TraceParentConverter.cs +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.TraceParentConverter.cs @@ -2,6 +2,7 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System.Buffers; +using System.Text; using System.Text.Json.Nodes; using Nerdbank.MessagePack; using PolyType; @@ -14,8 +15,10 @@ namespace StreamJsonRpc; /// public partial class NerdbankMessagePackFormatter { - internal class TraceParentConverter : MessagePackConverter + internal class TraceParentAsBinaryConverter : MessagePackConverter { + internal static readonly TraceParentAsBinaryConverter Instance = new(); + public unsafe override TraceParent Read(ref MessagePackReader reader, SerializationContext context) { context.DepthStep(); @@ -78,4 +81,47 @@ public unsafe override void Write(ref MessagePackWriter writer, in TraceParent v public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) => null; } + + internal class TraceParentAsStringConverter : MessagePackConverter + { + internal static readonly TraceParentAsStringConverter Instance = new(); + + public override TraceParent Read(ref MessagePackReader reader, SerializationContext context) + { + context.DepthStep(); + + ReadOnlySequence utf8Sequence = reader.ReadStringSequence() ?? throw new MessagePackSerializationException("Unexpected null value."); + if (utf8Sequence.Length != TraceParent.Length) + { + throw new MessagePackSerializationException("Unexpected length for traceparent string."); + } + + Span utf8Bytes = stackalloc byte[TraceParent.Length]; + utf8Sequence.CopyTo(utf8Bytes); + + Span chars = stackalloc char[TraceParent.Length]; + if (!Encoding.UTF8.TryGetChars(utf8Bytes, chars, out int charsWritten)) + { + throw new MessagePackSerializationException("Invalid UTF-8 in traceparent string."); + } + + return new TraceParent(chars); + } + + public override void Write(ref MessagePackWriter writer, in TraceParent value, SerializationContext context) + { + if (value.Version != 0) + { + throw new NotSupportedException("traceparent version " + value.Version + " is not supported."); + } + + context.DepthStep(); + + Span chars = stackalloc char[TraceParent.Length]; + value.WriteTo(chars); + writer.Write(chars); + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) => null; + } } diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs index f1e474b9..c8960c13 100644 --- a/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs @@ -167,6 +167,15 @@ public MessagePackSerializer UserDataSerializer } } + /// + /// Gets a value indicating whether the W3C traceparent property + /// should be serialized as a string instead of a more compact binary format. + /// + /// The default value is . + public bool TraceParentAsW3CString { get; init; } + + private MessagePackConverter TraceParentConverter => this.TraceParentAsW3CString ? TraceParentAsStringConverter.Instance : TraceParentAsBinaryConverter.Instance; + /// public JsonRpcMessage Deserialize(ReadOnlySequence contentBuffer) { @@ -470,7 +479,7 @@ internal class JsonRpcRequestConverter : MessagePackConverter(null).Read(ref reader, context); + TraceParent traceParent = formatter.TraceParentConverter.Read(ref reader, context); result.TraceParent = traceParent.ToString(); } else if (TraceStatePropertyName.TryRead(ref reader)) @@ -577,7 +586,7 @@ public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcRequ if (value.TraceParent?.Length > 0) { writer.Write(TraceParentPropertyName); - context.GetConverter(Witness.GeneratedTypeShapeProvider).Write(ref writer, new TraceParent(value.TraceParent), context); + formatter.TraceParentConverter.Write(ref writer, new TraceParent(value.TraceParent), context); if (value.TraceState?.Length > 0) { diff --git a/src/StreamJsonRpc/PolyfillMethods.cs b/src/StreamJsonRpc/PolyfillMethods.cs deleted file mode 100644 index 18ccf5ba..00000000 --- a/src/StreamJsonRpc/PolyfillMethods.cs +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. See LICENSE file in the project root for full license information. - -namespace StreamJsonRpc; - -internal static class PolyfillMethods -{ -#if NETSTANDARD2_0 - internal static void Deconstruct(this KeyValuePair pair, out TKey key, out TValue value) - => (key, value) = (pair.Key, pair.Value); -#endif -} diff --git a/src/StreamJsonRpc/Polyfills.cs b/src/StreamJsonRpc/Polyfills.cs index 31072d76..192b2807 100644 --- a/src/StreamJsonRpc/Polyfills.cs +++ b/src/StreamJsonRpc/Polyfills.cs @@ -7,6 +7,11 @@ namespace StreamJsonRpc; internal static class Polyfills { +#if NETSTANDARD2_0 + internal static void Deconstruct(this KeyValuePair pair, out TKey key, out TValue value) + => (key, value) = (pair.Key, pair.Value); +#endif + #if !(NETSTANDARD2_1_OR_GREATER || NET) internal static unsafe string GetString(this Encoding encoding, ReadOnlySpan utf8Bytes) { @@ -16,4 +21,26 @@ internal static unsafe string GetString(this Encoding encoding, ReadOnlySpan utf8Bytes, Span chars, out int charsWritten) + { + fixed (byte* pBytes = utf8Bytes) + { + fixed (char* pChars = chars) + { + try + { + charsWritten = encoding.GetChars(pBytes, utf8Bytes.Length, pChars, chars.Length); + return true; + } + catch (ArgumentException) + { + charsWritten = 0; + return false; + } + } + } + } +#endif } diff --git a/src/StreamJsonRpc/Protocol/TraceParent.cs b/src/StreamJsonRpc/Protocol/TraceParent.cs index 5089ae46..1ea77b05 100644 --- a/src/StreamJsonRpc/Protocol/TraceParent.cs +++ b/src/StreamJsonRpc/Protocol/TraceParent.cs @@ -2,11 +2,9 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System.Diagnostics; -using Nerdbank.MessagePack; namespace StreamJsonRpc.Protocol; -[MessagePackConverter(typeof(NerdbankMessagePackFormatter.TraceParentConverter))] internal unsafe struct TraceParent { internal const int VersionByteCount = 1; @@ -14,6 +12,14 @@ internal unsafe struct TraceParent internal const int TraceIdByteCount = 16; internal const int FlagsByteCount = 1; + /// + /// The number of characters in a serialized traceparent value. + /// + /// + /// When calculating the number of characters required, double each 'byte' we have to encode since we're using hex. + /// + internal const int Length = (VersionByteCount * 2) + 1 + (TraceIdByteCount * 2) + 1 + (ParentIdByteCount * 2) + 1 + (FlagsByteCount * 2); + internal byte Version; internal fixed byte TraceId[TraceIdByteCount]; @@ -23,50 +29,55 @@ internal unsafe struct TraceParent internal TraceFlags Flags; internal TraceParent(string? traceparent) + : this(traceparent is null ? default : traceparent.AsSpan()) { - if (traceparent is null) + } + + internal TraceParent(ReadOnlySpan traceparent) + { + if (traceparent is []) { this.Version = 0; this.Flags = TraceFlags.None; return; } - ReadOnlySpan traceparentChars = traceparent.AsSpan(); - // Decode version - ReadOnlySpan slice = Consume(ref traceparentChars, VersionByteCount * 2); + ReadOnlySpan slice = Consume(ref traceparent, VersionByteCount * 2); fixed (byte* pVersion = &this.Version) { Hex.Decode(slice, new Span(pVersion, 1)); } - ConsumeHyphen(ref traceparentChars); + ConsumeHyphen(ref traceparent); // Decode traceid - slice = Consume(ref traceparentChars, TraceIdByteCount * 2); + slice = Consume(ref traceparent, TraceIdByteCount * 2); fixed (byte* pTraceId = this.TraceId) { Hex.Decode(slice, new Span(pTraceId, TraceIdByteCount)); } - ConsumeHyphen(ref traceparentChars); + ConsumeHyphen(ref traceparent); // Decode parentid - slice = Consume(ref traceparentChars, ParentIdByteCount * 2); + slice = Consume(ref traceparent, ParentIdByteCount * 2); fixed (byte* pParentId = this.ParentId) { Hex.Decode(slice, new Span(pParentId, ParentIdByteCount)); } - ConsumeHyphen(ref traceparentChars); + ConsumeHyphen(ref traceparent); // Decode flags - slice = Consume(ref traceparentChars, FlagsByteCount * 2); + slice = Consume(ref traceparent, FlagsByteCount * 2); fixed (TraceFlags* pFlags = &this.Flags) { Hex.Decode(slice, new Span(pFlags, 1)); } + Requires.Argument(traceparent is [], nameof(traceparent), "Expected traceparent to be fully consumed."); + static void ConsumeHyphen(ref ReadOnlySpan value) { if (value[0] != '-') @@ -112,9 +123,22 @@ internal Guid TraceIdGuid public override string ToString() { - // When calculating the number of characters required, double each 'byte' we have to encode since we're using hex. - Span traceparent = stackalloc char[(VersionByteCount * 2) + 1 + (TraceIdByteCount * 2) + 1 + (ParentIdByteCount * 2) + 1 + (FlagsByteCount * 2)]; - Span traceParentRemaining = traceparent; + Span chars = stackalloc char[Length]; + this.WriteTo(chars); + return chars.ToString(); + } + + /// + /// Serializes the value as a string. + /// + /// The span to write to. This must be at least in length. + /// The number of characters written to . Always equal to . + /// Thrown if is shorter than . + internal int WriteTo(Span destination) + { + Requires.Argument(destination.Length >= Length, nameof(destination), $"Destination must be at least {Length} characters in length."); + + Span traceParentRemaining = destination; fixed (byte* pVersion = &this.Version) { @@ -142,9 +166,9 @@ public override string ToString() Hex.Encode(new ReadOnlySpan(pFlags, 1), ref traceParentRemaining); } - Debug.Assert(traceParentRemaining.Length == 0, "Characters were not initialized."); + Debug.Assert(traceParentRemaining is [], "Characters were not initialized."); - return traceparent.ToString(); + return Length; static void AddHyphen(ref Span value) { diff --git a/test/StreamJsonRpc.Tests/MessagePackFormatterTests.cs b/test/StreamJsonRpc.Tests/MessagePackFormatterTests.cs index b7f932a4..13ca3786 100644 --- a/test/StreamJsonRpc.Tests/MessagePackFormatterTests.cs +++ b/test/StreamJsonRpc.Tests/MessagePackFormatterTests.cs @@ -335,6 +335,48 @@ public void CanDeserializeWithExtraProperty_JsonRpcError() Assert.Equal(dynamic.error.code, (int?)request.Error?.Code); } + [Theory] + [InlineData(false, MessagePackType.Array)] + [InlineData(true, MessagePackType.String)] + public void TraceParentAsW3CStringControlsSerializationFormat( + bool traceParentAsW3CString, + MessagePackType expectedMessagePackType) + { + const string TraceParent = "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01"; + MessagePackFormatter formatter = new() { TraceParentAsW3CString = traceParentAsW3CString }; + var request = new JsonRpcRequest + { + Method = "something", + ArgumentsList = Array.Empty(), + TraceParent = TraceParent, + }; + + var sequence = new Sequence(); + formatter.Serialize(sequence, request); + + this.Logger.WriteLine(MessagePackSerializer.ConvertToJson(sequence, cancellationToken: this.TimeoutToken)); + + var reader = new MessagePackReader(sequence.AsReadOnlySequence); + int propertyCount = reader.ReadMapHeader(); + bool foundTraceParent = false; + for (int i = 0; i < propertyCount; i++) + { + string? propertyName = reader.ReadString(); + Assert.NotNull(propertyName); + if (propertyName == "traceparent") + { + Assert.Equal(expectedMessagePackType, reader.NextMessagePackType); + foundTraceParent = true; + } + + reader.Skip(); + } + + Assert.True(foundTraceParent); + var actual = Assert.IsAssignableFrom(formatter.Deserialize(sequence.AsReadOnlySequence)); + Assert.Equal(TraceParent, actual.TraceParent); + } + [Fact] public void StringsInUserDataAreInterned() { diff --git a/test/StreamJsonRpc.Tests/NerdbankMessagePackFormatterTests.cs b/test/StreamJsonRpc.Tests/NerdbankMessagePackFormatterTests.cs index 93368f67..36904b57 100644 --- a/test/StreamJsonRpc.Tests/NerdbankMessagePackFormatterTests.cs +++ b/test/StreamJsonRpc.Tests/NerdbankMessagePackFormatterTests.cs @@ -326,6 +326,53 @@ public void CanDeserializeWithExtraProperty_JsonRpcError() Assert.Equal(dynamic.error.code, (int?)request.Error?.Code); } + [Theory] + [InlineData(false, MessagePackType.Array)] + [InlineData(true, MessagePackType.String)] + public void TraceParentAsW3CStringControlsSerializationFormat( + bool traceParentAsW3CString, + MessagePackType expectedMessagePackType) + { + const string TraceParent = "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01"; + NerdbankMessagePackFormatter formatter = new() + { + TraceParentAsW3CString = traceParentAsW3CString, + TypeShapeProvider = Witness.GeneratedTypeShapeProvider, + }; + var request = new JsonRpcRequest + { + Method = "something", + ArgumentsList = Array.Empty(), + TraceParent = TraceParent, + }; + + var sequence = new Sequence(); + formatter.Serialize(sequence, request); + + this.Logger.WriteLine(formatter.UserDataSerializer.ConvertToJson(sequence, new MessagePackSerializer.JsonOptions { Indentation = " " })); + + var reader = new MessagePackReader(sequence.AsReadOnlySequence); + SerializationContext context = new(); + int propertyCount = reader.ReadMapHeader(); + bool foundTraceParent = false; + for (int i = 0; i < propertyCount; i++) + { + string? propertyName = reader.ReadString(); + Assert.NotNull(propertyName); + if (propertyName == "traceparent") + { + Assert.Equal(expectedMessagePackType, reader.NextMessagePackType); + foundTraceParent = true; + } + + reader.Skip(context); + } + + Assert.True(foundTraceParent); + var actual = Assert.IsAssignableFrom(formatter.Deserialize(sequence.AsReadOnlySequence)); + Assert.Equal(TraceParent, actual.TraceParent); + } + [Fact] public void StringsInUserDataAreInterned() {