Skip to content

test: guard against per-chunk lm_head all-gather in chunked_nll under…#6172

Draft
behroozazarkhalili wants to merge 2 commits into
mainfrom
test/chunked-nll-fsdp2-allgather-guard
Draft

test: guard against per-chunk lm_head all-gather in chunked_nll under…#6172
behroozazarkhalili wants to merge 2 commits into
mainfrom
test/chunked-nll-fsdp2-allgather-guard

Conversation

@behroozazarkhalili

@behroozazarkhalili behroozazarkhalili commented Jun 24, 2026

Copy link
Copy Markdown
Collaborator

What does this PR do?

Adds a 2-GPU distributed regression test that guards the chunked_nll cross-entropy path against re-introducing a per-chunk lm_head.weight all-gather under FSDP2 — the failure mode fixed in #6077.

Why

loss_type="chunked_nll" computes the LM loss in vocab chunks. Under FSDP2 with reshard_after_forward, a regression can re-gather the sharded lm_head.weight once per vocab chunk during the backward. The loss stays numerically correct, but each step performs O(vocab / chunk_size) extra all-gather collectives instead of O(1).

This is exactly the kind of bug a pass/fail test cannot catch — the produced numbers are identical whether the weight is gathered once or hundreds of times; only the collective count differs. So the regression is silent: correct results, quietly degraded throughput.

How

The test launches a companion worker under a 2-process FSDP2 group (reshard_after_forward=True — the condition that triggers the bug), runs a single chunked_nll step, and asserts the all-gather collective count stays bounded.

