Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 68 additions & 2 deletions src/megatron/bridge/data/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,13 @@ class MegatronPretrainingBatchSampler:
are padded to the same length, which is critical for fine-tuning with variable
sequence lengths.

The sampler is epoch-aware: it advances ``consumed_samples`` as it yields so that
when it is re-iterated (e.g. wrapped in ``cyclic_iter`` for multi-epoch
fine-tuning), each new epoch starts from the beginning of the dataset instead of
permanently re-applying the resume offset. When ``shuffle`` is enabled the per-epoch
order is a deterministic, ``seed``- and epoch-derived permutation, so resuming from a
checkpoint reproduces the same order as an uninterrupted run.

Args:
total_samples: Total number of samples in the dataset.
consumed_samples: Number of samples already consumed (for resuming).
Expand All @@ -207,6 +214,10 @@ class MegatronPretrainingBatchSampler:
data_parallel_size: Total number of GPUs in the data parallel group.
drop_last: If True, drops the last incomplete batch.
pad_samples_to_global_batch_size: If True, pads incomplete batches with -1 indices.
shuffle: If True, reshuffle the sample order every epoch using a deterministic,
seed- and epoch-derived permutation. Defaults to True for fine-tuning.
seed: Base random seed for the per-epoch shuffle. Defaults to the current
torch global seed. Only used when ``shuffle`` is True.
"""

def __init__(
Expand All @@ -219,6 +230,8 @@ def __init__(
data_parallel_size: int,
drop_last: bool = True,
pad_samples_to_global_batch_size: bool = False,
shuffle: bool = True,
seed: int | None = None,
) -> None:
self.total_samples = total_samples
self.consumed_samples = consumed_samples
Expand All @@ -227,6 +240,8 @@ def __init__(
self.data_parallel_size = data_parallel_size
self.drop_last = drop_last
self.pad_samples_to_global_batch_size = pad_samples_to_global_batch_size
self.shuffle = shuffle
self.seed = int(torch.initial_seed() if seed is None else seed)
self.micro_batch_times_data_parallel_size = self.micro_batch_size * data_parallel_size

assert self.total_samples > 0, "no sample to consume: {}".format(self.total_samples)
Expand Down Expand Up @@ -255,7 +270,9 @@ def __len__(self) -> int:
Since we now yield the full global batch at once (not split into microbatches),
this returns the number of global batches.
"""
num_available_samples = self.total_samples - self.consumed_samples % self.total_samples
active_total_samples = self._active_total_samples()
current_epoch_samples = self.consumed_samples % active_total_samples
num_available_samples = active_total_samples - current_epoch_samples
if self.drop_last:
num_global_batches = num_available_samples // self._global_batch_size
else:
Expand All @@ -264,6 +281,23 @@ def __len__(self) -> int:
# Each call to __iter__ yields one global batch
return num_global_batches

def _active_total_samples(self) -> int:
"""Return the sample-count unit used for epoch/offset accounting."""
if self.drop_last:
active_total_samples = (self.total_samples // self._global_batch_size) * self._global_batch_size
assert active_total_samples > 0, (
"drop_last=True requires at least one full global batch; "
f"got total_samples={self.total_samples}, global_batch_size={self._global_batch_size}"
)
return active_total_samples

if self.pad_samples_to_global_batch_size:
return (
(self.total_samples + self._global_batch_size - 1) // self._global_batch_size
) * self._global_batch_size

return self.total_samples

def __iter__(self) -> Iterator[list[int]]:
"""Yields lists of indices for the full global batch assigned to this rank.

Expand All @@ -276,10 +310,33 @@ def __iter__(self) -> Iterator[list[int]]:
1. Compute max_length across the entire global batch
2. Pad all samples to the same length
3. Then split into microbatches with consistent sequence length

The sampler is epoch-aware. It derives the current epoch and the within-epoch
offset from ``consumed_samples`` and advances ``consumed_samples`` as it yields.
When wrapped in ``cyclic_iter`` for multi-epoch fine-tuning, each re-iteration
therefore starts a fresh epoch (offset 0) rather than repeatedly re-applying the
resume offset, which previously caused the head of every post-resume epoch to be
skipped and the tail to be replayed.
"""
active_total_samples = self._active_total_samples()
epoch = self.consumed_samples // active_total_samples
current_epoch_samples = self.consumed_samples % active_total_samples

if self.shuffle:
# Deterministic, seed- and epoch-derived permutation so a resumed run
# reproduces the same per-epoch order as an uninterrupted run.
g = torch.Generator()
g.manual_seed(self.seed + epoch)
idx_order = torch.randperm(self.total_samples, generator=g).tolist()
else:
idx_order = list(range(self.total_samples))

# Skip samples already consumed within the current epoch (mid-epoch resume).
idx_order = idx_order[current_epoch_samples:]

batch = []
# Last batch will be dropped if drop_last is True
for idx in range(self.consumed_samples % self.total_samples, self.total_samples):
for idx in idx_order:
batch.append(idx)
if len(batch) == self._global_batch_size:
# Distribute indices in interleaved fashion across ranks
Expand All @@ -293,6 +350,10 @@ def __iter__(self) -> Iterator[list[int]]:
]
assert len(all_indices) == self._global_batch_size_on_this_data_parallel_rank

# Advance so the next re-iteration starts a fresh epoch (offset 0) rather
# than re-applying the resume offset and replaying the tail.
self.consumed_samples += self._global_batch_size

# Yield ALL indices at once (not split into microbatches)
# The training loop will handle splitting after collation
yield all_indices
Expand All @@ -307,6 +368,11 @@ def __iter__(self) -> Iterator[list[int]]:
num_pad = self._global_batch_size // self.data_parallel_size - len(all_indices)
all_indices = all_indices + [-1] * num_pad

if self.pad_samples_to_global_batch_size:
self.consumed_samples += self._global_batch_size
else:
self.consumed_samples += len(batch)

# Yield ALL indices at once
yield all_indices

Expand Down
Loading
Loading