test: guard against per-chunk lm_head all-gather in chunked_nll under…#6172
test: guard against per-chunk lm_head all-gather in chunked_nll under…#6172behroozazarkhalili wants to merge 2 commits into
Conversation
… FSDP2 The chunked cross-entropy path (`loss_type="chunked_nll"`) computes the LM loss in vocab chunks. Under FSDP2 with reshard_after_forward, a regression can re-gather the sharded `lm_head.weight` once per vocab chunk during the backward — the loss stays correct but each step does O(vocab/chunk_size) extra all-gather collectives instead of O(1). This is invisible to a pass/fail test (the numbers are identical) and was the failure mode fixed in PR #6077. Add a 2-GPU distributed test that runs one chunked_nll step under FSDP2 and asserts the all-gather collective count stays bounded. The count is measured with CommDebugMode (torch's DTensor-native comm counter), which observes the autograd-hook-driven gathers that `DTensor.full_tensor()` cannot. The bound is derived from the model's vocab/chunk arithmetic, never hardcoded, so it tracks any model. Skips unless two accelerators are present.
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Pull request overview
Adds a 2-GPU distributed regression test to detect performance regressions where loss_type="chunked_nll" under FSDP2 (reshard_after_forward=True) re-introduces per-vocab-chunk lm_head.weight all-gathers (the #6077 failure mode), by measuring and bounding the number of all-gather collectives during a real trainer.train() step.
Changes:
- Adds a new distributed pytest that launches an
accelerate2-process job and asserts the measured all-gather count stays O(1) (vs. O(vocab/chunk_size) under regression). - Adds a companion worker script that runs a single SFT step under
CommDebugModeand prints a machine-parseable JSON result line. - Adds a dedicated
accelerateconfig explicitly enablingfsdp_reshard_after_forward: truefor the regression scenario.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
tests/distributed/test_distributed.py |
Adds a new distributed regression test that launches a worker and asserts all-gather collectives stay bounded. |
tests/distributed/data/accelerate_configs/fsdp2_reshard.yaml |
Introduces an explicit 2-process FSDP2 config with fsdp_reshard_after_forward: true for the perf regression guard. |
tests/distributed/_chunked_nll_allgather_worker.py |
Adds a standalone worker that measures all-gather collectives during a single SFT chunked_nll step via CommDebugMode. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| mode under FSDP2 mismatches the index/weight devices on the embedding lookup, so we rely on ``CommDebugMode`` alone.) | ||
|
|
||
| Prints one machine-parseable line ``CHUNKED_NLL_ALLGATHER_RESULT {json}`` that the pytest side asserts on. | ||
| Self-contained (mirrors ``tests/experimental/_async_grpo_fsdp2_worker.py``): imports only public symbols. |
| dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") | ||
| args = SFTConfig( | ||
| output_dir="chunked_nll_fsdp2_out", | ||
| loss_type="chunked_nll", | ||
| per_device_train_batch_size=2, | ||
| max_length=64, | ||
| max_steps=1, | ||
| report_to="none", | ||
| bf16=True, | ||
| ) |
|
Thanks, do you have a repro? |
Address review feedback on the chunked_nll FSDP2 all-gather guard: - Chunk over valid tokens, not vocab (Cursor Bugbot): `_chunked_cross_entropy_loss` iterates `range(0, n_valid, chunk_size)`, so a per-chunk re-gather regression scales with ceil(n_valid / chunk_size), not vocab. The worker now shrinks the chunk size and derives the regression threshold from the exact `n_valid` it measures, and the test asserts the run is non-vacuous (>1 token chunk) so a regression could actually be observed. - Drop the stale cross-file docstring reference and write trainer artifacts to a tempdir instead of the repo checkout (Copilot). - Print the worker's own full traceback on failure: `accelerate launch` only re-raises a truncated CompletedProcess repr, hiding the real error frame.
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes using default effort and found 1 potential issue.
❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
Reviewed by Cursor Bugbot for commit 20841b4. Configure here.
| ceiling = max(16, measured["n_chunks_if_regressed"] // 4) | ||
| assert observed < measured["n_chunks_if_regressed"], ( | ||
| f"per-chunk lm_head.weight all-gathers detected (#6077 regression): {measured}" | ||
| ) |
There was a problem hiding this comment.
Vacuity bound below gather baseline
Medium Severity
The regression check requires all_gathers to be strictly less than n_chunks_if_regressed, but the non-vacuity guard only requires more than four token chunks. On the fixed FSDP2 path, all_gathers stays roughly O(1) per sharded parameter (about ten in the PR’s run), so if n_chunks_if_regressed falls between that baseline and about eleven, the test fails even without a per-chunk regression.
Reviewed by Cursor Bugbot for commit 20841b4. Configure here.
|
Thanks for the patience here @qgallouedec — you asked for a repro, and building one surfaced two real issues with this test that I want to lay out honestly before we decide its fate. 1. The CI failure is a torch 2.12
|


What does this PR do?
Adds a 2-GPU distributed regression test that guards the
chunked_nllcross-entropy path against re-introducing a per-chunklm_head.weightall-gather under FSDP2 — the failure mode fixed in #6077.Why
loss_type="chunked_nll"computes the LM loss in vocab chunks. Under FSDP2 withreshard_after_forward, a regression can re-gather the shardedlm_head.weightonce per vocab chunk during the backward. The loss stays numerically correct, but each step performsO(vocab / chunk_size)extra all-gather collectives instead ofO(1).This is exactly the kind of bug a pass/fail test cannot catch — the produced numbers are identical whether the weight is gathered once or hundreds of times; only the collective count differs. So the regression is silent: correct results, quietly degraded throughput.
How
The test launches a companion worker under a 2-process FSDP2 group (
reshard_after_forward=True— the condition that triggers the bug), runs a singlechunked_nllstep, and asserts the all-gather collective count stays bounded.A few design points worth calling out:
full_tensor(). Under FSDP2 the parameter unshard is driven by autograd hooks / c10d collectives, not by explicitDTensor.full_tensor()calls, so afull_tensor-counting approach is blind to it. The count is measured withCommDebugMode(torch's DTensor-native comm counter), which recordsfuncol.all_gather_into_tensorand the c10d_allgather_base_/allgather_variants FSDP2 emits.trainer.train()(via atraining_stepoverride), because underfsdp_cpu_ram_efficient_loadingthe model is only FSDP-wrapped and moved to GPU inside the training loop — callingtraining_stepdirectly would run the embedding lookup with a CPU weight against a CUDA input.vocab_size / chunk_sizearithmetic, so it tracks any model.@require_torch_multi_accelerator).Verification
Run on 2× H100 (
trl-internal-testing/tiny-Qwen2ForCausalLM-2.5, vocab 151665, chunk 256):The fixed path does 10 all-gathers (O(1) — roughly one per sharded parameter); a per-chunk regression would do ~593 (= ⌈vocab/chunk⌉). The ~59× gap gives the assertion clear discriminating power.
Related to #6077 (the chunked-CE FSDP2 fix this test guards against re-breaking).
Before submitting
AI writing disclosure
We welcome the use of AI tools to help with contributions. For transparency and to help us improve our review process, please indicate the level of AI involvement in this PR.
Who can review?
cc @qgallouedec — this follows up on the chunked-CE FSDP2 work in #6077; happy to adjust the bound or the config if you'd prefer a different shape.
Note
Low Risk
Test-only change with no production or training API modifications; CI cost is an extra multi-GPU subprocess when two accelerators are available.
Overview
Adds a 2-GPU distributed regression test so
loss_type="chunked_nll"under FSDP2 withreshard_after_forwarddoes not re-introduce extralm_head.weightall-gathers per token chunk (the silent perf bug addressed in #6077).A new
_chunked_nll_allgather_worker.pyruns one SFT step viaaccelerate launch, forces many token chunks by lowering_CHUNKED_LM_HEAD_CHUNK_SIZE, and tallies real all-gather collectives withCommDebugModeinside the firsttraining_stepoftrainer.train(). It emits JSON onCHUNKED_NLL_ALLGATHER_RESULTfor pytest to parse.test_sft_chunked_nll_fsdp2_no_per_chunk_allgatherlaunches that worker withfsdp2_reshard.yaml, checks finite loss, requires a non-vacuous multi-chunk run, and asserts all-gather count stays well belown_chunks_if_regressed(not O(n_valid/chunk_size)).PYTHONPATHis set so the child uses the repo’s TRL checkout.Reviewed by Cursor Bugbot for commit 20841b4. Bugbot is set up for automated code reviews on this repo. Configure here.