A few design points worth calling out:

  • Real collectives, not full_tensor(). Under FSDP2 the parameter unshard is driven by autograd hooks / c10d collectives, not by explicit DTensor.full_tensor() calls, so a full_tensor-counting approach is blind to it. The count is measured with CommDebugMode (torch's DTensor-native comm counter), which records funcol.all_gather_into_tensor and the c10d _allgather_base_ / allgather_ variants FSDP2 emits.
  • Measured inside the real training loop. The worker measures from within trainer.train() (via a training_step override), because under fsdp_cpu_ram_efficient_loading the model is only FSDP-wrapped and moved to GPU inside the training loop — calling training_step directly would run the embedding lookup with a CPU weight against a CUDA input.
  • No hardcoded thresholds. The bound is derived from the model's vocab_size / chunk_size arithmetic, so it tracks any model.
  • Skips unless two accelerators are present (@require_torch_multi_accelerator).

Verification

Run on 2× H100 (trl-internal-testing/tiny-Qwen2ForCausalLM-2.5, vocab 151665, chunk 256):

CHUNKED_NLL_ALLGATHER_RESULT
{"vocab_size": 151665, "chunk_size": 256, "n_chunks_if_regressed": 593,
 "all_gathers": 10, "commdebug_total": 13, "loss_finite": true}

tests/distributed/test_distributed.py::TestDistributed::test_sft_chunked_nll_fsdp2_no_per_chunk_allgather PASSED

The fixed path does 10 all-gathers (O(1) — roughly one per sharded parameter); a per-chunk regression would do ~593 (= ⌈vocab/chunk⌉). The ~59× gap gives the assertion clear discriminating power.

Related to #6077 (the chunked-CE FSDP2 fix this test guards against re-breaking).

Before submitting

AI writing disclosure

We welcome the use of AI tools to help with contributions. For transparency and to help us improve our review process, please indicate the level of AI involvement in this PR.

  • No AI usage: the PR was written entirely by a human.
  • AI-assisted: some parts were suggested or improved by AI, but the PR was written and reviewed by a human.
  • AI-generated: the PR was mostly or fully generated by an AI tool.

Who can review?

cc @qgallouedec — this follows up on the chunked-CE FSDP2 work in #6077; happy to adjust the bound or the config if you'd prefer a different shape.


Note

Low Risk
Test-only change with no production or training API modifications; CI cost is an extra multi-GPU subprocess when two accelerators are available.

Overview
Adds a 2-GPU distributed regression test so loss_type="chunked_nll" under FSDP2 with reshard_after_forward does not re-introduce extra lm_head.weight all-gathers per token chunk (the silent perf bug addressed in #6077).

A new _chunked_nll_allgather_worker.py runs one SFT step via accelerate launch, forces many token chunks by lowering _CHUNKED_LM_HEAD_CHUNK_SIZE, and tallies real all-gather collectives with CommDebugMode inside the first training_step of trainer.train(). It emits JSON on CHUNKED_NLL_ALLGATHER_RESULT for pytest to parse.

test_sft_chunked_nll_fsdp2_no_per_chunk_allgather launches that worker with fsdp2_reshard.yaml, checks finite loss, requires a non-vacuous multi-chunk run, and asserts all-gather count stays well below n_chunks_if_regressed (not O(n_valid/chunk_size)). PYTHONPATH is set so the child uses the repo’s TRL checkout.

Reviewed by Cursor Bugbot for commit 20841b4. Bugbot is set up for automated code reviews on this repo. Configure here.

… FSDP2

The chunked cross-entropy path (`loss_type="chunked_nll"`) computes the LM
loss in vocab chunks. Under FSDP2 with reshard_after_forward, a regression
can re-gather the sharded `lm_head.weight` once per vocab chunk during the
backward — the loss stays correct but each step does O(vocab/chunk_size)
extra all-gather collectives instead of O(1). This is invisible to a
pass/fail test (the numbers are identical) and was the failure mode fixed
in PR #6077.

Add a 2-GPU distributed test that runs one chunked_nll step under FSDP2 and
asserts the all-gather collective count stays bounded. The count is measured
with CommDebugMode (torch's DTensor-native comm counter), which observes the
autograd-hook-driven gathers that `DTensor.full_tensor()` cannot. The bound
is derived from the model's vocab/chunk arithmetic, never hardcoded, so it
tracks any model. Skips unless two accelerators are present.
Comment thread tests/distributed/_chunked_nll_allgather_worker.py Outdated
@bot-ci-comment

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a 2-GPU distributed regression test to detect performance regressions where loss_type="chunked_nll" under FSDP2 (reshard_after_forward=True) re-introduces per-vocab-chunk lm_head.weight all-gathers (the #6077 failure mode), by measuring and bounding the number of all-gather collectives during a real trainer.train() step.

Changes:

  • Adds a new distributed pytest that launches an accelerate 2-process job and asserts the measured all-gather count stays O(1) (vs. O(vocab/chunk_size) under regression).
  • Adds a companion worker script that runs a single SFT step under CommDebugMode and prints a machine-parseable JSON result line.
  • Adds a dedicated accelerate config explicitly enabling fsdp_reshard_after_forward: true for the regression scenario.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.

File Description
tests/distributed/test_distributed.py Adds a new distributed regression test that launches a worker and asserts all-gather collectives stay bounded.
tests/distributed/data/accelerate_configs/fsdp2_reshard.yaml Introduces an explicit 2-process FSDP2 config with fsdp_reshard_after_forward: true for the perf regression guard.
tests/distributed/_chunked_nll_allgather_worker.py Adds a standalone worker that measures all-gather collectives during a single SFT chunked_nll step via CommDebugMode.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

mode under FSDP2 mismatches the index/weight devices on the embedding lookup, so we rely on ``CommDebugMode`` alone.)

Prints one machine-parseable line ``CHUNKED_NLL_ALLGATHER_RESULT {json}`` that the pytest side asserts on.
Self-contained (mirrors ``tests/experimental/_async_grpo_fsdp2_worker.py``): imports only public symbols.
Comment on lines +92 to +101
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")
args = SFTConfig(
output_dir="chunked_nll_fsdp2_out",
loss_type="chunked_nll",
per_device_train_batch_size=2,
max_length=64,
max_steps=1,
report_to="none",
bf16=True,
)
@qgallouedec

Copy link
Copy Markdown
Member

Thanks, do you have a repro?

Address review feedback on the chunked_nll FSDP2 all-gather guard:

- Chunk over valid tokens, not vocab (Cursor Bugbot): `_chunked_cross_entropy_loss`
  iterates `range(0, n_valid, chunk_size)`, so a per-chunk re-gather regression
  scales with ceil(n_valid / chunk_size), not vocab. The worker now shrinks the
  chunk size and derives the regression threshold from the exact `n_valid` it
  measures, and the test asserts the run is non-vacuous (>1 token chunk) so a
  regression could actually be observed.
- Drop the stale cross-file docstring reference and write trainer artifacts to a
  tempdir instead of the repo checkout (Copilot).
- Print the worker's own full traceback on failure: `accelerate launch` only
  re-raises a truncated CompletedProcess repr, hiding the real error frame.

@cursor cursor Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes using default effort and found 1 potential issue.

Fix All in Cursor

❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

Reviewed by Cursor Bugbot for commit 20841b4. Configure here.

ceiling = max(16, measured["n_chunks_if_regressed"] // 4)
assert observed < measured["n_chunks_if_regressed"], (
f"per-chunk lm_head.weight all-gathers detected (#6077 regression): {measured}"
)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Vacuity bound below gather baseline

Medium Severity

The regression check requires all_gathers to be strictly less than n_chunks_if_regressed, but the non-vacuity guard only requires more than four token chunks. On the fixed FSDP2 path, all_gathers stays roughly O(1) per sharded parameter (about ten in the PR’s run), so if n_chunks_if_regressed falls between that baseline and about eleven, the test fails even without a per-chunk regression.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 20841b4. Configure here.

@behroozazarkhalili

Copy link
Copy Markdown
Collaborator Author

Thanks for the patience here @qgallouedec — you asked for a repro, and building one surfaced two real issues with this test that I want to lay out honestly before we decide its fate.

1. The CI failure is a torch 2.12 CommDebugMode bug, not the test logic

The Distributed smoke tests failure traces (I added a worker-side traceback.print_exc() to get past accelerate launch's truncated CompletedProcess repr) to inside torch itself:

File ".../torch/distributed/tensor/debug/_comm_mode.py", line 106, in _fw_set_module_hook
    self.name = self.parent_list[-1]
IndexError: list index out of range
  • CI runs torch 2.12.1; on torch 2.11.0 the same test passes (verified on 2×H100).
  • chunked_nll computes each chunk under torch.utils.checkpoint. During the checkpoint recompute, CommDebugMode's module hooks pop parent_list more times than they push, so parent_list[-1] underflows. It's a CommDebugMode + gradient-checkpointing interaction that regressed in 2.12.

So the measurement tool the test relies on (CommDebugMode) isn't torch-portable for gradient-checkpointed code. That's not something the test can paper over.

2. I couldn't get the guard to actually fail on a real regression

To prove the guard has teeth, I reverted the fix in a controlled run — neutralized the one full_tensor() pre-gather (sft_trainer.py:322) so the lm_head DTensor flows into the checkpointed loop, which should re-gather per chunk. On 2×H100 the all-gather count came out identical to the fixed path (10 vs 10, with ~19 token chunks). So either the per-chunk re-gather doesn't emit distinctly-countable collectives in this setup, or CommDebugMode isn't observing them the way I assumed — meaning I can't currently demonstrate the test would catch the #6077 regression it's meant to guard.

Where that leaves it

Given (1) a torch-version-fragile measurement and (2) unproven detection power, I don't think this is mergeable as-is, and I'd rather not ship a green-but-ineffective guard. Options I see:

You know the chunked-CE design and the CI matrix best — how would you prefer to guard this regression class, if at all? Happy to do the rework or close, whichever you think is right.

@behroozazarkhalili behroozazarkhalili marked this pull request as draft June 25, 2026 15:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants