Skip to content

Align KTO collator keys with the DPO convention#6182

Open
qgallouedec wants to merge 3 commits into
mainfrom
align-kto-collator-keys
Open

Align KTO collator keys with the DPO convention#6182
qgallouedec wants to merge 3 commits into
mainfrom
align-kto-collator-keys

Conversation

@qgallouedec

@qgallouedec qgallouedec commented Jun 25, 2026

Copy link
Copy Markdown
Member

KTO's collator named the assembled prompt+completion model input completion_input_ids / completion_attention_mask misleading (it's not completion-only), and the trainer renamed it straight back to input_ids before every forward. This aligns the keys with DPO: the policy branch is unprefixed, and a prefix marks only the KL branch.

  • completion_input_idsinput_ids, completion_attention_maskattention_mask
  • KL_completion_input_idsKL_input_ids, KL_completion_attention_maskKL_attention_mask
  • KL_completion_token_type_idsKL_token_type_ids, KL_completion_mm_token_type_idsKL_mm_token_type_ids
  • completion_mask / KL_completion_mask unchanged (they correctly mark completion positions); dataset columns (prompt_ids, completion_ids, KL_completion_ids) unchanged.

Removes the now-dead model_kwargs.pop("completion_input_ids") renames in the forward and the redundant output.pop("input_ids") in the vision collator (now identical to the DPO/SFT vision collators).


Note

Medium Risk
Renames every batch key on the KTO training path (text, VLM, Liger, ref log-prob precompute); behavior should be equivalent but any external code or forks expecting completion_input_ids will break.

Overview
Renames KTO collator and trainer batch tensor keys so the policy forward path matches DPO/SFT: full prompt+completion sequences are exposed as input_ids / attention_mask instead of completion_input_ids / completion_attention_mask, with KL branches under KL_input_ids, KL_attention_mask, and (for VLMs) KL_token_type_ids / KL_mm_token_type_ids.

completion_mask and KL_completion_mask stay as-is; dataset columns (prompt_ids, completion_ids, KL_completion_ids) are unchanged.

The trainer no longer renames collator keys before model(**kwargs) (removed pop("completion_input_ids") style shims). The vision collator writes assembled tensors directly into the output dict instead of dropping processor input_ids and aliasing them under the old names. Tests are updated to assert the new keys.

Reviewed by Cursor Bugbot for commit 1e16213. Bugbot is set up for automated code reviews on this repo. Configure here.

@bot-ci-comment

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant