Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 135 additions & 0 deletions tests/distributed/_chunked_nll_allgather_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Companion worker for the chunked-NLL FSDP2 all-gather perf-regression test.

Launched under ``accelerate launch --config_file <fsdp2_reshard>`` by ``test_distributed.py``. It runs a single SFT
``chunked_nll`` training step on an FSDP2-sharded tiny model and counts how many all-gather collectives occur during
that step, so a regression that re-gathers ``lm_head.weight`` once per vocab chunk (the PR #6077 failure mode — correct
loss, silently slow) is caught by a bounded assertion.

Why count real collectives and not ``DTensor.full_tensor()``: under FSDP2 the parameter unshard is driven by autograd
pre-hooks / c10d collectives, not by explicit ``full_tensor()`` calls, so a ``full_tensor`` counter is blind to it. We
use ``CommDebugMode`` (``torch.distributed.tensor.debug``) — torch's purpose-built, DTensor-native comm counter, which
records ``funcol.all_gather_into_tensor`` and the ``c10d`` ``_allgather_base_`` / ``allgather_`` variants that FSDP2
emits. (An earlier version also ran a hand-rolled ``TorchDispatchMode``, but re-dispatching sharded ops from a custom
mode under FSDP2 mismatches the index/weight devices on the embedding lookup, so we rely on ``CommDebugMode`` alone.)

Prints one machine-parseable line ``CHUNKED_NLL_ALLGATHER_RESULT {json}`` that the pytest side asserts on.
Self-contained (mirrors ``tests/experimental/_async_grpo_fsdp2_worker.py``): imports only public symbols.
"""

from __future__ import annotations

import json
import math

from datasets import load_dataset

from trl import SFTConfig, SFTTrainer


MODEL_ID = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
RESULT_PREFIX = "CHUNKED_NLL_ALLGATHER_RESULT"


def _count_all_gathers(comm_counts: dict) -> int:
"""Sum the all-gather collectives from a ``CommDebugMode.get_comm_counts()`` dict.

``CommDebugMode`` keys the dict by the comm op (funcol / c10d). FSDP2's parameter unshard shows up as
``funcol.all_gather_into_tensor`` (and the ``_allgather_base_`` / ``allgather_`` c10d variants), so we match on the
op's string name containing ``all_gather`` / ``allgather`` and total those. This is the DTensor-native counter — it
observes the autograd-hook-driven gathers that ``DTensor.full_tensor()`` is blind to, without a hand-rolled
``TorchDispatchMode`` (which mis-dispatches sharded ops under FSDP2).
"""
total = 0
for op, n in comm_counts.items():
name = str(op).lower()
if "all_gather" in name or "allgather" in name:
total += int(n)
return total


class _MeasuringSFTTrainer(SFTTrainer):
"""SFTTrainer that counts all-gather collectives during its first ``training_step``.

