Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 187 additions & 0 deletions tests/distributed/_chunked_nll_allgather_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Companion worker for the chunked-NLL FSDP2 all-gather perf-regression test.

Launched under ``accelerate launch --config_file <fsdp2_reshard>`` by ``test_distributed.py``. It runs a single SFT
``chunked_nll`` training step on an FSDP2-sharded tiny model and counts how many all-gather collectives occur during
that step, so a regression that re-gathers ``lm_head.weight`` once per token chunk (the PR #6077 failure mode — correct
loss, silently slow) is caught by a bounded assertion.

``_chunked_cross_entropy_loss`` chunks over *valid tokens* (``for start in range(0, n_valid, chunk_size)``), so the
regression scales with ``ceil(n_valid / chunk_size)`` and only shows up when more than one token chunk runs. The zen
test data is tiny, so this worker shrinks the chunk size (see ``_TEST_CHUNK_SIZE``) to force many token chunks, and
derives the regression threshold from the exact ``n_valid`` captured from inside the loss path — never from vocab size.

Why count real collectives and not ``DTensor.full_tensor()``: under FSDP2 the parameter unshard is driven by autograd
pre-hooks / c10d collectives, not by explicit ``full_tensor()`` calls, so a ``full_tensor`` counter is blind to it. We
use ``CommDebugMode`` (``torch.distributed.tensor.debug``) — torch's purpose-built, DTensor-native comm counter, which
records ``funcol.all_gather_into_tensor`` and the ``c10d`` ``_allgather_base_`` / ``allgather_`` variants that FSDP2
emits. (An earlier version also ran a hand-rolled ``TorchDispatchMode``, but re-dispatching sharded ops from a custom
mode under FSDP2 mismatches the index/weight devices on the embedding lookup, so we rely on ``CommDebugMode`` alone.)

Prints one machine-parseable line ``CHUNKED_NLL_ALLGATHER_RESULT {json}`` that the pytest side asserts on.
Self-contained on purpose: it imports only public TRL symbols and runs as ``__main__`` under ``accelerate launch``.
"""

from __future__ import annotations

import json
import math
import tempfile

from datasets import load_dataset

from trl import SFTConfig, SFTTrainer


MODEL_ID = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
RESULT_PREFIX = "CHUNKED_NLL_ALLGATHER_RESULT"


def _count_all_gathers(comm_counts: dict) -> int:
"""Sum the all-gather collectives from a ``CommDebugMode.get_comm_counts()`` dict.

``CommDebugMode`` keys the dict by the comm op (funcol / c10d). FSDP2's parameter unshard shows up as
``funcol.all_gather_into_tensor`` (and the ``_allgather_base_`` / ``allgather_`` c10d variants), so we match on the
op's string name containing ``all_gather`` / ``allgather`` and total those. This is the DTensor-native counter — it
observes the autograd-hook-driven gathers that ``DTensor.full_tensor()`` is blind to, without a hand-rolled
``TorchDispatchMode`` (which mis-dispatches sharded ops under FSDP2).
"""
total = 0
for op, n in comm_counts.items():
name = str(op).lower()
if "all_gather" in name or "allgather" in name:
total += int(n)
return total


class _MeasuringSFTTrainer(SFTTrainer):
"""SFTTrainer that counts all-gather collectives during its first ``training_step``.

