diff --git a/mongo/client.go b/mongo/client.go index 14b65b68bc..82a3441fb7 100644 --- a/mongo/client.go +++ b/mongo/client.go @@ -956,6 +956,10 @@ type ClientBulkWrite struct { Database string Collection string Model ClientWriteModel + + // Deprecated: This option is for internal use only and should not be set. It may be changed or removed in any + // release. + Internal optionsutil.Options } // BulkWrite performs a client-level bulk write operation. @@ -1019,16 +1023,20 @@ func (c *Client) BulkWrite(ctx context.Context, writes []ClientBulkWrite, } selector := makePinnedSelector(sess, writeSelector) - writePairs := make([]clientBulkWritePair, len(writes)) + writeOps := make([]clientBulkWriteOp, len(writes)) for i, w := range writes { - writePairs[i] = clientBulkWritePair{ + p := clientBulkWriteOp{ namespace: fmt.Sprintf("%s.%s", w.Database, w.Collection), model: w.Model, } + if uuid, ok := optionsutil.Value(w.Internal, "collectionUUID").([]byte); ok { + p.collectionUUID = uuid + } + writeOps[i] = p } op := clientBulkWrite{ - writePairs: writePairs, + writeOps: writeOps, ordered: bwo.Ordered, bypassDocumentValidation: bwo.BypassDocumentValidation, comment: bwo.Comment, diff --git a/mongo/client_bulk_write.go b/mongo/client_bulk_write.go index bf3a1e530e..8dcb6f3175 100644 --- a/mongo/client_bulk_write.go +++ b/mongo/client_bulk_write.go @@ -28,13 +28,14 @@ const ( database = "admin" ) -type clientBulkWritePair struct { - namespace string - model any +type clientBulkWriteOp struct { + namespace string + model any + collectionUUID []byte } type clientBulkWrite struct { - writePairs []clientBulkWritePair + writeOps []clientBulkWriteOp errorsOnly bool ordered *bool bypassDocumentValidation *bool @@ -54,10 +55,10 @@ type clientBulkWrite struct { } func (bw *clientBulkWrite) execute(ctx context.Context) error { - if len(bw.writePairs) == 0 { + if len(bw.writeOps) == 0 { return fmt.Errorf("invalid writes: %w", ErrEmptySlice) } - for i, m := range bw.writePairs { + for i, m := range bw.writeOps { if m.model == nil { return fmt.Errorf("error from model at index %d: %w", i, ErrNilDocument) } @@ -66,7 +67,7 @@ func (bw *clientBulkWrite) execute(ctx context.Context) error { session: bw.session, client: bw.client, ordered: bw.ordered == nil || *bw.ordered, - writePairs: bw.writePairs, + writeOps: bw.writeOps, result: &bw.result, retryMode: driver.RetryOnce, } @@ -117,7 +118,7 @@ func (bw *clientBulkWrite) execute(ctx context.Context) error { _, ok := batches.writeErrors[0] hasSuccess = !ok } else { - hasSuccess = len(batches.writeErrors) < len(bw.writePairs) + hasSuccess = len(batches.writeErrors) < len(bw.writeOps) } if hasSuccess { exception.PartialResult = batches.result @@ -200,7 +201,7 @@ type modelBatches struct { client *Client ordered bool - writePairs []clientBulkWritePair + writeOps []clientBulkWriteOp offset int @@ -221,16 +222,16 @@ func (mb *modelBatches) IsOrdered() *bool { func (mb *modelBatches) AdvanceBatches(n int) { mb.offset += n - if mb.offset > len(mb.writePairs) { - mb.offset = len(mb.writePairs) + if mb.offset > len(mb.writeOps) { + mb.offset = len(mb.writeOps) } } func (mb *modelBatches) Size() int { - if mb.offset > len(mb.writePairs) { + if mb.offset > len(mb.writeOps) { return 0 } - return len(mb.writePairs) - mb.offset + return len(mb.writeOps) - mb.offset } func (mb *modelBatches) AppendBatchSequence(dst []byte, maxCount, totalSize int) (int, []byte, error) { @@ -281,15 +282,20 @@ func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, tota mb.cursorHandlers = mb.cursorHandlers[:0] mb.newIDMap = make(map[int]any) - nsMap := make(map[string]int) - getNsIndex := func(namespace string) (int, bool) { - v, ok := nsMap[namespace] + type nsInfoKey struct { + namespace string + uuidKey string // string(uuid), or "" when no UUID + } + nsMap := make(map[nsInfoKey]int) + getNsIndex := func(namespace string, uuid []byte) (int, bool) { + key := nsInfoKey{namespace: namespace, uuidKey: string(uuid)} + v, ok := nsMap[key] if ok { return v, ok } - nsIdx := len(nsMap) - nsMap[namespace] = nsIdx - return nsIdx, ok + idx := len(nsMap) + nsMap[key] = idx + return idx, ok } canRetry := true @@ -301,17 +307,18 @@ func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, tota totalSize -= 1000 size := len(dst) + len(nsDst) var n int - for i := mb.offset; i < len(mb.writePairs); i++ { + for i := mb.offset; i < len(mb.writeOps); i++ { if n == maxCount { break } - ns := mb.writePairs[i].namespace - nsIdx, exists := getNsIndex(ns) + ns := mb.writeOps[i].namespace + opUUID := mb.writeOps[i].collectionUUID + nsIdx, exists := getNsIndex(ns, opUUID) var doc bsoncore.Document var err error - switch model := mb.writePairs[i].model.(type) { + switch model := mb.writeOps[i].model.(type) { case *ClientInsertOneModel: mb.cursorHandlers = append(mb.cursorHandlers, mb.appendInsertResult) var id any @@ -393,6 +400,9 @@ func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, tota length := len(doc) if !exists { length += len(ns) + if opUUID != nil { + length += len(opUUID) + } } size += length if size >= totalSize { @@ -403,6 +413,9 @@ func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, tota if !exists { idx, doc := bsoncore.AppendDocumentStart(nil) doc = bsoncore.AppendStringElement(doc, "ns", ns) + if opUUID != nil { + doc = bsoncore.AppendBinaryElement(doc, "collectionUUID", bson.TypeBinaryUUID, opUUID) + } doc, _ = bsoncore.AppendDocumentEnd(doc, idx) nsDst = fn.appendDocument(nsDst, strconv.Itoa(n), doc) } diff --git a/mongo/client_bulk_write_test.go b/mongo/client_bulk_write_test.go index 7eb4fd9907..c5662af3cd 100644 --- a/mongo/client_bulk_write_test.go +++ b/mongo/client_bulk_write_test.go @@ -12,13 +12,14 @@ import ( "go.mongodb.org/mongo-driver/v2/bson" "go.mongodb.org/mongo-driver/v2/internal/assert" "go.mongodb.org/mongo-driver/v2/internal/require" + "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" ) func TestBatches(t *testing.T) { t.Parallel() batches := &modelBatches{ - writePairs: make([]clientBulkWritePair, 2), + writeOps: make([]clientBulkWriteOp, 2), } batches.AdvanceBatches(3) size := batches.Size() @@ -33,16 +34,16 @@ func TestAppendBatchSequence(t *testing.T) { require.NoError(t, err, "NewClient error: %v", err) return &modelBatches{ client: client, - writePairs: []clientBulkWritePair{ - {"ns0", nil}, - {"ns1", &ClientInsertOneModel{ + writeOps: []clientBulkWriteOp{ + {namespace: "ns0", model: nil}, + {namespace: "ns1", model: &ClientInsertOneModel{ Document: bson.D{{"foo", 42}}, }}, - {"ns2", &ClientReplaceOneModel{ + {namespace: "ns2", model: &ClientReplaceOneModel{ Filter: bson.D{{"foo", "bar"}}, Replacement: bson.D{{"foo", "baz"}}, }}, - {"ns1", &ClientDeleteOneModel{ + {namespace: "ns1", model: &ClientDeleteOneModel{ Filter: bson.D{{"qux", "quux"}}, }}, }, @@ -120,3 +121,188 @@ func TestAppendBatchSequence(t *testing.T) { assert.False(t, ok, "expected an delete results") }) } + +func TestAppendBatchesNamespaceUUIDs(t *testing.T) { + t.Parallel() + + uuid1 := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + uuid2 := []byte{17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32} + + type batchResult struct { + Ops []bson.Raw + NsInfo []bson.Raw + } + + // decodeBatches runs AppendBatchArray and returns both the ops and nsInfo + // arrays as slices of raw BSON documents. + decodeBatches := func(t *testing.T, batches *modelBatches) batchResult { + t.Helper() + _, data, err := batches.AppendBatchArray(nil, 100, 16_000) + require.NoError(t, err, "AppendBatchArray error: %v", err) + + // AppendBatchArray returns two concatenated array-element payloads ("ops" + // then "nsInfo"). Wrap them in a document so bson.Unmarshal can parse them. + idx, doc := bsoncore.AppendDocumentStart(nil) + doc = append(doc, data...) + doc, _ = bsoncore.AppendDocumentEnd(doc, idx) + + var result struct { + Ops []bson.Raw `bson:"ops"` + NsInfo []bson.Raw `bson:"nsInfo"` + } + require.NoError(t, bson.Unmarshal(doc, &result), "unmarshal error") + return batchResult{Ops: result.Ops, NsInfo: result.NsInfo} + } + + // decodeNsInfo returns the namespace string and, if present, the raw UUID bytes + // from a collectionUUID binary element (subtype is not checked here). + decodeNsInfo := func(t *testing.T, raw bson.Raw) (ns string, uuid []byte) { + t.Helper() + var entry struct { + Ns string `bson:"ns"` + CollectionUUID *bson.Binary `bson:"collectionUUID"` + } + require.NoError(t, bson.Unmarshal(raw, &entry)) + if entry.CollectionUUID != nil { + return entry.Ns, entry.CollectionUUID.Data + } + return entry.Ns, nil + } + + // nsIdxFromOp returns the namespace index embedded in any op document + // (the value of whichever of "insert"/"update"/"delete" is present). + nsIdxFromOp := func(t *testing.T, raw bson.Raw) int { + t.Helper() + var op struct { + Insert *int32 `bson:"insert"` + Update *int32 `bson:"update"` + Delete *int32 `bson:"delete"` + } + require.NoError(t, bson.Unmarshal(raw, &op)) + switch { + case op.Insert != nil: + return int(*op.Insert) + case op.Update != nil: + return int(*op.Update) + case op.Delete != nil: + return int(*op.Delete) + default: + t.Fatal("op has no insert/update/delete field") + return -1 + } + } + + client, err := newClient() + require.NoError(t, err, "newClient error: %v", err) + + t.Run("UUID on single entry", func(t *testing.T) { + t.Parallel() + + batches := &modelBatches{ + client: client, + writeOps: []clientBulkWriteOp{ + {namespace: "db.coll", model: &ClientInsertOneModel{Document: bson.D{{"x", 1}}}, collectionUUID: uuid1}, + }, + result: &ClientBulkWriteResult{Acknowledged: true}, + } + res := decodeBatches(t, batches) + require.Len(t, res.NsInfo, 1) + ns, uuid := decodeNsInfo(t, res.NsInfo[0]) + assert.Equal(t, "db.coll", ns) + require.NotNil(t, uuid) + assert.Equal(t, uuid1, uuid) + }) + + t.Run("no UUID set", func(t *testing.T) { + t.Parallel() + + batches := &modelBatches{ + client: client, + writeOps: []clientBulkWriteOp{ + {namespace: "db.coll", model: &ClientInsertOneModel{Document: bson.D{{"x", 1}}}}, + }, + result: &ClientBulkWriteResult{Acknowledged: true}, + } + res := decodeBatches(t, batches) + require.Len(t, res.NsInfo, 1) + ns, uuid := decodeNsInfo(t, res.NsInfo[0]) + assert.Equal(t, "db.coll", ns) + assert.Nil(t, uuid, "expected collectionUUID to be absent") + }) + + t.Run("mixed: some entries have UUID, some do not", func(t *testing.T) { + t.Parallel() + + batches := &modelBatches{ + client: client, + writeOps: []clientBulkWriteOp{ + {namespace: "db.with_uuid", model: &ClientInsertOneModel{Document: bson.D{{"x", 1}}}, collectionUUID: uuid1}, + {namespace: "db.no_uuid", model: &ClientInsertOneModel{Document: bson.D{{"x", 2}}}}, + }, + result: &ClientBulkWriteResult{Acknowledged: true}, + } + res := decodeBatches(t, batches) + require.Len(t, res.NsInfo, 2) + + ns0, uuid0 := decodeNsInfo(t, res.NsInfo[0]) + assert.Equal(t, "db.with_uuid", ns0) + require.NotNil(t, uuid0) + assert.Equal(t, uuid1, uuid0) + + ns1, uuid1Got := decodeNsInfo(t, res.NsInfo[1]) + assert.Equal(t, "db.no_uuid", ns1) + assert.Nil(t, uuid1Got, "expected collectionUUID to be absent for db.no_uuid") + }) + + t.Run("same namespace string, different UUIDs → two nsInfo entries, each op points to its UUID", func(t *testing.T) { + t.Parallel() + + // Simulates two collections with the same db.collection name but different + // UUIDs (e.g. the collection was dropped and recreated between events). + batches := &modelBatches{ + client: client, + writeOps: []clientBulkWriteOp{ + {namespace: "db.coll", model: &ClientInsertOneModel{Document: bson.D{{"x", 1}}}, collectionUUID: uuid1}, + {namespace: "db.coll", model: &ClientInsertOneModel{Document: bson.D{{"x", 2}}}, collectionUUID: uuid2}, + }, + result: &ClientBulkWriteResult{Acknowledged: true}, + } + res := decodeBatches(t, batches) + + // Verify nsInfo: one entry per UUID, both with the same namespace string. + require.Len(t, res.NsInfo, 2, "expected one nsInfo entry per distinct UUID") + ns0, u0 := decodeNsInfo(t, res.NsInfo[0]) + assert.Equal(t, "db.coll", ns0) + require.NotNil(t, u0) + assert.Equal(t, uuid1, u0) + ns1, u1 := decodeNsInfo(t, res.NsInfo[1]) + assert.Equal(t, "db.coll", ns1) + require.NotNil(t, u1) + assert.Equal(t, uuid2, u1) + + // Verify each op's namespace index correctly references its UUID's nsInfo entry. + require.Len(t, res.Ops, 2) + assert.Equal(t, 0, nsIdxFromOp(t, res.Ops[0]), "op[0] should reference nsInfo[0] (uuid1)") + assert.Equal(t, 1, nsIdxFromOp(t, res.Ops[1]), "op[1] should reference nsInfo[1] (uuid2)") + }) + + t.Run("same namespace and UUID → deduplicated nsInfo", func(t *testing.T) { + t.Parallel() + + batches := &modelBatches{ + client: client, + writeOps: []clientBulkWriteOp{ + {namespace: "db.coll", model: &ClientInsertOneModel{Document: bson.D{{"x", 1}}}, collectionUUID: uuid1}, + {namespace: "db.coll", model: &ClientInsertOneModel{Document: bson.D{{"x", 2}}}, collectionUUID: uuid1}, + }, + result: &ClientBulkWriteResult{Acknowledged: true}, + } + res := decodeBatches(t, batches) + require.Len(t, res.NsInfo, 1, "expected one nsInfo entry for same namespace+UUID") + + ns, uuid := decodeNsInfo(t, res.NsInfo[0]) + assert.Equal(t, "db.coll", ns) + require.NotNil(t, uuid) + assert.Equal(t, uuid1, uuid) + }) +} diff --git a/x/mongo/driver/xoptions/options.go b/x/mongo/driver/xoptions/options.go index c3a0938ebe..f185c64956 100644 --- a/x/mongo/driver/xoptions/options.go +++ b/x/mongo/driver/xoptions/options.go @@ -126,6 +126,35 @@ func SetInternalClientBulkWriteOptions(a *options.ClientBulkWriteOptionsBuilder, return nil } +// SetInternalClientBulkWriteEntry sets internal options on the Internal field of a +// mongo.ClientBulkWrite entry. Pass a pointer to the Internal field directly: +// +// xoptions.SetInternalClientBulkWriteEntry(&write.Internal, "collectionUUID", uuid) +// +// Supported keys: +// +// - "collectionUUID" ([]byte): attaches a collection UUID to the +// nsInfo entry for this write in the bulkWrite command. Two writes that share a +// namespace string but carry different UUIDs produce separate nsInfo entries, +// allowing the server to distinguish a collection from a same-named replacement +// created after a drop. +func SetInternalClientBulkWriteEntry(internal *optionsutil.Options, key string, option any) error { + typeErrFunc := func(t string) error { + return fmt.Errorf("unexpected type for %q: %T is not %s", key, option, t) + } + switch key { + case "collectionUUID": + b, ok := option.([]byte) + if !ok { + return typeErrFunc("[]byte") + } + *internal = optionsutil.WithValue(*internal, key, b) + default: + return fmt.Errorf("unsupported option: %q", key) + } + return nil +} + // SetInternalCountOptions sets internal options for CountOptions. func SetInternalCountOptions(a *options.CountOptionsBuilder, key string, option any) error { typeErrFunc := func(t string) error { diff --git a/x/mongo/driver/xoptions/options_test.go b/x/mongo/driver/xoptions/options_test.go index 4a4e1a371e..bafd06affd 100644 --- a/x/mongo/driver/xoptions/options_test.go +++ b/x/mongo/driver/xoptions/options_test.go @@ -87,3 +87,36 @@ func TestSetInternalClientOptions(t *testing.T) { require.EqualError(t, err, "unsupported option: \"unsupported\"") }) } + +func TestSetInternalClientBulkWriteEntry(t *testing.T) { + t.Parallel() + + t.Run("set collectionUUID", func(t *testing.T) { + t.Parallel() + + uuid := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + + var internal optionsutil.Options + err := SetInternalClientBulkWriteEntry(&internal, "collectionUUID", uuid) + require.NoError(t, err, "SetInternalClientBulkWriteEntry error: %v", err) + + v := optionsutil.Value(internal, "collectionUUID") + require.Equal(t, uuid, v) + }) + + t.Run("set collectionUUID - wrong type", func(t *testing.T) { + t.Parallel() + + var internal optionsutil.Options + err := SetInternalClientBulkWriteEntry(&internal, "collectionUUID", "not-a-slice") + require.EqualError(t, err, "unexpected type for \"collectionUUID\": string is not []byte") + }) + + t.Run("set unsupported option", func(t *testing.T) { + t.Parallel() + + var internal optionsutil.Options + err := SetInternalClientBulkWriteEntry(&internal, "unsupported", "value") + require.EqualError(t, err, "unsupported option: \"unsupported\"") + }) +}