Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions mongo/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -956,6 +956,10 @@ type ClientBulkWrite struct {
Database string
Collection string
Model ClientWriteModel

// Deprecated: This option is for internal use only and should not be set. It may be changed or removed in any
// release.
Internal optionsutil.Options
}

// BulkWrite performs a client-level bulk write operation.
Expand Down Expand Up @@ -1019,16 +1023,20 @@ func (c *Client) BulkWrite(ctx context.Context, writes []ClientBulkWrite,
}
selector := makePinnedSelector(sess, writeSelector)

writePairs := make([]clientBulkWritePair, len(writes))
writeOps := make([]clientBulkWriteOp, len(writes))
for i, w := range writes {
writePairs[i] = clientBulkWritePair{
p := clientBulkWriteOp{
namespace: fmt.Sprintf("%s.%s", w.Database, w.Collection),
model: w.Model,
}
if uuid, ok := optionsutil.Value(w.Internal, "collectionUUID").([]byte); ok {
p.collectionUUID = uuid
}
writeOps[i] = p
}

op := clientBulkWrite{
writePairs: writePairs,
writeOps: writeOps,
ordered: bwo.Ordered,
bypassDocumentValidation: bwo.BypassDocumentValidation,
comment: bwo.Comment,
Expand Down
59 changes: 36 additions & 23 deletions mongo/client_bulk_write.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,14 @@ const (
database = "admin"
)

type clientBulkWritePair struct {
namespace string
model any
type clientBulkWriteOp struct {
namespace string
model any
collectionUUID []byte
}

type clientBulkWrite struct {
writePairs []clientBulkWritePair
writeOps []clientBulkWriteOp
errorsOnly bool
ordered *bool
bypassDocumentValidation *bool
Expand All @@ -54,10 +55,10 @@ type clientBulkWrite struct {
}

func (bw *clientBulkWrite) execute(ctx context.Context) error {
if len(bw.writePairs) == 0 {
if len(bw.writeOps) == 0 {
return fmt.Errorf("invalid writes: %w", ErrEmptySlice)
}
for i, m := range bw.writePairs {
for i, m := range bw.writeOps {
if m.model == nil {
return fmt.Errorf("error from model at index %d: %w", i, ErrNilDocument)
}
Expand All @@ -66,7 +67,7 @@ func (bw *clientBulkWrite) execute(ctx context.Context) error {
session: bw.session,
client: bw.client,
ordered: bw.ordered == nil || *bw.ordered,
writePairs: bw.writePairs,
writeOps: bw.writeOps,
result: &bw.result,
retryMode: driver.RetryOnce,
}
Expand Down Expand Up @@ -117,7 +118,7 @@ func (bw *clientBulkWrite) execute(ctx context.Context) error {
_, ok := batches.writeErrors[0]
hasSuccess = !ok
} else {
hasSuccess = len(batches.writeErrors) < len(bw.writePairs)
hasSuccess = len(batches.writeErrors) < len(bw.writeOps)
}
if hasSuccess {
exception.PartialResult = batches.result
Expand Down Expand Up @@ -200,7 +201,7 @@ type modelBatches struct {
client *Client

ordered bool
writePairs []clientBulkWritePair
writeOps []clientBulkWriteOp

offset int

Expand All @@ -221,16 +222,16 @@ func (mb *modelBatches) IsOrdered() *bool {

func (mb *modelBatches) AdvanceBatches(n int) {
mb.offset += n
if mb.offset > len(mb.writePairs) {
mb.offset = len(mb.writePairs)
if mb.offset > len(mb.writeOps) {
mb.offset = len(mb.writeOps)
}
}

func (mb *modelBatches) Size() int {
if mb.offset > len(mb.writePairs) {
if mb.offset > len(mb.writeOps) {
return 0
}
return len(mb.writePairs) - mb.offset
return len(mb.writeOps) - mb.offset
}

func (mb *modelBatches) AppendBatchSequence(dst []byte, maxCount, totalSize int) (int, []byte, error) {
Expand Down Expand Up @@ -281,15 +282,20 @@ func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, tota
mb.cursorHandlers = mb.cursorHandlers[:0]
mb.newIDMap = make(map[int]any)

nsMap := make(map[string]int)
getNsIndex := func(namespace string) (int, bool) {
v, ok := nsMap[namespace]
type nsInfoKey struct {
namespace string
uuidKey string // string(uuid), or "" when no UUID
}
nsMap := make(map[nsInfoKey]int)
getNsIndex := func(namespace string, uuid []byte) (int, bool) {
key := nsInfoKey{namespace: namespace, uuidKey: string(uuid)}
v, ok := nsMap[key]
if ok {
return v, ok
}
nsIdx := len(nsMap)
nsMap[namespace] = nsIdx
return nsIdx, ok
idx := len(nsMap)
nsMap[key] = idx
return idx, ok
}

canRetry := true
Expand All @@ -301,17 +307,18 @@ func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, tota
totalSize -= 1000
size := len(dst) + len(nsDst)
var n int
for i := mb.offset; i < len(mb.writePairs); i++ {
for i := mb.offset; i < len(mb.writeOps); i++ {
if n == maxCount {
break
}

ns := mb.writePairs[i].namespace
nsIdx, exists := getNsIndex(ns)
ns := mb.writeOps[i].namespace
opUUID := mb.writeOps[i].collectionUUID
nsIdx, exists := getNsIndex(ns, opUUID)

var doc bsoncore.Document
var err error
switch model := mb.writePairs[i].model.(type) {
switch model := mb.writeOps[i].model.(type) {
case *ClientInsertOneModel:
mb.cursorHandlers = append(mb.cursorHandlers, mb.appendInsertResult)
var id any
Expand Down Expand Up @@ -393,6 +400,9 @@ func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, tota
length := len(doc)
if !exists {
length += len(ns)
if opUUID != nil {
length += len(opUUID)
}
}
size += length
if size >= totalSize {
Expand All @@ -403,6 +413,9 @@ func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, tota
if !exists {
idx, doc := bsoncore.AppendDocumentStart(nil)
doc = bsoncore.AppendStringElement(doc, "ns", ns)
if opUUID != nil {
doc = bsoncore.AppendBinaryElement(doc, "collectionUUID", bson.TypeBinaryUUID, opUUID)
}
doc, _ = bsoncore.AppendDocumentEnd(doc, idx)
nsDst = fn.appendDocument(nsDst, strconv.Itoa(n), doc)
}
Expand Down
Loading
Loading