The measurement must run *inside* the real ``trainer.train()`` loop, not by calling ``training_step`` directly:
under ``fsdp_cpu_ram_efficient_loading`` the model is on CPU/meta until ``_inner_training_loop`` FSDP-wraps it and
moves it to GPU. Calling ``training_step`` on ``trainer.model`` before ``train()`` runs the embedding lookup with a
CPU weight against a CUDA input → device-mismatch crash. Overriding ``training_step`` lets the trainer do all
wrapping/placement, while we wrap the (single, since ``max_steps=1``) step in ``CommDebugMode`` to tally the FSDP2
unshard collectives.
"""

comm_counts: dict | None = None

def training_step(self, *args, **kwargs):
from torch.distributed.tensor.debug import CommDebugMode

# Only measure the first step (with max_steps=1 there is exactly one); guard anyway so the counts
# reflect a single step even if the caller raises max_steps later.
if self.comm_counts is not None:
return super().training_step(*args, **kwargs)
comm_mode = CommDebugMode()
with comm_mode:
loss = super().training_step(*args, **kwargs)
self.comm_counts = comm_mode.get_comm_counts()
return loss


def main() -> None:
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")
args = SFTConfig(
output_dir="chunked_nll_fsdp2_out",
loss_type="chunked_nll",
per_device_train_batch_size=2,
max_length=64,
max_steps=1,
report_to="none",
bf16=True,
)
Comment on lines +128 to +142
trainer = _MeasuringSFTTrainer(model=MODEL_ID, args=args, train_dataset=dataset)

# vocab / chunk arithmetic: a per-chunk-regather regression would do ~ceil(vocab / chunk_size)
# gathers of lm_head.weight per step; the fixed path does O(1). Computed, never hardcoded.
from trl.trainer.sft_trainer import _CHUNKED_LM_HEAD_CHUNK_SIZE

vocab_size = trainer.model.config.vocab_size
n_chunks = -(-vocab_size // _CHUNKED_LM_HEAD_CHUNK_SIZE) # ceil
Comment thread
cursor[bot] marked this conversation as resolved.
Outdated

# Run the real training loop: it FSDP-wraps the model and moves it to GPU, then calls training_step
# once (max_steps=1), which our subclass measures under CommDebugMode.
trainer.train()

comm_counts = trainer.comm_counts or {}
all_gathers = _count_all_gathers(comm_counts)
comm_total = sum(int(n) for n in comm_counts.values())

last = trainer.state.log_history[-1] if trainer.state.log_history else {}
train_loss = last.get("train_loss")

result = {
"vocab_size": int(vocab_size),
"chunk_size": int(_CHUNKED_LM_HEAD_CHUNK_SIZE),
"n_chunks_if_regressed": int(n_chunks),
"all_gathers": int(all_gathers),
"commdebug_total": int(comm_total),
"loss_finite": train_loss is not None and math.isfinite(train_loss),
}
if trainer.accelerator.is_main_process:
print(f"{RESULT_PREFIX} {json.dumps(result)}", flush=True) # noqa: T201 - result channel for the launcher


if __name__ == "__main__":
main()
31 changes: 31 additions & 0 deletions tests/distributed/data/accelerate_configs/fsdp2_reshard.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# 2-process FSDP2 config for the chunked-NLL all-gather perf-regression test.
#
# `fsdp_reshard_after_forward: true` is the load-bearing setting: it is the condition under which PR
# #6077's per-chunk `lm_head.weight` all-gather manifested (the param is resharded after forward, then
# re-gathered during the chunked backward). The minimal `fsdp2.yaml` leaves this at FSDP's default, so a
# dedicated config that states it explicitly keeps the collective count meaningful. Mirrors
# `examples/accelerate_configs/fsdp2.yaml` with `num_processes: 2` for a 2-GPU node.
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
enable_cpu_affinity: false
fsdp_config:
fsdp_activation_checkpointing: false
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_cpu_ram_efficient_loading: true
fsdp_offload_params: false
fsdp_reshard_after_forward: true
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_version: 2
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
46 changes: 46 additions & 0 deletions tests/distributed/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
import subprocess
from pathlib import Path
Expand Down Expand Up @@ -376,3 +377,48 @@ def test_grpo_liger(self, config, get_config_path):
os.environ.copy(),
)
# fmt: on

def test_sft_chunked_nll_fsdp2_no_per_chunk_allgather(self, lazy_shared_datadir):
# Perf-regression guard for the PR #6077 class: a chunked cross-entropy path must NOT re-gather the
# sharded `lm_head.weight` once per vocab chunk under FSDP2 (correct loss, silently slow, invisible
# to a pass/fail test). The companion worker runs one SFT `chunked_nll` step under a 2-process FSDP2
# group (reshard_after_forward=True — the condition that triggers the bug) and counts the all-gather
# collectives during that step; here we assert the count stays O(1), not O(vocab / chunk_size).
#
# Counting real collectives is required: under FSDP2 the parameter unshard is driven by autograd
# hooks / c10d collectives, not by `DTensor.full_tensor()`, so the worker counts the actual
# all-gather collectives via CommDebugMode (torch's DTensor-native comm counter) for the step.
worker = Path(__file__).parent / "_chunked_nll_allgather_worker.py"
config_path = lazy_shared_datadir / "accelerate_configs" / "fsdp2_reshard.yaml"
# Pin the repo root onto PYTHONPATH for the child: `accelerate launch` re-execs each rank via
# torch.distributed.elastic, which sets sys.path[0] to the launched script's directory, not cwd.
# Without this, a non-editable `trl` already in site-packages would shadow the working tree.
env = os.environ.copy()
env["PYTHONPATH"] = os.pathsep.join([str(ROOT), env.get("PYTHONPATH", "")]).rstrip(os.pathsep)
result = subprocess.run(
["accelerate", "launch", "--config_file", str(config_path), str(worker)],
env=env,
cwd=ROOT,
capture_output=True,
text=True,
)
assert result.returncode == 0, f"worker failed:\nSTDOUT:\n{result.stdout}\nSTDERR:\n{result.stderr}"

prefix = "CHUNKED_NLL_ALLGATHER_RESULT"
lines = [ln for ln in result.stdout.splitlines() if ln.startswith(prefix)]
assert len(lines) == 1, f"expected exactly one result line, got {lines}\n{result.stdout}"
measured = json.loads(lines[0][len(prefix) :].strip())

assert measured["loss_finite"], f"chunked_nll loss not finite under FSDP2: {measured}"
# A per-chunk-regather regression would do ~n_chunks all-gathers of lm_head.weight in the step; the
# fixed path does O(1). `all_gathers` is the total all-gather collective count for the step, measured
# by CommDebugMode (the DTensor-native counter that sees FSDP2's autograd-hook-driven gathers). It
# legitimately includes one gather per sharded parameter (a handful of decoder layers), so bound it
# well below the regression count rather than at exactly 1. The ceiling scales off n_chunks (never a
# hardcoded collective count) so it tracks the model's vocab/chunk arithmetic.
observed = measured["all_gathers"]
ceiling = max(16, measured["n_chunks_if_regressed"] // 4)
assert observed < measured["n_chunks_if_regressed"], (
f"per-chunk lm_head.weight all-gathers detected (#6077 regression): {measured}"
)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Vacuity bound below gather baseline

Medium Severity

The regression check requires all_gathers to be strictly less than n_chunks_if_regressed, but the non-vacuity guard only requires more than four token chunks. On the fixed FSDP2 path, all_gathers stays roughly O(1) per sharded parameter (about ten in the PR’s run), so if n_chunks_if_regressed falls between that baseline and about eleven, the test fails even without a per-chunk regression.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 20841b4. Configure here.

assert observed <= ceiling, f"unexpectedly many all-gathers (possible regression): {measured}"
Loading