diff --git a/tests/distributed/_chunked_nll_allgather_worker.py b/tests/distributed/_chunked_nll_allgather_worker.py new file mode 100644 index 00000000000..8dac7594e25 --- /dev/null +++ b/tests/distributed/_chunked_nll_allgather_worker.py @@ -0,0 +1,187 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Companion worker for the chunked-NLL FSDP2 all-gather perf-regression test. + +Launched under ``accelerate launch --config_file `` by ``test_distributed.py``. It runs a single SFT +``chunked_nll`` training step on an FSDP2-sharded tiny model and counts how many all-gather collectives occur during +that step, so a regression that re-gathers ``lm_head.weight`` once per token chunk (the PR #6077 failure mode — correct +loss, silently slow) is caught by a bounded assertion. + +``_chunked_cross_entropy_loss`` chunks over *valid tokens* (``for start in range(0, n_valid, chunk_size)``), so the +regression scales with ``ceil(n_valid / chunk_size)`` and only shows up when more than one token chunk runs. The zen +test data is tiny, so this worker shrinks the chunk size (see ``_TEST_CHUNK_SIZE``) to force many token chunks, and +derives the regression threshold from the exact ``n_valid`` captured from inside the loss path — never from vocab size. + +Why count real collectives and not ``DTensor.full_tensor()``: under FSDP2 the parameter unshard is driven by autograd +pre-hooks / c10d collectives, not by explicit ``full_tensor()`` calls, so a ``full_tensor`` counter is blind to it. We +use ``CommDebugMode`` (``torch.distributed.tensor.debug``) — torch's purpose-built, DTensor-native comm counter, which +records ``funcol.all_gather_into_tensor`` and the ``c10d`` ``_allgather_base_`` / ``allgather_`` variants that FSDP2 +emits. (An earlier version also ran a hand-rolled ``TorchDispatchMode``, but re-dispatching sharded ops from a custom +mode under FSDP2 mismatches the index/weight devices on the embedding lookup, so we rely on ``CommDebugMode`` alone.) + +Prints one machine-parseable line ``CHUNKED_NLL_ALLGATHER_RESULT {json}`` that the pytest side asserts on. +Self-contained on purpose: it imports only public TRL symbols and runs as ``__main__`` under ``accelerate launch``. +""" + +from __future__ import annotations + +import json +import math +import tempfile + +from datasets import load_dataset + +from trl import SFTConfig, SFTTrainer + + +MODEL_ID = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" +RESULT_PREFIX = "CHUNKED_NLL_ALLGATHER_RESULT" + + +def _count_all_gathers(comm_counts: dict) -> int: + """Sum the all-gather collectives from a ``CommDebugMode.get_comm_counts()`` dict. + + ``CommDebugMode`` keys the dict by the comm op (funcol / c10d). FSDP2's parameter unshard shows up as + ``funcol.all_gather_into_tensor`` (and the ``_allgather_base_`` / ``allgather_`` c10d variants), so we match on the + op's string name containing ``all_gather`` / ``allgather`` and total those. This is the DTensor-native counter — it + observes the autograd-hook-driven gathers that ``DTensor.full_tensor()`` is blind to, without a hand-rolled + ``TorchDispatchMode`` (which mis-dispatches sharded ops under FSDP2). + """ + total = 0 + for op, n in comm_counts.items(): + name = str(op).lower() + if "all_gather" in name or "allgather" in name: + total += int(n) + return total + + +class _MeasuringSFTTrainer(SFTTrainer): + """SFTTrainer that counts all-gather collectives during its first ``training_step``. + + The measurement must run *inside* the real ``trainer.train()`` loop, not by calling ``training_step`` directly: + under ``fsdp_cpu_ram_efficient_loading`` the model is on CPU/meta until ``_inner_training_loop`` FSDP-wraps it and + moves it to GPU. Calling ``training_step`` on ``trainer.model`` before ``train()`` runs the embedding lookup with a + CPU weight against a CUDA input → device-mismatch crash. Overriding ``training_step`` lets the trainer do all + wrapping/placement, while we wrap the (single, since ``max_steps=1``) step in ``CommDebugMode`` to tally the FSDP2 + unshard collectives. + """ + + comm_counts: dict | None = None + + def training_step(self, *args, **kwargs): + from torch.distributed.tensor.debug import CommDebugMode + + # Only measure the first step (with max_steps=1 there is exactly one); guard anyway so the counts + # reflect a single step even if the caller raises max_steps later. + if self.comm_counts is not None: + return super().training_step(*args, **kwargs) + comm_mode = CommDebugMode() + with comm_mode: + loss = super().training_step(*args, **kwargs) + self.comm_counts = comm_mode.get_comm_counts() + return loss + + +# The chunked-CE loop chunks over *valid tokens*, not vocab: `for start in range(0, n_valid, chunk_size)` +# in `_chunked_cross_entropy_loss`. So a per-chunk `lm_head.weight` re-gather regression scales with +# ceil(n_valid / chunk_size) — the TOKEN-chunk count — and is only observable when more than one chunk +# runs (n_valid > chunk_size). The zen test data is tiny (~120 valid tokens total), so with the default +# chunk size of 256 only a single chunk would run and a regression would be invisible. We therefore shrink +# the chunk size for this test so the tiny batch genuinely exercises many token-chunks. +_TEST_CHUNK_SIZE = 4 + + +def main() -> None: + import trl.trainer.sft_trainer as sft + + # Shrink the chunk size BEFORE the trainer patches the lm_head (it reads this module constant at + # construction). With ~120 valid tokens this yields ~30 token-chunks, so a per-chunk re-gather + # regression would do ~30 lm_head all-gathers vs O(1) for the fixed path — a wide, detectable margin. + sft._CHUNKED_LM_HEAD_CHUNK_SIZE = _TEST_CHUNK_SIZE + + # Capture the real valid-token count from inside the chunked-CE path, so the regression threshold is + # derived from the exact n_valid the loop iterates over (never guessed from token lengths). + captured = {} + _orig_cce = sft._chunked_cross_entropy_loss + + def _capturing_cce(hidden_states, lm_head_weight, chunk_size, *args, **kwargs): + out = _orig_cce(hidden_states, lm_head_weight, chunk_size, *args, **kwargs) + # Returns (loss, correct, entropy_sum, n_valid_tensor); n_valid is the 4th element. + captured["n_valid"] = int(out[3].item()) + captured["chunk_size"] = int(chunk_size) + return out + + sft._chunked_cross_entropy_loss = _capturing_cce + + dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") + # Write trainer artifacts to a throwaway temp dir so the worker leaves no state in the repo checkout and + # repeated runs can't collide. tempfile keeps this self-contained (no reliance on the launch cwd). + tmp_out = tempfile.mkdtemp(prefix="chunked_nll_fsdp2_") + args = SFTConfig( + output_dir=tmp_out, + loss_type="chunked_nll", + # Pack as many of the tiny examples into the single measured step as possible, so n_valid is well + # above the (shrunk) chunk size and the token-chunk count is large. + per_device_train_batch_size=8, + max_length=64, + max_steps=1, + report_to="none", + bf16=True, + ) + trainer = _MeasuringSFTTrainer(model=MODEL_ID, args=args, train_dataset=dataset) + + vocab_size = trainer.model.config.vocab_size + + # Run the real training loop: it FSDP-wraps the model and moves it to GPU, then calls training_step + # once (max_steps=1), which our subclass measures under CommDebugMode. + trainer.train() + + comm_counts = trainer.comm_counts or {} + all_gathers = _count_all_gathers(comm_counts) + comm_total = sum(int(n) for n in comm_counts.values()) + + n_valid = captured.get("n_valid", 0) + chunk_size = captured.get("chunk_size", _TEST_CHUNK_SIZE) + n_chunks = -(-n_valid // chunk_size) if n_valid else 0 # ceil(n_valid / chunk_size) — TOKEN chunks + + last = trainer.state.log_history[-1] if trainer.state.log_history else {} + train_loss = last.get("train_loss") + + result = { + "vocab_size": int(vocab_size), + "n_valid": int(n_valid), + "chunk_size": int(chunk_size), + "n_chunks_if_regressed": int(n_chunks), + "all_gathers": int(all_gathers), + "commdebug_total": int(comm_total), + "loss_finite": train_loss is not None and math.isfinite(train_loss), + } + if trainer.accelerator.is_main_process: + print(f"{RESULT_PREFIX} {json.dumps(result)}", flush=True) # noqa: T201 - result channel for the launcher + + +if __name__ == "__main__": + # Print the full traceback from this worker directly: when `accelerate launch` re-raises a child + # failure, the parent only sees a truncated `CompletedProcess` repr, which hides the real error frame. + # Surfacing it here puts the complete traceback in the worker's own stderr (and thus the CI log). + import sys + import traceback + + try: + main() + except Exception: + traceback.print_exc() + sys.stderr.flush() + raise diff --git a/tests/distributed/data/accelerate_configs/fsdp2_reshard.yaml b/tests/distributed/data/accelerate_configs/fsdp2_reshard.yaml new file mode 100644 index 00000000000..13a31f2cf2b --- /dev/null +++ b/tests/distributed/data/accelerate_configs/fsdp2_reshard.yaml @@ -0,0 +1,31 @@ +# 2-process FSDP2 config for the chunked-NLL all-gather perf-regression test. +# +# `fsdp_reshard_after_forward: true` is the load-bearing setting: it is the condition under which PR +# #6077's per-chunk `lm_head.weight` all-gather manifested (the param is resharded after forward, then +# re-gathered during the chunked backward). The minimal `fsdp2.yaml` leaves this at FSDP's default, so a +# dedicated config that states it explicitly keeps the collective count meaningful. Mirrors +# `examples/accelerate_configs/fsdp2.yaml` with `num_processes: 2` for a 2-GPU node. +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: FSDP +downcast_bf16: 'no' +enable_cpu_affinity: false +fsdp_config: + fsdp_activation_checkpointing: false + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_cpu_ram_efficient_loading: true + fsdp_offload_params: false + fsdp_reshard_after_forward: true + fsdp_state_dict_type: FULL_STATE_DICT + fsdp_version: 2 +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 2 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/tests/distributed/test_distributed.py b/tests/distributed/test_distributed.py index 2e99bf0f4ea..c5caa9a6142 100644 --- a/tests/distributed/test_distributed.py +++ b/tests/distributed/test_distributed.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import os import subprocess from pathlib import Path @@ -376,3 +377,61 @@ def test_grpo_liger(self, config, get_config_path): os.environ.copy(), ) # fmt: on + + def test_sft_chunked_nll_fsdp2_no_per_chunk_allgather(self, lazy_shared_datadir): + # Perf-regression guard for the PR #6077 class: a chunked cross-entropy path must NOT re-gather the + # sharded `lm_head.weight` once per token chunk under FSDP2 (correct loss, silently slow, invisible + # to a pass/fail test). The companion worker runs one SFT `chunked_nll` step under a 2-process FSDP2 + # group (reshard_after_forward=True — the condition that triggers the bug) and counts the all-gather + # collectives during that step; here we assert the count stays O(1), not O(n_valid / chunk_size). + # + # `_chunked_cross_entropy_loss` chunks over VALID TOKENS, not vocab, so the regression scales with + # ceil(n_valid / chunk_size) and only manifests when more than one token chunk runs. The worker + # shrinks the chunk size so the tiny zen batch exercises many token chunks, and reports the exact + # n_valid / chunk_size it measured so this side can both bound the count and confirm the test is + # non-vacuous (n_chunks_if_regressed > 1 — otherwise a regression could never have been observed). + # + # Counting real collectives is required: under FSDP2 the parameter unshard is driven by autograd + # hooks / c10d collectives, not by `DTensor.full_tensor()`, so the worker counts the actual + # all-gather collectives via CommDebugMode (torch's DTensor-native comm counter) for the step. + worker = Path(__file__).parent / "_chunked_nll_allgather_worker.py" + config_path = lazy_shared_datadir / "accelerate_configs" / "fsdp2_reshard.yaml" + # Pin the repo root onto PYTHONPATH for the child: `accelerate launch` re-execs each rank via + # torch.distributed.elastic, which sets sys.path[0] to the launched script's directory, not cwd. + # Without this, a non-editable `trl` already in site-packages would shadow the working tree. + env = os.environ.copy() + env["PYTHONPATH"] = os.pathsep.join([str(ROOT), env.get("PYTHONPATH", "")]).rstrip(os.pathsep) + result = subprocess.run( + ["accelerate", "launch", "--config_file", str(config_path), str(worker)], + env=env, + cwd=ROOT, + capture_output=True, + text=True, + ) + assert result.returncode == 0, f"worker failed:\nSTDOUT:\n{result.stdout}\nSTDERR:\n{result.stderr}" + + prefix = "CHUNKED_NLL_ALLGATHER_RESULT" + lines = [ln for ln in result.stdout.splitlines() if ln.startswith(prefix)] + assert len(lines) == 1, f"expected exactly one result line, got {lines}\n{result.stdout}" + measured = json.loads(lines[0][len(prefix) :].strip()) + + assert measured["loss_finite"], f"chunked_nll loss not finite under FSDP2: {measured}" + # Non-vacuity guard (the heart of this test): a per-token-chunk regression can only be detected if the + # step actually ran multiple token chunks. If only one chunk ran, a regression would gather exactly + # once too, so the test would pass for the wrong reason. Require a comfortably multi-chunk run. + assert measured["n_chunks_if_regressed"] > 4, ( + f"test is vacuous — only {measured['n_chunks_if_regressed']} token chunk(s) ran, so a per-chunk " + f"regression could not be observed; increase batch/length or shrink chunk_size: {measured}" + ) + # A per-token-chunk-regather regression would do ~n_chunks all-gathers of lm_head.weight in the step; + # the fixed path does O(1). `all_gathers` is the total all-gather collective count for the step, + # measured by CommDebugMode (the DTensor-native counter that sees FSDP2's autograd-hook-driven + # gathers). It legitimately includes one gather per sharded parameter (a handful of decoder layers), + # so bound it well below the regression count rather than at exactly 1. The ceiling scales off + # n_chunks (never a hardcoded collective count) so it tracks the model's token/chunk arithmetic. + observed = measured["all_gathers"] + ceiling = max(16, measured["n_chunks_if_regressed"] // 4) + assert observed < measured["n_chunks_if_regressed"], ( + f"per-chunk lm_head.weight all-gathers detected (#6077 regression): {measured}" + ) + assert observed <= ceiling, f"unexpectedly many all-gathers (possible regression): {measured}"