From 6a9f32a0ae0d373cf303666c15008a71d94da688 Mon Sep 17 00:00:00 2001 From: Vineeth Sai Date: Mon, 22 Jun 2026 20:53:40 -0700 Subject: [PATCH] Fix Online DPO vLLM server re-flattening completion_ids OnlineDPOTrainer._generate_vllm_server re-flattened the completion_ids returned by the vLLM client, which is already a list[list[int]] with one token-id list per completion (the same shape the colocate path and GRPO produce). The comprehension iterated over every completion and every token, turning each token into its own single-token completion, which corrupts the completion mask and the per-process row count. Remove the re-flatten so completion_ids is passed through unchanged. Added a CPU regression test that mocks the vLLM client and the distributed gather/broadcast and asserts each completion is preserved. Fixes #5514 --- tests/experimental/test_online_dpo_trainer.py | 37 +++++++++++++++++++ .../online_dpo/online_dpo_trainer.py | 2 - 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/tests/experimental/test_online_dpo_trainer.py b/tests/experimental/test_online_dpo_trainer.py index 355d20f2722..f9bcdb04f3a 100644 --- a/tests/experimental/test_online_dpo_trainer.py +++ b/tests/experimental/test_online_dpo_trainer.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from types import SimpleNamespace + import pytest +import torch from datasets import Dataset, features, load_dataset from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer from transformers.utils import is_peft_available, is_vision_available @@ -407,6 +410,40 @@ def simple_reward_func(prompts, completions, completion_ids, **kwargs): assert round(abs(trainer.reward_weights[0].item() - 0.7), 5) == 0 assert round(abs(trainer.reward_weights[1].item() - 0.3), 5) == 0 + def test_generate_vllm_server_preserves_completion_token_lists(self, monkeypatch): + # Regression test for #5514. VLLMClient.generate returns completion_ids as a list[list[int]] with one + # token-id list per completion (the same shape the colocate path produces), so _generate_vllm_server must + # pass each completion through unchanged. Previously it re-flattened the list, turning every token into its + # own single-token "completion". + monkeypatch.setattr("trl.experimental.online_dpo.online_dpo_trainer.gather_object", lambda x: x) + monkeypatch.setattr( + "trl.experimental.online_dpo.online_dpo_trainer.broadcast_object_list", + lambda x, from_process=0: x, + ) + + server_completion_ids = [[101, 102, 103], [201, 202]] + + trainer = OnlineDPOTrainer.__new__(OnlineDPOTrainer) + trainer.accelerator = SimpleNamespace(is_main_process=True, process_index=0) + trainer.state = SimpleNamespace(global_step=0) + trainer._last_loaded_step = 0 # equal to global_step, so no vLLM weight reload is triggered + trainer.num_generations = 2 + trainer.repetition_penalty = 1.0 + trainer.temperature = 1.0 + trainer.top_p = 1.0 + trainer.top_k = None + trainer.min_p = None + trainer.generation_config = SimpleNamespace(max_tokens=16) + trainer.args = SimpleNamespace(generation_kwargs={}) + trainer.processing_class = lambda **kwargs: {"input_ids": torch.tensor([[5, 6, 7]])} + trainer.vllm_client = SimpleNamespace(generate=lambda **kwargs: {"completion_ids": server_completion_ids}) + + completion_ids, _ = trainer._generate_vllm_server(["a plain prompt"]) + + # Each completion keeps its full token-id list instead of being exploded into one single-token entry per + # token (which was the bug). + assert completion_ids == server_completion_ids + @require_vision class TestOnlineDPOVisionTrainer(TrlTestCase): diff --git a/trl/experimental/online_dpo/online_dpo_trainer.py b/trl/experimental/online_dpo/online_dpo_trainer.py index f6da8910935..e220779a227 100644 --- a/trl/experimental/online_dpo/online_dpo_trainer.py +++ b/trl/experimental/online_dpo/online_dpo_trainer.py @@ -673,8 +673,6 @@ def _generate_vllm_server(self, prompts, images=None): else None, generation_kwargs=self.args.generation_kwargs, )["completion_ids"] - # Flatten: each prompt generates 2 completions - completion_ids = [[comp_id] for prompt_completions in completion_ids for comp_id in prompt_completions] else: completion_ids = [None] * (len(all_prompts) * 2)