diff --git a/bson/array_codec.go b/bson/array_codec.go index 5714b0e81..c369fc49f 100644 --- a/bson/array_codec.go +++ b/bson/array_codec.go @@ -31,8 +31,16 @@ func (ac *arrayCodec) DecodeValue(_ DecodeContext, vr ValueReader, val reflect.V if !val.CanSet() || val.Type() != tCoreArray { return ValueDecoderError{Name: "CoreArrayDecodeValue", Types: []reflect.Type{tCoreArray}, Received: val} } - if vrType := vr.Type(); vrType != TypeArray { - return fmt.Errorf("cannot decode %v into a %s", vrType, val.Type()) + switch vr.Type() { + case TypeArray: + case TypeNull: + val.Set(reflect.Zero(val.Type())) + return vr.ReadNull() + case TypeUndefined: + val.Set(reflect.Zero(val.Type())) + return vr.ReadUndefined() + default: + return fmt.Errorf("cannot decode %v into a %s", vr.Type(), val.Type()) } if val.IsNil() { diff --git a/bson/default_value_decoders.go b/bson/default_value_decoders.go index 1dc598dde..91c84095b 100644 --- a/bson/default_value_decoders.go +++ b/bson/default_value_decoders.go @@ -119,6 +119,9 @@ func dDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { case TypeNull: val.Set(reflect.Zero(val.Type())) return vr.ReadNull() + case TypeUndefined: + val.Set(reflect.Zero(val.Type())) + return vr.ReadUndefined() default: return fmt.Errorf("cannot decode %v into a D", vrType) } @@ -1310,9 +1313,15 @@ func coreDocumentDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) if !val.CanSet() || val.Type() != tCoreDocument { return ValueDecoderError{Name: "CoreDocumentDecodeValue", Types: []reflect.Type{tCoreDocument}, Received: val} } - vrType := vr.Type() - isDocument := vrType == Type(0) || vrType == TypeEmbeddedDocument || vrType == TypeArray - if !isDocument { + switch vrType := vr.Type(); vrType { + case Type(0), TypeEmbeddedDocument, TypeArray: + case TypeNull: + val.Set(reflect.Zero(val.Type())) + return vr.ReadNull() + case TypeUndefined: + val.Set(reflect.Zero(val.Type())) + return vr.ReadUndefined() + default: return fmt.Errorf("cannot decode %v into a %s", vrType, val.Type()) } diff --git a/bson/default_value_decoders_test.go b/bson/default_value_decoders_test.go index 5b955158b..9f845370a 100644 --- a/bson/default_value_decoders_test.go +++ b/bson/default_value_decoders_test.go @@ -2302,6 +2302,22 @@ func TestDefaultValueDecoders(t *testing.T) { readDocument, errors.New("copy error"), }, + { + "decode null", + bsoncore.Document(nil), + nil, + &valueReaderWriter{BSONType: TypeNull}, + readNull, + nil, + }, + { + "decode undefined", + bsoncore.Document(nil), + nil, + &valueReaderWriter{BSONType: TypeUndefined}, + readUndefined, + nil, + }, }, }, { @@ -2423,6 +2439,22 @@ func TestDefaultValueDecoders(t *testing.T) { Received: reflect.New(reflect.TypeOf((*bsoncore.Array)(nil))).Elem(), }, }, + { + "decode null", + bsoncore.Array(nil), + nil, + &valueReaderWriter{BSONType: TypeNull}, + readNull, + nil, + }, + { + "decode undefined", + bsoncore.Array(nil), + nil, + &valueReaderWriter{BSONType: TypeUndefined}, + readUndefined, + nil, + }, }, }, } diff --git a/bson/primitive_codecs.go b/bson/primitive_codecs.go index ffb9f0344..1d9b56a78 100644 --- a/bson/primitive_codecs.go +++ b/bson/primitive_codecs.go @@ -78,9 +78,15 @@ func rawDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tRaw { return ValueDecoderError{Name: "RawDecodeValue", Types: []reflect.Type{tRaw}, Received: val} } - vrType := vr.Type() - isDocument := vrType == Type(0) || vrType == TypeEmbeddedDocument || vrType == TypeArray - if !isDocument { + switch vrType := vr.Type(); vrType { + case Type(0), TypeEmbeddedDocument, TypeArray: + case TypeNull: + val.Set(reflect.Zero(val.Type())) + return vr.ReadNull() + case TypeUndefined: + val.Set(reflect.Zero(val.Type())) + return vr.ReadUndefined() + default: return fmt.Errorf("cannot decode %v into a %s", vrType, val.Type()) } diff --git a/bson/primitive_codecs_test.go b/bson/primitive_codecs_test.go index 5f380157e..9ffc27a7a 100644 --- a/bson/primitive_codecs_test.go +++ b/bson/primitive_codecs_test.go @@ -574,6 +574,22 @@ func TestPrimitiveValueDecoders(t *testing.T) { readDocument, errors.New("copy error"), }, + { + "decode null", + Raw(nil), + nil, + &valueReaderWriter{BSONType: TypeNull}, + readNull, + nil, + }, + { + "decode undefined", + Raw(nil), + nil, + &valueReaderWriter{BSONType: TypeUndefined}, + readUndefined, + nil, + }, }, }, } diff --git a/bson/unmarshal_test.go b/bson/unmarshal_test.go index be9d7e238..b53bdbee3 100644 --- a/bson/unmarshal_test.go +++ b/bson/unmarshal_test.go @@ -931,5 +931,55 @@ func TestUnmarshalTypeCompatibility(t *testing.T) { assert.NoError(t, err) assert.Equal(t, want, []byte(val.Foo)) }) + t.Run("bson.D into bson.Raw", func(t *testing.T) { + t.Parallel() + + data := docToBytes(D{{"foo", D{{"bar", int32(42)}}}}) + want := []byte(bsoncore.NewDocumentBuilder().AppendInt32("bar", 42).Build()) + + var val struct { + Foo Raw + } + + err := Unmarshal(data, &val) + assert.NoError(t, err) + assert.Equal(t, want, []byte(val.Foo)) + }) + t.Run("nil into bson.Raw", func(t *testing.T) { + t.Parallel() + + data := docToBytes(D{{"foo", nil}}) + + var val struct { + Foo Raw + } + err := Unmarshal(data, &val) + assert.NoError(t, err) + assert.Nil(t, val.Foo) + }) + t.Run("nil into bsoncore.Document", func(t *testing.T) { + t.Parallel() + + data := docToBytes(D{{"foo", nil}}) + + var val struct { + Foo bsoncore.Document + } + err := Unmarshal(data, &val) + assert.NoError(t, err) + assert.Nil(t, val.Foo) + }) + t.Run("nil into bsoncore.Array", func(t *testing.T) { + t.Parallel() + + data := docToBytes(D{{"foo", nil}}) + + var val struct { + Foo bsoncore.Array + } + err := Unmarshal(data, &val) + assert.NoError(t, err) + assert.Nil(t, val.Foo) + }) }) }