Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/megatron/bridge/data/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,11 @@ def worker_init_fn(_):
data_parallel_rank=dp_rank,
data_parallel_size=dp_size,
global_batch_size=cfg.train.global_batch_size,
# Reshuffle the fine-tuning ('batch') training order every epoch so multi-epoch
# runs do not repeat the same order each pass. Seeded for resume reproducibility.
# Ignored by the 'single'/'cyclic'/'external' samplers.
shuffle=cfg.dataset.dataloader_type == "batch",
seed=getattr(cfg.dataset, "seed", cfg.rng.seed),
)
eval_gbs = (
cfg.validation.eval_global_batch_size
Expand Down
58 changes: 57 additions & 1 deletion src/megatron/bridge/data/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ def build_pretraining_data_loader(
data_parallel_size: int = 1,
drop_last: Optional[bool] = True,
global_batch_size: Optional[int] = None,
shuffle: bool = False,
seed: int = 0,
) -> Optional[DataLoader]:
"""Build a dataloader for pretraining.

Expand All @@ -49,6 +51,9 @@ def build_pretraining_data_loader(
drop_last: Whether to drop last incomplete batch.
global_batch_size: Total batch size across all data parallel ranks.
Required for 'batch' dataloader_type.
shuffle: Whether to reshuffle the sample order every epoch. Only honored by the
'batch' dataloader_type (fine-tuning); ignored by other samplers.
seed: Base random seed for the per-epoch shuffle of the 'batch' sampler.

Returns:
A PyTorch DataLoader instance, or the dataset itself if dataloader_type is
Expand Down Expand Up @@ -97,6 +102,8 @@ def build_pretraining_data_loader(
data_parallel_size=data_parallel_size,
drop_last=drop_last,
pad_samples_to_global_batch_size=not drop_last,
shuffle=shuffle,
seed=seed,
)
elif dataloader_type == "external":
# External dataloaders are passed through. User is expected to provide a
Expand Down Expand Up @@ -198,6 +205,13 @@ class MegatronPretrainingBatchSampler:
are padded to the same length, which is critical for fine-tuning with variable
sequence lengths.

The sampler is epoch-aware: it advances ``consumed_samples`` as it yields so that
when it is re-iterated (e.g. wrapped in ``cyclic_iter`` for multi-epoch
fine-tuning), each new epoch starts from the beginning of the dataset instead of
permanently re-applying the resume offset. When ``shuffle`` is enabled the per-epoch
order is a deterministic, ``seed``- and epoch-derived permutation, so resuming from a
checkpoint reproduces the same order as an uninterrupted run.

Args:
total_samples: Total number of samples in the dataset.
consumed_samples: Number of samples already consumed (for resuming).
Expand All @@ -207,6 +221,9 @@ class MegatronPretrainingBatchSampler:
data_parallel_size: Total number of GPUs in the data parallel group.
drop_last: If True, drops the last incomplete batch.
pad_samples_to_global_batch_size: If True, pads incomplete batches with -1 indices.
shuffle: If True, reshuffle the sample order every epoch using a deterministic,
seed- and epoch-derived permutation. Defaults to False (sequential order).
seed: Base random seed for the per-epoch shuffle. Only used when ``shuffle`` is True.
"""

def __init__(
Expand All @@ -219,6 +236,8 @@ def __init__(
data_parallel_size: int,
drop_last: bool = True,
pad_samples_to_global_batch_size: bool = False,
shuffle: bool = False,
seed: int = 0,
) -> None:
self.total_samples = total_samples
self.consumed_samples = consumed_samples
Expand All @@ -227,6 +246,8 @@ def __init__(
self.data_parallel_size = data_parallel_size
self.drop_last = drop_last
self.pad_samples_to_global_batch_size = pad_samples_to_global_batch_size
self.shuffle = shuffle
self.seed = seed
self.micro_batch_times_data_parallel_size = self.micro_batch_size * data_parallel_size

assert self.total_samples > 0, "no sample to consume: {}".format(self.total_samples)
Expand Down Expand Up @@ -276,10 +297,41 @@ def __iter__(self) -> Iterator[list[int]]:
1. Compute max_length across the entire global batch
2. Pad all samples to the same length
3. Then split into microbatches with consistent sequence length

The sampler is epoch-aware. It derives the current epoch and the within-epoch
offset from ``consumed_samples`` and advances ``consumed_samples`` as it yields.
When wrapped in ``cyclic_iter`` for multi-epoch fine-tuning, each re-iteration
therefore starts a fresh epoch (offset 0) rather than repeatedly re-applying the
resume offset, which previously caused the head of every post-resume epoch to be
skipped and the tail to be replayed.
"""
# Largest multiple of the global batch size that fits in one epoch. Used for
# epoch / offset accounting so that cyclic re-iteration wraps cleanly instead of
# getting stuck re-serving the same partial window after a mid-epoch resume.
active_total_samples = (self.total_samples // self._global_batch_size) * self._global_batch_size
if active_total_samples == 0:
# Dataset smaller than a single global batch (only reachable with
# drop_last=False); fall back to the full dataset for accounting.
active_total_samples = self.total_samples

epoch = self.consumed_samples // active_total_samples
current_epoch_samples = self.consumed_samples % active_total_samples

if self.shuffle:
# Deterministic, seed- and epoch-derived permutation so a resumed run
# reproduces the same per-epoch order as an uninterrupted run.
g = torch.Generator()
g.manual_seed(self.seed + epoch)
idx_order = torch.randperm(self.total_samples, generator=g).tolist()
else:
idx_order = list(range(self.total_samples))

# Skip samples already consumed within the current epoch (mid-epoch resume).
idx_order = idx_order[current_epoch_samples:]

batch = []
# Last batch will be dropped if drop_last is True
for idx in range(self.consumed_samples % self.total_samples, self.total_samples):
for idx in idx_order:
batch.append(idx)
if len(batch) == self._global_batch_size:
# Distribute indices in interleaved fashion across ranks
Expand All @@ -293,6 +345,10 @@ def __iter__(self) -> Iterator[list[int]]:
]
assert len(all_indices) == self._global_batch_size_on_this_data_parallel_rank

# Advance so the next re-iteration starts a fresh epoch (offset 0) rather
# than re-applying the resume offset and replaying the tail.
self.consumed_samples += self._global_batch_size

# Yield ALL indices at once (not split into microbatches)
# The training loop will handle splitting after collation
yield all_indices
Expand Down
138 changes: 138 additions & 0 deletions tests/functional_tests/test_groups/data/test_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,144 @@ def test_batch_sampler_multiple_data_parallel_ranks(self):
all_indices_sorted = sorted(all_indices)
assert all_indices_sorted == list(range(32))

def test_batch_sampler_resume_serves_full_epoch_after_first_cycle(self):
"""Resuming mid-epoch must serve the remaining tail first, then full epochs.

Regression for the resume bug where the sampler froze ``consumed_samples`` and
re-applied the resume offset on every ``cyclic_iter`` re-iteration, permanently
skipping the head of each post-resume epoch and replaying only the tail (which
repeatedly re-trained the model on the same samples and depressed the loss).
"""
from megatron.bridge.data.samplers import MegatronPretrainingBatchSampler

# 32 samples, resume after consuming 8 (mid first epoch). global_batch_size=4.
sampler = MegatronPretrainingBatchSampler(
total_samples=32,
consumed_samples=8,
micro_batch_size=1,
global_batch_size=4,
data_parallel_rank=0,
data_parallel_size=1,
drop_last=True,
)

first_cycle = [idx for batch in sampler for idx in batch]
second_cycle = [idx for batch in sampler for idx in batch]

# First (resumed) cycle serves only the remaining tail [8, 32).
assert sorted(first_cycle) == list(range(8, 32))
# Second cycle is a full epoch starting from 0 — the head is no longer skipped.
assert sorted(second_cycle) == list(range(32))
# And it is not a replay of the resumed tail.
assert second_cycle != first_cycle

def test_batch_sampler_epoch_boundary_resume_serves_full_epoch(self):
"""Resuming exactly on an epoch boundary serves full epochs (offset 0)."""
from megatron.bridge.data.samplers import MegatronPretrainingBatchSampler

sampler = MegatronPretrainingBatchSampler(
total_samples=32,
consumed_samples=32, # exactly one epoch consumed
micro_batch_size=1,
global_batch_size=4,
data_parallel_rank=0,
data_parallel_size=1,
drop_last=True,
)

assert sorted(idx for batch in sampler for idx in batch) == list(range(32))

def test_batch_sampler_shuffle_reshuffles_each_epoch(self):
"""With shuffle enabled, each epoch is a distinct permutation of all samples."""
from megatron.bridge.data.samplers import MegatronPretrainingBatchSampler

sampler = MegatronPretrainingBatchSampler(
total_samples=16,
consumed_samples=0,
micro_batch_size=1,
global_batch_size=4,
data_parallel_rank=0,
data_parallel_size=1,
drop_last=True,
shuffle=True,
seed=1234,
)

epoch0 = [idx for batch in sampler for idx in batch]
epoch1 = [idx for batch in sampler for idx in batch]

# Each epoch covers every sample exactly once...
assert sorted(epoch0) == list(range(16))
assert sorted(epoch1) == list(range(16))
# ...but the order differs between epochs (reshuffled).
assert epoch0 != epoch1

def test_batch_sampler_shuffle_is_deterministic_for_seed(self):
"""A given (seed, epoch) yields a reproducible order (resume-safe)."""
from megatron.bridge.data.samplers import MegatronPretrainingBatchSampler

def first_epoch(seed):
sampler = MegatronPretrainingBatchSampler(
total_samples=16,
consumed_samples=0,
micro_batch_size=1,
global_batch_size=4,
data_parallel_rank=0,
data_parallel_size=1,
drop_last=True,
shuffle=True,
seed=seed,
)
return [idx for batch in sampler for idx in batch]

assert first_epoch(1234) == first_epoch(1234)
assert first_epoch(1234) != first_epoch(5678)

def test_batch_sampler_shuffle_resume_parity(self):
"""A shuffled resume reproduces the same per-epoch order as an uninterrupted run."""
from megatron.bridge.data.samplers import MegatronPretrainingBatchSampler

def make(consumed):
return MegatronPretrainingBatchSampler(
total_samples=16,
consumed_samples=consumed,
micro_batch_size=1,
global_batch_size=4,
data_parallel_rank=0,
data_parallel_size=1,
drop_last=True,
shuffle=True,
seed=1234,
)

# Uninterrupted: collect epoch 0 then epoch 1.
uninterrupted = make(0)
_ = [idx for batch in uninterrupted for idx in batch] # epoch 0
epoch1 = [idx for batch in uninterrupted for idx in batch] # epoch 1

# Resume after consuming all of epoch 0 (16) plus 8 of epoch 1.
resumed = make(24)
resumed_indices = [idx for batch in resumed for idx in batch]

# The resumed run serves exactly the not-yet-consumed tail of epoch 1.
assert resumed_indices == epoch1[8:]

def test_batch_sampler_shuffle_disabled_by_default(self):
"""Default behavior stays sequential (no shuffle) for backward compatibility."""
from megatron.bridge.data.samplers import MegatronPretrainingBatchSampler

sampler = MegatronPretrainingBatchSampler(
total_samples=16,
consumed_samples=0,
micro_batch_size=1,
global_batch_size=4,
data_parallel_rank=0,
data_parallel_size=1,
drop_last=True,
)

assert [idx for batch in sampler for idx in batch] == list(range(16))


class TestFinetuningUtilities:
"""Tests for finetuning data handling utilities."""
Expand Down