Skip to content

Fix spurious liger token_accuracy CI warnings in SFT tests#6189

Open
albertvillanova wants to merge 1 commit into
mainfrom
fix-warnings-liger-token_accuracy
Open

Fix spurious liger token_accuracy CI warnings in SFT tests#6189
albertvillanova wants to merge 1 commit into
mainfrom
fix-warnings-liger-token_accuracy

Conversation

@albertvillanova

@albertvillanova albertvillanova commented Jun 26, 2026

Copy link
Copy Markdown
Member

This PR fixes two test mocks in test_sft_trainer.py that were emitting spurious UserWarnings about liger-kernel not returning token_accuracy.

Motivation

Both test_compute_loss_skip_logits_on_eval_without_metrics_with_liger and test_predict_does_not_skip_logits_with_liger are tests about skip_logits behaviour. Their mocks were accidentally simulating a broken liger call (returning None or no token_accuracy), which caused the SFT trainer's liger branch to emit a warning intended for genuine liger failures. The warning is correct by design; the mocks were not.

Real liger always returns token_accuracy when return_token_accuracy=True is set, even when skip_logits=True, because token accuracy is computed inside the fused kernel without materialising the full logits tensor.

Changes

  • In test_compute_loss_skip_logits_on_eval_without_metrics_with_liger: set dummy_outputs.token_accuracy = torch.tensor(0.5) instead of None
  • In test_predict_does_not_skip_logits_with_liger: replace the bare tuple with a namedtuple that exposes a token_accuracy field while remaining tuple-indexable (required by transformers' prediction_step which slices outputs as outputs[1:])

Note

Low Risk
Test-only mock fixes; no changes to SFT trainer or runtime behavior.

Overview
Updates two Liger + skip_logits tests in test_sft_trainer.py so their patched compute_loss mocks behave like real liger-kernel (which always supplies token_accuracy when return_token_accuracy=True), instead of returning None or outputs without that field.

In test_compute_loss_skip_logits_on_eval_without_metrics_with_liger, the mock MagicMock now sets token_accuracy to a tensor instead of None. In test_predict_does_not_skip_logits_with_liger, the mock return value is a collections.namedtuple with loss, logits, and token_accuracy so it stays tuple-indexable for prediction_step while exposing the attribute the trainer reads. import collections is added for the namedtuple.

This removes spurious CI **UserWarning**s from the intentional liger failure path in SFTTrainer.compute_loss; production trainer logic is unchanged.

Reviewed by Cursor Bugbot for commit 9aeb466. Bugbot is set up for automated code reviews on this repo. Configure here.

@bot-ci-comment

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant