Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 175 additions & 0 deletions tests/experimental/_async_grpo_fsdp2_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Companion worker launched under ``accelerate launch --config_file <fsdp2>`` by the FSDP2 case in
``test_async_grpo_trainer.py``.

It runs a couple of :class:`AsyncGRPOTrainer` steps on an FSDP2-sharded model, driven by an in-process stub rollout
worker (no vLLM server, no NCCL weight transfer), and checks that training actually progresses under FSDP2: the loss is
finite and the parameters change. It then prints one machine-parseable result line (``ASYNC_GRPO_FSDP2_RESULT {json}``)
that the pytest side asserts on.

This is a *functional* FSDP2 smoke, not a performance microbenchmark. (An earlier version tried to count
``lm_head.weight`` all-gathers to answer PR #6077's per-chunk re-gather question, but under FSDP2 those gathers are
driven by autograd unshard hooks, not by ``DTensor.full_tensor``, and the trainer's own weight-sync path calls
``full_tensor`` on every parameter every step — so a ``full_tensor`` counter cannot isolate the chunked-logprob path.
The #6077 question is instead settled by static analysis: ``patch_chunked_lm_head`` uses a plain custom autograd
Function with no ``torch.utils.checkpoint`` recompute, so the per-chunk re-gather mechanism that PR #6077 fixed for
SFT's ``chunked_nll`` is structurally absent here.)

Self-contained on purpose (mirrors ``tests/experimental/_openreward_echo_env.py``): it imports only public TRL symbols
and carries its own stub, so it never imports pytest-internal classes across the subprocess boundary.
"""

from __future__ import annotations

import itertools
import json
import queue

import numpy as np
import torch
from datasets import load_dataset
from transformers import AutoTokenizer

from trl.experimental.async_grpo import AsyncGRPOConfig, AsyncGRPOTrainer
from trl.experimental.async_grpo.async_rollout_worker import RolloutSample


MODEL_ID = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
RESULT_PREFIX = "ASYNC_GRPO_FSDP2_RESULT"


def dummy_reward_func(completions, **kwargs):
# Mirrors tests/experimental/test_async_grpo_trainer.py: the stub pre-computes rewards, so this is
# only here to satisfy the trainer's required `reward_funcs` argument.
return [float(hash(c[0]["content"]) % 100) / 100.0 for c in completions]


class _StubRolloutWorker:
"""Minimal in-process rollout worker — same shape as the one in test_async_grpo_trainer.py.

Reproduced here (rather than imported) because this module runs as ``__main__`` under ``accelerate launch``, not as
a pytest module, so importing the test class would be fragile. Keeping it self-contained matches the openreward
companion-script precedent.
"""

def __init__(self, tokenizer, dataset, num_generations: int = 3, samples_per_weight_sync: int = 10):
self.rollout_buffer = queue.Queue()
self._samples_per_weight_sync = samples_per_weight_sync
self._model_version = 0
self._sample_iter = self._make_sample_iter(tokenizer, dataset, num_generations)

def _make_sample_iter(self, tokenizer, dataset, num_generations):
for row in itertools.cycle(dataset):
completions = [
[{"role": "assistant", "content": f"{row['completion'][0]['content']} {idx}"}]
for idx in range(num_generations)
]
prompt_completions = [row["prompt"] + completion for completion in completions]
prompt_ids = tokenizer.apply_chat_template(
row["prompt"], tokenize=True, add_generation_prompt=True, return_dict=False
)
prompt_completion_ids = tokenizer.apply_chat_template(
prompt_completions, tokenize=True, add_generation_prompt=False, return_dict=False
)
rewards = np.array(dummy_reward_func(completions))
advantages = (rewards - rewards.mean()) / rewards.std()
for idx in range(num_generations):
completion_ids = prompt_completion_ids[idx][len(prompt_ids) :]
yield RolloutSample(
prompt=row["prompt"],
completion=completions[idx],
input_ids=prompt_ids + completion_ids,
completion_mask=[0] * len(prompt_ids) + [1] * len(completion_ids),
old_log_probs=[0.0] * len(prompt_ids) + [-0.5] * len(completion_ids),
advantage=float(advantages[idx]),
model_version=self._model_version,
metrics={"reward": float(rewards[idx]), "reward_std": float(rewards.std())},
)

def _fill_queue(self):
for _ in range(self._samples_per_weight_sync):
self.rollout_buffer.put(next(self._sample_iter))

def start(self):
self._fill_queue()

def update_model_version(self, version):
self._model_version = version
self._fill_queue()

def stop(self):
pass

def check_health(self, stale_after_s):
pass


def main() -> None:
dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_completion", split="train")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

# Same minimal, memory-frugal config as the existing single-process test_train, with 2 steps so we
# exercise the optimizer loop more than once under FSDP2.
args = AsyncGRPOConfig(
output_dir="async_grpo_fsdp2_out",
learning_rate=0.1,
per_device_train_batch_size=3,
num_generations=3,
max_completion_length=8,
max_steps=2,
vllm_server_timeout=5.0,
report_to="none",
)
trainer = AsyncGRPOTrainer(
model=MODEL_ID,
reward_funcs=dummy_reward_func,
args=args,
train_dataset=dataset,
rollout_worker=_StubRolloutWorker(tokenizer, dataset, num_generations=3),
)

# Snapshot params before training so we can confirm FSDP2 training actually updated them.
before = {n: p.detach().clone() for n, p in trainer.model.named_parameters()}

trainer.train()

# Did any parameter change? Materialize DTensors (full_tensor) and move both operands to CPU before
# comparing: the `before` snapshot is captured at construction (pre-FSDP-wrap, plain tensor) while the
# post-train param is an FSDP2 DTensor on CUDA, so a direct torch.equal would raise a device mismatch.
def _materialize(t):
t = t.full_tensor() if isinstance(t, torch.distributed.tensor.DTensor) else t
return t.detach().cpu()

changed = False
for n, p in trainer.model.named_parameters():
if not torch.equal(_materialize(before[n]), _materialize(p)):
changed = True
break

last = trainer.state.log_history[-1] if trainer.state.log_history else {}
train_loss = last.get("train_loss")
result = {
"steps": trainer.state.global_step,
"params_changed": changed,
"train_loss_finite": train_loss is not None and bool(np.isfinite(train_loss)),
}
# Only rank 0 prints the asserted line, so the pytest side parses exactly one result.
if trainer.accelerator.is_main_process:
print(f"{RESULT_PREFIX} {json.dumps(result)}", flush=True) # noqa: T201 - result channel for the launcher


if __name__ == "__main__":
main()
29 changes: 29 additions & 0 deletions tests/experimental/data/accelerate_configs/fsdp2_reshard.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# 2-process FSDP2 config for the async-GRPO FSDP2 functional test (test_train_fsdp2).
#
# `fsdp_reshard_after_forward: true` is set explicitly so the test exercises the resharding parameter
# lifecycle rather than relying on FSDP's default. Mirrors `examples/accelerate_configs/fsdp2.yaml`
# with `num_processes: 2` for a 2-GPU node.
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
enable_cpu_affinity: false
fsdp_config:
fsdp_activation_checkpointing: false
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_cpu_ram_efficient_loading: true
fsdp_offload_params: false
fsdp_reshard_after_forward: true
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_version: 2
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
51 changes: 50 additions & 1 deletion tests/experimental/test_async_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
# limitations under the License.

import itertools
import json
import os
import queue
import subprocess
from pathlib import Path

import numpy as np
import pytest
Expand All @@ -25,7 +29,14 @@
from trl.experimental.async_grpo import AsyncGRPOConfig, AsyncGRPOTrainer
from trl.experimental.async_grpo.async_rollout_worker import RolloutSample

from ..testing_utils import TrlTestCase, is_ampere_or_newer
from ..testing_utils import TrlTestCase, is_ampere_or_newer, require_torch_multi_accelerator


ROOT = Path(__file__).resolve().parents[2]
_HERE = Path(__file__).parent
_FSDP2_WORKER = _HERE / "_async_grpo_fsdp2_worker.py"
_FSDP2_CONFIG = _HERE / "data" / "accelerate_configs" / "fsdp2_reshard.yaml"
_FSDP2_RESULT_PREFIX = "ASYNC_GRPO_FSDP2_RESULT"


def dummy_reward_func(completions, **kwargs):
Expand Down Expand Up @@ -134,3 +145,41 @@ def test_train(self):
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."

@require_torch_multi_accelerator
def test_train_fsdp2(self):
# Functional smoke: AsyncGRPOTrainer trains under a 2-process FSDP2 group. This exercises the
# `patch_chunked_lm_head` chunked-logprob path on FSDP2-sharded parameters end-to-end and confirms
# the optimizer actually updates them. The worker uses an in-process stub rollout worker (no vLLM
# server / NCCL weight transfer), so the only distributed surface is the FSDP2 parameter lifecycle.
#
# (This is NOT a #6077 all-gather microbenchmark: under FSDP2 the per-parameter gathers are driven
# by autograd unshard hooks, not by `DTensor.full_tensor`, and the trainer's weight-sync path calls
# `full_tensor` on every parameter every step — so counting `full_tensor` cannot isolate the chunk
# path. The #6077 question is settled by static analysis instead: `patch_chunked_lm_head` has no
# `torch.utils.checkpoint` recompute, so the per-chunk re-gather that PR #6077 fixed for SFT's
# `chunked_nll` is structurally absent here.)
#
# Pin the repo root onto PYTHONPATH for the child: `accelerate launch` re-execs each rank via
# torch.distributed.elastic, which sets sys.path[0] to the launched script's directory
# (tests/experimental/), not cwd. Without this, a non-editable `trl` already in site-packages
# would shadow the working tree and the test would exercise the wrong code.
env = os.environ.copy()
env["PYTHONPATH"] = os.pathsep.join([str(ROOT), env.get("PYTHONPATH", "")]).rstrip(os.pathsep)
result = subprocess.run(
["accelerate", "launch", "--config_file", str(_FSDP2_CONFIG), str(_FSDP2_WORKER)],
env=env,
cwd=ROOT,
capture_output=True,
text=True,
)
assert result.returncode == 0, f"FSDP2 worker failed:\nSTDOUT:\n{result.stdout}\nSTDERR:\n{result.stderr}"

result_lines = [ln for ln in result.stdout.splitlines() if ln.startswith(_FSDP2_RESULT_PREFIX)]
assert len(result_lines) == 1, f"expected exactly one result line, got {result_lines}\n{result.stdout}"
measured = json.loads(result_lines[0][len(_FSDP2_RESULT_PREFIX) :].strip())

# Training actually ran under FSDP2, produced a finite loss, and updated the parameters.
assert measured["steps"] >= 1, f"no training steps ran: {measured}"
assert measured["train_loss_finite"], f"train loss not finite under FSDP2: {measured}"
assert measured["params_changed"], f"parameters did not change under FSDP2: {measured}"
Loading