The measurement must run *inside* the real ``trainer.train()`` loop, not by calling ``training_step`` directly:
under ``fsdp_cpu_ram_efficient_loading`` the model is on CPU/meta until ``_inner_training_loop`` FSDP-wraps it and
moves it to GPU. Calling ``training_step`` on ``trainer.model`` before ``train()`` runs the embedding lookup with a
CPU weight against a CUDA input → device-mismatch crash. Overriding ``training_step`` lets the trainer do all
wrapping/placement, while we wrap the (single, since ``max_steps=1``) step in ``CommDebugMode`` to tally the FSDP2
unshard collectives.
"""

comm_counts: dict | None = None

def training_step(self, *args, **kwargs):
from torch.distributed.tensor.debug import CommDebugMode

# Only measure the first step (with max_steps=1 there is exactly one); guard anyway so the counts
# reflect a single step even if the caller raises max_steps later.
if self.comm_counts is not None:
return super().training_step(*args, **kwargs)
comm_mode = CommDebugMode()
with comm_mode:
loss = super().training_step(*args, **kwargs)
self.comm_counts = comm_mode.get_comm_counts()
return loss


# The chunked-CE loop chunks over *valid tokens*, not vocab: `for start in range(0, n_valid, chunk_size)`
# in `_chunked_cross_entropy_loss`. So a per-chunk `lm_head.weight` re-gather regression scales with
# ceil(n_valid / chunk_size) — the TOKEN-chunk count — and is only observable when more than one chunk
# runs (n_valid > chunk_size). The zen test data is tiny (~120 valid tokens total), so with the default
# chunk size of 256 only a single chunk would run and a regression would be invisible. We therefore shrink
# the chunk size for this test so the tiny batch genuinely exercises many token-chunks.
_TEST_CHUNK_SIZE = 4


def main() -> None:
import trl.trainer.sft_trainer as sft

# Shrink the chunk size BEFORE the trainer patches the lm_head (it reads this module constant at
# construction). With ~120 valid tokens this yields ~30 token-chunks, so a per-chunk re-gather
# regression would do ~30 lm_head all-gathers vs O(1) for the fixed path — a wide, detectable margin.
sft._CHUNKED_LM_HEAD_CHUNK_SIZE = _TEST_CHUNK_SIZE

# Capture the real valid-token count from inside the chunked-CE path, so the regression threshold is
# derived from the exact n_valid the loop iterates over (never guessed from token lengths).
captured = {}
_orig_cce = sft._chunked_cross_entropy_loss

def _capturing_cce(hidden_states, lm_head_weight, chunk_size, *args, **kwargs):
out = _orig_cce(hidden_states, lm_head_weight, chunk_size, *args, **kwargs)
# Returns (loss, correct, entropy_sum, n_valid_tensor); n_valid is the 4th element.
captured["n_valid"] = int(out[3].item())
captured["chunk_size"] = int(chunk_size)
return out

sft._chunked_cross_entropy_loss = _capturing_cce

dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")
# Write trainer artifacts to a throwaway temp dir so the worker leaves no state in the repo checkout and
# repeated runs can't collide. tempfile keeps this self-contained (no reliance on the launch cwd).
tmp_out = tempfile.mkdtemp(prefix="chunked_nll_fsdp2_")
args = SFTConfig(
output_dir=tmp_out,
loss_type="chunked_nll",
# Pack as many of the tiny examples into the single measured step as possible, so n_valid is well
# above the (shrunk) chunk size and the token-chunk count is large.
per_device_train_batch_size=8,
max_length=64,
max_steps=1,
report_to="none",
bf16=True,
)
trainer = _MeasuringSFTTrainer(model=MODEL_ID, args=args, train_dataset=dataset)

vocab_size = trainer.model.config.vocab_size

# Run the real training loop: it FSDP-wraps the model and moves it to GPU, then calls training_step
# once (max_steps=1), which our subclass measures under CommDebugMode.
trainer.train()

comm_counts = trainer.comm_counts or {}
all_gathers = _count_all_gathers(comm_counts)
comm_total = sum(int(n) for n in comm_counts.values())

n_valid = captured.get("n_valid", 0)
chunk_size = captured.get("chunk_size", _TEST_CHUNK_SIZE)
n_chunks = -(-n_valid // chunk_size) if n_valid else 0 # ceil(n_valid / chunk_size) — TOKEN chunks

last = trainer.state.log_history[-1] if trainer.state.log_history else {}
train_loss = last.get("train_loss")

result = {
"vocab_size": int(vocab_size),
"n_valid": int(n_valid),
"chunk_size": int(chunk_size),
"n_chunks_if_regressed": int(n_chunks),
"all_gathers": int(all_gathers),
"commdebug_total": int(comm_total),
"loss_finite": train_loss is not None and math.isfinite(train_loss),
}
if trainer.accelerator.is_main_process:
print(f"{RESULT_PREFIX} {json.dumps(result)}", flush=True) # noqa: T201 - result channel for the launcher


if __name__ == "__main__":
# Print the full traceback from this worker directly: when `accelerate launch` re-raises a child
# failure, the parent only sees a truncated `CompletedProcess` repr, which hides the real error frame.
# Surfacing it here puts the complete traceback in the worker's own stderr (and thus the CI log).
import sys
import traceback

try:
main()
except Exception:
traceback.print_exc()
sys.stderr.flush()
raise
31 changes: 31 additions & 0 deletions tests/distributed/data/accelerate_configs/fsdp2_reshard.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# 2-process FSDP2 config for the chunked-NLL all-gather perf-regression test.
#
# `fsdp_reshard_after_forward: true` is the load-bearing setting: it is the condition under which PR
# #6077's per-chunk `lm_head.weight` all-gather manifested (the param is resharded after forward, then
# re-gathered during the chunked backward). The minimal `fsdp2.yaml` leaves this at FSDP's default, so a
# dedicated config that states it explicitly keeps the collective count meaningful. Mirrors
# `examples/accelerate_configs/fsdp2.yaml` with `num_processes: 2` for a 2-GPU node.
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
enable_cpu_affinity: false
fsdp_config:
fsdp_activation_checkpointing: false
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_cpu_ram_efficient_loading: true
fsdp_offload_params: false
fsdp_reshard_after_forward: true
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_version: 2
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
59 changes: 59 additions & 0 deletions tests/distributed/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
import subprocess
from pathlib import Path
Expand Down Expand Up @@ -376,3 +377,61 @@ def test_grpo_liger(self, config, get_config_path):
os.environ.copy(),
)
# fmt: on

def test_sft_chunked_nll_fsdp2_no_per_chunk_allgather(self, lazy_shared_datadir):
# Perf-regression guard for the PR #6077 class: a chunked cross-entropy path must NOT re-gather the
# sharded `lm_head.weight` once per token chunk under FSDP2 (correct loss, silently slow, invisible
# to a pass/fail test). The companion worker runs one SFT `chunked_nll` step under a 2-process FSDP2
# group (reshard_after_forward=True — the condition that triggers the bug) and counts the all-gather
# collectives during that step; here we assert the count stays O(1), not O(n_valid / chunk_size).
#
# `_chunked_cross_entropy_loss` chunks over VALID TOKENS, not vocab, so the regression scales with
# ceil(n_valid / chunk_size) and only manifests when more than one token chunk runs. The worker
# shrinks the chunk size so the tiny zen batch exercises many token chunks, and reports the exact
# n_valid / chunk_size it measured so this side can both bound the count and confirm the test is
# non-vacuous (n_chunks_if_regressed > 1 — otherwise a regression could never have been observed).
#
# Counting real collectives is required: under FSDP2 the parameter unshard is driven by autograd
# hooks / c10d collectives, not by `DTensor.full_tensor()`, so the worker counts the actual
# all-gather collectives via CommDebugMode (torch's DTensor-native comm counter) for the step.
worker = Path(__file__).parent / "_chunked_nll_allgather_worker.py"
config_path = lazy_shared_datadir / "accelerate_configs" / "fsdp2_reshard.yaml"
# Pin the repo root onto PYTHONPATH for the child: `accelerate launch` re-execs each rank via
# torch.distributed.elastic, which sets sys.path[0] to the launched script's directory, not cwd.
# Without this, a non-editable `trl` already in site-packages would shadow the working tree.
env = os.environ.copy()
env["PYTHONPATH"] = os.pathsep.join([str(ROOT), env.get("PYTHONPATH", "")]).rstrip(os.pathsep)
result = subprocess.run(
["accelerate", "launch", "--config_file", str(config_path), str(worker)],
env=env,
cwd=ROOT,
capture_output=True,
text=True,
)
assert result.returncode == 0, f"worker failed:\nSTDOUT:\n{result.stdout}\nSTDERR:\n{result.stderr}"

prefix = "CHUNKED_NLL_ALLGATHER_RESULT"
lines = [ln for ln in result.stdout.splitlines() if ln.startswith(prefix)]
assert len(lines) == 1, f"expected exactly one result line, got {lines}\n{result.stdout}"
measured = json.loads(lines[0][len(prefix) :].strip())

assert measured["loss_finite"], f"chunked_nll loss not finite under FSDP2: {measured}"
# Non-vacuity guard (the heart of this test): a per-token-chunk regression can only be detected if the
# step actually ran multiple token chunks. If only one chunk ran, a regression would gather exactly
# once too, so the test would pass for the wrong reason. Require a comfortably multi-chunk run.
assert measured["n_chunks_if_regressed"] > 4, (
f"test is vacuous — only {measured['n_chunks_if_regressed']} token chunk(s) ran, so a per-chunk "
f"regression could not be observed; increase batch/length or shrink chunk_size: {measured}"
)
# A per-token-chunk-regather regression would do ~n_chunks all-gathers of lm_head.weight in the step;
# the fixed path does O(1). `all_gathers` is the total all-gather collective count for the step,
# measured by CommDebugMode (the DTensor-native counter that sees FSDP2's autograd-hook-driven
# gathers). It legitimately includes one gather per sharded parameter (a handful of decoder layers),
# so bound it well below the regression count rather than at exactly 1. The ceiling scales off
# n_chunks (never a hardcoded collective count) so it tracks the model's token/chunk arithmetic.
observed = measured["all_gathers"]
ceiling = max(16, measured["n_chunks_if_regressed"] // 4)
assert observed < measured["n_chunks_if_regressed"], (
f"per-chunk lm_head.weight all-gathers detected (#6077 regression): {measured}"
)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Vacuity bound below gather baseline

Medium Severity

The regression check requires all_gathers to be strictly less than n_chunks_if_regressed, but the non-vacuity guard only requires more than four token chunks. On the fixed FSDP2 path, all_gathers stays roughly O(1) per sharded parameter (about ten in the PR’s run), so if n_chunks_if_regressed falls between that baseline and about eleven, the test fails even without a per-chunk regression.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 20841b4. Configure here.

assert observed <= ceiling, f"unexpectedly many all-gathers (possible regression): {measured}"
Loading