Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions tests/experimental/test_online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from types import SimpleNamespace

import pytest
import torch
from datasets import Dataset, features, load_dataset
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
from transformers.utils import is_peft_available, is_vision_available
Expand Down Expand Up @@ -407,6 +410,40 @@ def simple_reward_func(prompts, completions, completion_ids, **kwargs):
assert round(abs(trainer.reward_weights[0].item() - 0.7), 5) == 0
assert round(abs(trainer.reward_weights[1].item() - 0.3), 5) == 0

def test_generate_vllm_server_preserves_completion_token_lists(self, monkeypatch):
# Regression test for #5514. VLLMClient.generate returns completion_ids as a list[list[int]] with one
# token-id list per completion (the same shape the colocate path produces), so _generate_vllm_server must
# pass each completion through unchanged. Previously it re-flattened the list, turning every token into its
# own single-token "completion".
monkeypatch.setattr("trl.experimental.online_dpo.online_dpo_trainer.gather_object", lambda x: x)
monkeypatch.setattr(
"trl.experimental.online_dpo.online_dpo_trainer.broadcast_object_list",
lambda x, from_process=0: x,
)

server_completion_ids = [[101, 102, 103], [201, 202]]

trainer = OnlineDPOTrainer.__new__(OnlineDPOTrainer)
trainer.accelerator = SimpleNamespace(is_main_process=True, process_index=0)
trainer.state = SimpleNamespace(global_step=0)
trainer._last_loaded_step = 0 # equal to global_step, so no vLLM weight reload is triggered
trainer.num_generations = 2
trainer.repetition_penalty = 1.0
trainer.temperature = 1.0
trainer.top_p = 1.0
trainer.top_k = None
trainer.min_p = None
trainer.generation_config = SimpleNamespace(max_tokens=16)
trainer.args = SimpleNamespace(generation_kwargs={})
trainer.processing_class = lambda **kwargs: {"input_ids": torch.tensor([[5, 6, 7]])}
trainer.vllm_client = SimpleNamespace(generate=lambda **kwargs: {"completion_ids": server_completion_ids})

completion_ids, _ = trainer._generate_vllm_server(["a plain prompt"])

# Each completion keeps its full token-id list instead of being exploded into one single-token entry per
# token (which was the bug).
assert completion_ids == server_completion_ids


@require_vision
class TestOnlineDPOVisionTrainer(TrlTestCase):
Expand Down
2 changes: 0 additions & 2 deletions trl/experimental/online_dpo/online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,8 +673,6 @@ def _generate_vllm_server(self, prompts, images=None):
else None,
generation_kwargs=self.args.generation_kwargs,
)["completion_ids"]
# Flatten: each prompt generates 2 completions
completion_ids = [[comp_id] for prompt_completions in completion_ids for comp_id in prompt_completions]
else:
completion_ids = [None] * (len(all_prompts) * 2)

Expand Down
Loading