Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/bentoml/_internal/frameworks/common/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,14 @@ def _run(self: PytorchModelRunnable, *args: t.Any, **kwargs: t.Any) -> torch.Ten
params = Params(*args, **kwargs)

def _mapping(item: T) -> torch.Tensor | T:
# ``torch.Tensor(...)`` (the type constructor) ignores the input
# dtype and produces ``float32``. Use ``torch.from_numpy`` instead
# so an ``np.float16``/``np.int64``/``np.bool_`` array is handed to
# the model with a matching torch dtype (see #4266).
if LazyType["ext.NpNDArray"]("numpy.ndarray").isinstance(item):
return torch.Tensor(item, device=self.device_id)
return torch.from_numpy(item).to(self.device_id)
if LazyType["ext.PdDataFrame"]("pandas.DataFrame").isinstance(item):
return torch.Tensor(item.to_numpy(), device=self.device_id)
return torch.from_numpy(item.to_numpy()).to(self.device_id)
if LazyType["torch.Tensor"]("torch.Tensor").isinstance(item):
return item.to(self.device_id)
else:
Expand Down
6 changes: 5 additions & 1 deletion src/bentoml/_internal/frameworks/detectron.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,12 @@ def _run(self: Detectron2Runnable, *args: t.Any, **kwargs: t.Any) -> t.Any:
)

def mapping(item: ext.NpNDArray | torch.Tensor) -> t.Any:
# Mirror the fix from common/pytorch.py: ``torch.Tensor(arr)``
# silently upcasts every input to ``float32``, breaking models
# whose weights are ``float16``/``int64``/``bool``. Use
# ``torch.from_numpy`` to preserve the numpy dtype (#4266).
if LazyType["ext.NpNDArray"]("numpy.ndarray").isinstance(item):
return torch.Tensor(item, device=self.device_id)
return torch.from_numpy(item).to(self.device_id)
elif isinstance(item, torch.Tensor):
return item.to(self.device_id)
else:
Expand Down
66 changes: 66 additions & 0 deletions tests/integration/frameworks/test_pytorch_unit.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

import numpy as np
import pytest
import torch

from bentoml._internal.frameworks.common.pytorch import make_pytorch_runnable_method
from bentoml._internal.frameworks.pytorch import PyTorchTensorContainer
from bentoml._internal.runner.container import AutoContainer

Expand Down Expand Up @@ -39,3 +41,67 @@ def test_pytorch_container(batch_axis: int):
AutoContainer.from_payload(AutoContainer.to_payload(one_batch, batch_dim=0))
== one_batch
).all()


class _IdentityModel(torch.nn.Module):
"""Minimal nn.Module that returns its first input — used to inspect the
tensor the runnable hands to the model after numpy/pandas conversion."""

def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


class _FakeRunnable:
"""A stand-in for PytorchModelRunnable that lets us drive
make_pytorch_runnable_method without a saved bento model."""

def __init__(self) -> None:
self.device_id = "cpu"
self.model = _IdentityModel()


@pytest.mark.parametrize(
"np_dtype, expected_torch_dtype",
[
(np.float16, torch.float16),
(np.float32, torch.float32),
(np.float64, torch.float64),
(np.int8, torch.int8),
(np.int32, torch.int32),
(np.int64, torch.int64),
(np.uint8, torch.uint8),
(np.bool_, torch.bool),
],
)
def test_pytorch_runnable_method_preserves_numpy_dtype(
np_dtype: np.dtype, expected_torch_dtype: torch.dtype
) -> None:
"""Regression for #4266: the numpy -> torch conversion inside the pytorch
runnable used ``torch.Tensor(arr)``, which silently upcast every input to
``float32`` and broke models whose weights were ``float16``, ``int64``,
etc."""
runnable = _FakeRunnable()
method = make_pytorch_runnable_method("forward")
arr = (
np.array([True, False, True], dtype=np_dtype)
if np_dtype is np.bool_
else np.array([1, 2, 3], dtype=np_dtype)
)

result = method(runnable, arr)

assert isinstance(result, torch.Tensor)
assert result.dtype == expected_torch_dtype


def test_pytorch_runnable_method_preserves_pandas_dtype() -> None:
"""Same dtype-preservation contract for pandas DataFrame inputs."""
pd = pytest.importorskip("pandas")
runnable = _FakeRunnable()
method = make_pytorch_runnable_method("forward")
df = pd.DataFrame({"a": np.array([1, 2, 3], dtype=np.int64)})

result = method(runnable, df)

assert isinstance(result, torch.Tensor)
assert result.dtype == torch.int64