Fix spurious liger token_accuracy CI warnings in SFT tests#6189
Open
albertvillanova wants to merge 1 commit into
Open
Fix spurious liger token_accuracy CI warnings in SFT tests#6189albertvillanova wants to merge 1 commit into
albertvillanova wants to merge 1 commit into
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR fixes two test mocks in test_sft_trainer.py that were emitting spurious UserWarnings about liger-kernel not returning token_accuracy.
Motivation
Both test_compute_loss_skip_logits_on_eval_without_metrics_with_liger and test_predict_does_not_skip_logits_with_liger are tests about skip_logits behaviour. Their mocks were accidentally simulating a broken liger call (returning None or no token_accuracy), which caused the SFT trainer's liger branch to emit a warning intended for genuine liger failures. The warning is correct by design; the mocks were not.
Real liger always returns token_accuracy when return_token_accuracy=True is set, even when skip_logits=True, because token accuracy is computed inside the fused kernel without materialising the full logits tensor.
Changes
Note
Low Risk
Test-only mock fixes; no changes to SFT trainer or runtime behavior.
Overview
Updates two Liger +
skip_logitstests intest_sft_trainer.pyso their patchedcompute_lossmocks behave like real liger-kernel (which always suppliestoken_accuracywhenreturn_token_accuracy=True), instead of returningNoneor outputs without that field.In
test_compute_loss_skip_logits_on_eval_without_metrics_with_liger, the mockMagicMocknow setstoken_accuracyto a tensor instead ofNone. Intest_predict_does_not_skip_logits_with_liger, the mock return value is acollections.namedtuplewithloss,logits, andtoken_accuracyso it stays tuple-indexable forprediction_stepwhile exposing the attribute the trainer reads.import collectionsis added for the namedtuple.This removes spurious CI **
UserWarning**s from the intentional liger failure path inSFTTrainer.compute_loss; production trainer logic is unchanged.Reviewed by Cursor Bugbot for commit 9aeb466. Bugbot is set up for automated code reviews on this repo. Configure here.