diff --git a/src/megatron/bridge/data/samplers.py b/src/megatron/bridge/data/samplers.py index 137a469e8d..1c7670515d 100644 --- a/src/megatron/bridge/data/samplers.py +++ b/src/megatron/bridge/data/samplers.py @@ -198,6 +198,13 @@ class MegatronPretrainingBatchSampler: are padded to the same length, which is critical for fine-tuning with variable sequence lengths. + The sampler is epoch-aware: it advances ``consumed_samples`` as it yields so that + when it is re-iterated (e.g. wrapped in ``cyclic_iter`` for multi-epoch + fine-tuning), each new epoch starts from the beginning of the dataset instead of + permanently re-applying the resume offset. When ``shuffle`` is enabled the per-epoch + order is a deterministic, ``seed``- and epoch-derived permutation, so resuming from a + checkpoint reproduces the same order as an uninterrupted run. + Args: total_samples: Total number of samples in the dataset. consumed_samples: Number of samples already consumed (for resuming). @@ -207,6 +214,10 @@ class MegatronPretrainingBatchSampler: data_parallel_size: Total number of GPUs in the data parallel group. drop_last: If True, drops the last incomplete batch. pad_samples_to_global_batch_size: If True, pads incomplete batches with -1 indices. + shuffle: If True, reshuffle the sample order every epoch using a deterministic, + seed- and epoch-derived permutation. Defaults to True for fine-tuning. + seed: Base random seed for the per-epoch shuffle. Defaults to the current + torch global seed. Only used when ``shuffle`` is True. """ def __init__( @@ -219,6 +230,8 @@ def __init__( data_parallel_size: int, drop_last: bool = True, pad_samples_to_global_batch_size: bool = False, + shuffle: bool = True, + seed: int | None = None, ) -> None: self.total_samples = total_samples self.consumed_samples = consumed_samples @@ -227,6 +240,8 @@ def __init__( self.data_parallel_size = data_parallel_size self.drop_last = drop_last self.pad_samples_to_global_batch_size = pad_samples_to_global_batch_size + self.shuffle = shuffle + self.seed = int(torch.initial_seed() if seed is None else seed) self.micro_batch_times_data_parallel_size = self.micro_batch_size * data_parallel_size assert self.total_samples > 0, "no sample to consume: {}".format(self.total_samples) @@ -255,7 +270,9 @@ def __len__(self) -> int: Since we now yield the full global batch at once (not split into microbatches), this returns the number of global batches. """ - num_available_samples = self.total_samples - self.consumed_samples % self.total_samples + active_total_samples = self._active_total_samples() + current_epoch_samples = self.consumed_samples % active_total_samples + num_available_samples = active_total_samples - current_epoch_samples if self.drop_last: num_global_batches = num_available_samples // self._global_batch_size else: @@ -264,6 +281,23 @@ def __len__(self) -> int: # Each call to __iter__ yields one global batch return num_global_batches + def _active_total_samples(self) -> int: + """Return the sample-count unit used for epoch/offset accounting.""" + if self.drop_last: + active_total_samples = (self.total_samples // self._global_batch_size) * self._global_batch_size + assert active_total_samples > 0, ( + "drop_last=True requires at least one full global batch; " + f"got total_samples={self.total_samples}, global_batch_size={self._global_batch_size}" + ) + return active_total_samples + + if self.pad_samples_to_global_batch_size: + return ( + (self.total_samples + self._global_batch_size - 1) // self._global_batch_size + ) * self._global_batch_size + + return self.total_samples + def __iter__(self) -> Iterator[list[int]]: """Yields lists of indices for the full global batch assigned to this rank. @@ -276,10 +310,33 @@ def __iter__(self) -> Iterator[list[int]]: 1. Compute max_length across the entire global batch 2. Pad all samples to the same length 3. Then split into microbatches with consistent sequence length + + The sampler is epoch-aware. It derives the current epoch and the within-epoch + offset from ``consumed_samples`` and advances ``consumed_samples`` as it yields. + When wrapped in ``cyclic_iter`` for multi-epoch fine-tuning, each re-iteration + therefore starts a fresh epoch (offset 0) rather than repeatedly re-applying the + resume offset, which previously caused the head of every post-resume epoch to be + skipped and the tail to be replayed. """ + active_total_samples = self._active_total_samples() + epoch = self.consumed_samples // active_total_samples + current_epoch_samples = self.consumed_samples % active_total_samples + + if self.shuffle: + # Deterministic, seed- and epoch-derived permutation so a resumed run + # reproduces the same per-epoch order as an uninterrupted run. + g = torch.Generator() + g.manual_seed(self.seed + epoch) + idx_order = torch.randperm(self.total_samples, generator=g).tolist() + else: + idx_order = list(range(self.total_samples)) + + # Skip samples already consumed within the current epoch (mid-epoch resume). + idx_order = idx_order[current_epoch_samples:] + batch = [] # Last batch will be dropped if drop_last is True - for idx in range(self.consumed_samples % self.total_samples, self.total_samples): + for idx in idx_order: batch.append(idx) if len(batch) == self._global_batch_size: # Distribute indices in interleaved fashion across ranks @@ -293,6 +350,10 @@ def __iter__(self) -> Iterator[list[int]]: ] assert len(all_indices) == self._global_batch_size_on_this_data_parallel_rank + # Advance so the next re-iteration starts a fresh epoch (offset 0) rather + # than re-applying the resume offset and replaying the tail. + self.consumed_samples += self._global_batch_size + # Yield ALL indices at once (not split into microbatches) # The training loop will handle splitting after collation yield all_indices @@ -307,6 +368,11 @@ def __iter__(self) -> Iterator[list[int]]: num_pad = self._global_batch_size // self.data_parallel_size - len(all_indices) all_indices = all_indices + [-1] * num_pad + if self.pad_samples_to_global_batch_size: + self.consumed_samples += self._global_batch_size + else: + self.consumed_samples += len(batch) + # Yield ALL indices at once yield all_indices diff --git a/tests/functional_tests/test_groups/data/test_samplers.py b/tests/functional_tests/test_groups/data/test_samplers.py index 894a971d6f..dac35ab117 100644 --- a/tests/functional_tests/test_groups/data/test_samplers.py +++ b/tests/functional_tests/test_groups/data/test_samplers.py @@ -255,6 +255,7 @@ def test_batch_sampler_interleaved_distribution(self): data_parallel_rank=0, data_parallel_size=2, drop_last=True, + shuffle=False, ) # Simulate rank 1 @@ -266,6 +267,7 @@ def test_batch_sampler_interleaved_distribution(self): data_parallel_rank=1, data_parallel_size=2, drop_last=True, + shuffle=False, ) # Get indices from both ranks @@ -296,6 +298,7 @@ def test_batch_sampler_consumed_samples(self): data_parallel_rank=0, data_parallel_size=2, drop_last=True, + shuffle=False, ) batches = list(sampler) @@ -319,6 +322,7 @@ def test_batch_sampler_incomplete_batch_drop_last_true(self): data_parallel_rank=0, data_parallel_size=2, drop_last=True, + shuffle=False, ) batches = list(sampler) @@ -345,6 +349,7 @@ def test_batch_sampler_incomplete_batch_drop_last_false(self): data_parallel_size=2, drop_last=False, pad_samples_to_global_batch_size=False, + shuffle=False, ) batches = list(sampler) @@ -372,6 +377,7 @@ def test_batch_sampler_incomplete_batch_with_padding(self): data_parallel_size=2, drop_last=False, pad_samples_to_global_batch_size=True, + shuffle=False, ) batches = list(sampler) @@ -438,6 +444,274 @@ def test_batch_sampler_multiple_data_parallel_ranks(self): all_indices_sorted = sorted(all_indices) assert all_indices_sorted == list(range(32)) + def test_batch_sampler_resume_serves_full_epoch_after_first_cycle(self): + """Resuming mid-epoch must serve the remaining tail first, then full epochs. + + Regression for the resume bug where the sampler froze ``consumed_samples`` and + re-applied the resume offset on every ``cyclic_iter`` re-iteration, permanently + skipping the head of each post-resume epoch and replaying only the tail (which + repeatedly re-trained the model on the same samples and depressed the loss). + """ + from megatron.bridge.data.samplers import MegatronPretrainingBatchSampler + + # 32 samples, resume after consuming 8 (mid first epoch). global_batch_size=4. + sampler = MegatronPretrainingBatchSampler( + total_samples=32, + consumed_samples=8, + micro_batch_size=1, + global_batch_size=4, + data_parallel_rank=0, + data_parallel_size=1, + drop_last=True, + shuffle=False, + ) + + first_cycle = [idx for batch in sampler for idx in batch] + second_cycle = [idx for batch in sampler for idx in batch] + + # First (resumed) cycle serves only the remaining tail [8, 32). + assert sorted(first_cycle) == list(range(8, 32)) + # Second cycle is a full epoch starting from 0 — the head is no longer skipped. + assert sorted(second_cycle) == list(range(32)) + # And it is not a replay of the resumed tail. + assert second_cycle != first_cycle + + def test_batch_sampler_len_stays_epoch_aware_after_non_divisible_epoch(self): + """Length should use the same active epoch size as iteration. + + When ``total_samples`` is not divisible by ``global_batch_size``, the active + epoch for ``drop_last=True`` excludes the incomplete tail. After one full + active epoch, length should report the next full active epoch rather than the + stale tail from ``total_samples`` modulo arithmetic. + """ + from megatron.bridge.data.samplers import MegatronPretrainingBatchSampler + + sampler = MegatronPretrainingBatchSampler( + total_samples=10, + consumed_samples=0, + micro_batch_size=1, + global_batch_size=4, + data_parallel_rank=0, + data_parallel_size=1, + drop_last=True, + shuffle=False, + ) + + assert len(sampler) == 2 + assert [idx for batch in sampler for idx in batch] == list(range(8)) + assert len(sampler) == 2 + assert [idx for batch in sampler for idx in batch] == list(range(8)) + + def test_batch_sampler_drop_last_false_resume_serves_partial_then_full_epoch(self): + """Resuming before a padded partial batch should not skip that partial batch.""" + from megatron.bridge.data.samplers import MegatronPretrainingBatchSampler + + sampler = MegatronPretrainingBatchSampler( + total_samples=10, + consumed_samples=8, + micro_batch_size=1, + global_batch_size=4, + data_parallel_rank=0, + data_parallel_size=1, + drop_last=False, + pad_samples_to_global_batch_size=True, + shuffle=False, + ) + + first_cycle = [batch for batch in sampler] + second_cycle = [batch for batch in sampler] + + assert first_cycle == [[8, 9, -1, -1]] + assert second_cycle == [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, -1, -1]] + + def test_batch_sampler_drop_last_false_small_dataset_cycles_after_padding(self): + """A padded epoch smaller than one global batch should restart cleanly.""" + from megatron.bridge.data.samplers import MegatronPretrainingBatchSampler + + sampler = MegatronPretrainingBatchSampler( + total_samples=3, + consumed_samples=0, + micro_batch_size=1, + global_batch_size=4, + data_parallel_rank=0, + data_parallel_size=1, + drop_last=False, + pad_samples_to_global_batch_size=True, + shuffle=False, + ) + + assert [batch for batch in sampler] == [[0, 1, 2, -1]] + assert [batch for batch in sampler] == [[0, 1, 2, -1]] + + def test_batch_sampler_drop_last_false_unpadded_partial_advances_epoch(self): + """Unpadded direct users should not replay the partial tail on re-iteration.""" + from megatron.bridge.data.samplers import MegatronPretrainingBatchSampler + + def make(consumed): + return MegatronPretrainingBatchSampler( + total_samples=10, + consumed_samples=consumed, + micro_batch_size=1, + global_batch_size=4, + data_parallel_rank=0, + data_parallel_size=1, + drop_last=False, + pad_samples_to_global_batch_size=False, + shuffle=False, + ) + + sampler = make(0) + assert [batch for batch in sampler] == [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9]] + assert [batch for batch in sampler] == [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9]] + assert [batch for batch in make(10)] == [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9]] + + def test_batch_sampler_epoch_boundary_resume_serves_full_epoch(self): + """Resuming exactly on an epoch boundary serves full epochs (offset 0).""" + from megatron.bridge.data.samplers import MegatronPretrainingBatchSampler + + sampler = MegatronPretrainingBatchSampler( + total_samples=32, + consumed_samples=32, # exactly one epoch consumed + micro_batch_size=1, + global_batch_size=4, + data_parallel_rank=0, + data_parallel_size=1, + drop_last=True, + ) + + assert sorted(idx for batch in sampler for idx in batch) == list(range(32)) + + def test_batch_sampler_shuffle_reshuffles_each_epoch(self): + """With shuffle enabled, each epoch is a distinct permutation of all samples.""" + from megatron.bridge.data.samplers import MegatronPretrainingBatchSampler + + sampler = MegatronPretrainingBatchSampler( + total_samples=16, + consumed_samples=0, + micro_batch_size=1, + global_batch_size=4, + data_parallel_rank=0, + data_parallel_size=1, + drop_last=True, + shuffle=True, + seed=1234, + ) + + epoch0 = [idx for batch in sampler for idx in batch] + epoch1 = [idx for batch in sampler for idx in batch] + + # Each epoch covers every sample exactly once... + assert sorted(epoch0) == list(range(16)) + assert sorted(epoch1) == list(range(16)) + # ...but the order differs between epochs (reshuffled). + assert epoch0 != epoch1 + + def test_batch_sampler_shuffle_is_deterministic_for_seed(self): + """A given (seed, epoch) yields a reproducible order (resume-safe).""" + from megatron.bridge.data.samplers import MegatronPretrainingBatchSampler + + def first_epoch(seed): + sampler = MegatronPretrainingBatchSampler( + total_samples=16, + consumed_samples=0, + micro_batch_size=1, + global_batch_size=4, + data_parallel_rank=0, + data_parallel_size=1, + drop_last=True, + shuffle=True, + seed=seed, + ) + return [idx for batch in sampler for idx in batch] + + assert first_epoch(1234) == first_epoch(1234) + assert first_epoch(1234) != first_epoch(5678) + + def test_batch_sampler_shuffle_resume_parity(self): + """A shuffled resume reproduces the same per-epoch order as an uninterrupted run.""" + from megatron.bridge.data.samplers import MegatronPretrainingBatchSampler + + def make(consumed): + return MegatronPretrainingBatchSampler( + total_samples=16, + consumed_samples=consumed, + micro_batch_size=1, + global_batch_size=4, + data_parallel_rank=0, + data_parallel_size=1, + drop_last=True, + shuffle=True, + seed=1234, + ) + + # Uninterrupted: collect epoch 0 then epoch 1. + uninterrupted = make(0) + _ = [idx for batch in uninterrupted for idx in batch] # epoch 0 + epoch1 = [idx for batch in uninterrupted for idx in batch] # epoch 1 + + # Resume after consuming all of epoch 0 (16) plus 8 of epoch 1. + resumed = make(24) + resumed_indices = [idx for batch in resumed for idx in batch] + + # The resumed run serves exactly the not-yet-consumed tail of epoch 1. + assert resumed_indices == epoch1[8:] + + def test_batch_sampler_shuffle_resume_parity_non_divisible_drop_last(self): + """Shuffle resume parity should use the active full-batch epoch size.""" + from megatron.bridge.data.samplers import MegatronPretrainingBatchSampler + + def make(consumed): + return MegatronPretrainingBatchSampler( + total_samples=10, + consumed_samples=consumed, + micro_batch_size=1, + global_batch_size=4, + data_parallel_rank=0, + data_parallel_size=1, + drop_last=True, + shuffle=True, + seed=1234, + ) + + uninterrupted = make(0) + epoch0 = [idx for batch in uninterrupted for idx in batch] + epoch1 = [idx for batch in uninterrupted for idx in batch] + + assert len(epoch0) == 8 + assert sorted(epoch0) == sorted(set(epoch0)) + assert len(epoch1) == 8 + assert sorted(epoch1) == sorted(set(epoch1)) + + resumed_epoch0 = [idx for batch in make(4) for idx in batch] + resumed_epoch1 = [idx for batch in make(12) for idx in batch] + + assert resumed_epoch0 == epoch0[4:] + assert resumed_epoch1 == epoch1[4:] + + def test_batch_sampler_shuffle_enabled_by_default_uses_global_seed(self): + """Default batch-sampler behavior reshuffles deterministically from the global seed.""" + import torch + + from megatron.bridge.data.samplers import MegatronPretrainingBatchSampler + + def first_epoch(): + torch.manual_seed(1234) + sampler = MegatronPretrainingBatchSampler( + total_samples=16, + consumed_samples=0, + micro_batch_size=1, + global_batch_size=4, + data_parallel_rank=0, + data_parallel_size=1, + drop_last=True, + ) + return [idx for batch in sampler for idx in batch] + + epoch = first_epoch() + assert epoch == first_epoch() + assert sorted(epoch) == list(range(16)) + assert epoch != list(range(16)) + class TestBatchUtilities: """Tests for batch handling utilities."""