diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 7e52f700b49..fbf5a6bf8a3 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import collections import copy import gc import json @@ -950,7 +951,7 @@ def mock_super_compute_loss(model, inputs, return_outputs=False, num_items_in_ba captured["skip_logits"] = inputs.get("skip_logits") dummy_loss = torch.tensor(1.0, requires_grad=True) dummy_outputs = MagicMock() - dummy_outputs.token_accuracy = None + dummy_outputs.token_accuracy = torch.tensor(0.5) dummy_outputs.logits = torch.randn(1, 5, trainer.model.config.vocab_size) return (dummy_loss, dummy_outputs) @@ -994,7 +995,12 @@ def test_predict_does_not_skip_logits_with_liger(self): def mock_super_compute_loss(model, inputs, return_outputs=False, num_items_in_batch=None): captured["skip_logits"] = inputs.get("skip_logits") dummy_loss = torch.tensor(1.0, requires_grad=True) - dummy_outputs = (dummy_loss, torch.randn(1, 5, trainer.model.config.vocab_size)) + DummyOutput = collections.namedtuple("DummyOutput", ["loss", "logits", "token_accuracy"]) + dummy_outputs = DummyOutput( + loss=dummy_loss, + logits=torch.randn(1, 5, trainer.model.config.vocab_size), + token_accuracy=torch.tensor(0.5), + ) return (dummy_loss, dummy_outputs) with patch("transformers.Trainer.compute_loss", side_effect=mock_super_compute_loss):