Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
f3401df
[PyTorch] Make tensorless quantizers opaque value objects for torch.c…
pggPL Jun 6, 2026
c4ad54c
[PyTorch] Drop quantizer value registry; reconstruct via __fx_repr__ …
pggPL Jun 6, 2026
a06324b
[PyTorch] Split dynamo.py into a dynamo/ package
pggPL Jun 7, 2026
ea5b396
[PyTorch] Raise in quantizer __fx_repr__ when a process group is stored
pggPL Jun 8, 2026
aa65e34
[PyTorch] Cover NVFP4 in quantizer value-object test
pggPL Jun 8, 2026
e1b1db6
Reject a value quantizer that carries an amax reduction group in __eq…
pggPL Jun 16, 2026
8c33d0e
Recognize value-opaque quantizers via a class flag
pggPL Jun 16, 2026
945f62d
Address review: narrow opaque-type except, add fullgraph test, fix nv…
pggPL Jun 29, 2026
e3c8f43
Restore NVFP4 rht_matrix on value-key rebuild; assert quantize round-…
pggPL Jun 29, 2026
3f68621
Enforce process-group rejection in _value_key, not __fx_repr__; add test
pggPL Jun 29, 2026
32d1768
Strengthen fullgraph test: quantize/dequantize via a custom op, not p…
pggPL Jun 29, 2026
28bde9e
Clarify comments: rht_matrix_random_sign_mask_t derivation; why the o…
pggPL Jun 29, 2026
2c3c5df
Reword opaque-flag comment: self-contained, no Linear reference
pggPL Jun 29, 2026
826f271
Cover is_opaque_value_type with the import-safety guard too
pggPL Jun 29, 2026
ad1ccce
Add TensorProto mechanism for data-free quantized tensor allocation
pggPL Jun 16, 2026
ea3df7a
[PyTorch] torch.compile: dedup cached FP8 weight from saved-for-backward
pggPL Jun 22, 2026
4997929
[PyTorch] nvfp4: emit _describe_buffers in canonical flatten order
pggPL Jun 22, 2026
50c11cd
Address review: error on undescribed buffers, gate nvfp4 test on HW s…
pggPL Jun 29, 2026
ff48e52
[PyTorch] Workaround torch.compile staticmethod guard bug in NVFP4 _d…
pggPL Jun 29, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
434 changes: 433 additions & 1 deletion tests/pytorch/test_torch_compile.py

Large diffs are not rendered by default.

15 changes: 15 additions & 0 deletions transformer_engine/pytorch/dynamo/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""torch.compile glue for Transformer Engine."""

from .quantizer_opaque import register_value_opaque_quantizer, is_value_opaque_quantizer
from .tensor_proto import TensorProto, to_tensor_proto

__all__ = [
"register_value_opaque_quantizer",
"is_value_opaque_quantizer",
"TensorProto",
"to_tensor_proto",
]
116 changes: 116 additions & 0 deletions transformer_engine/pytorch/dynamo/quantizer_opaque.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""Value-opaque quantizers for torch.compile."""

from __future__ import annotations
from typing import Any, Dict, Tuple

from ..constants import DType


# Registration marks the class with this attribute rather than recording it in a
# module-level set. It looks odd but is a deliberate workaround: the check must
# stay traceable when it runs inside a torch.compile graph -- Dynamo can bake a
# ``getattr`` on the opaque quantizer into a constant, but cannot evaluate
# ``type(q) in some_set`` (no equality/hash rules for the opaque class object),
# which would graph-break under ``fullgraph=True``.
_VALUE_OPAQUE_FLAG = "_te_compile_value_opaque"


def is_value_opaque_quantizer(quantizer: Any) -> bool:
"""Whether *quantizer*'s class is registered as a torch.compile value-opaque
type."""
return getattr(quantizer, _VALUE_OPAQUE_FLAG, False)


def _rebuild_quantizer(cls: type, items: Tuple[Tuple[str, Any], ...]) -> Any:
"""Rebuild a tensorless quantizer of type *cls* from its value items.

Referenced by the ``__fx_repr__`` emitted for value-opaque quantizers; the
generated FX code calls this to materialize the quantizer constant.
"""
# Bypass ``__init__`` and restore the value attributes directly: the value
# items already capture every value-defining field (including derived ones),
# and the constructors have heterogeneous signatures / side effects.
obj = cls.__new__(cls)
field_names = set()
for name, value in items:
if name == "dtype":
value = DType.cast(value)
object.__setattr__(obj, name, value)
field_names.add(name)
# The deprecated amax-reduction group is not a value field; initialize it to
# None so attribute access keeps working on the rebuilt quantizer.
if "with_amax_reduction" in field_names and not hasattr(obj, "amax_reduction_group"):
object.__setattr__(obj, "amax_reduction_group", None)
# Restore non-value derived state that ``__init__`` would normally build but
# that cannot live in the value key (e.g. NVFP4's ``rht_matrix`` tensor).
finalize = getattr(obj, "_rebuild_derived_state", None)
if finalize is not None:
finalize()
return obj


def _quantizer_fx_repr(self: Any) -> Tuple[str, Dict[str, Any]]:
"""``__fx_repr__`` for value-opaque quantizers (attached at registration).

Returns an evaluable expression that rebuilds the quantizer via
:func:`_rebuild_quantizer`, capturing both the helper and the quantizer
class itself in the FX globals so codegen can resolve them with no global
registry and no qualname collisions.

Raises ``TypeError`` (via :meth:`Quantizer._value_key`) if the quantizer
stores a process group (e.g. a non-``None`` deprecated
``amax_reduction_group``): live distributed state must never be baked into
the graph as a constant. Pass the reduction group per quantize call instead
of storing it on the quantizer.
"""
cls = type(self)
items = self._value_key()[1]
return (
f"_rebuild_quantizer({cls.__name__}, {items!r})",
{"_rebuild_quantizer": _rebuild_quantizer, cls.__name__: cls},
)


def register_value_opaque_quantizer(cls: type) -> None:
"""Register a tensorless quantizer class as a torch.compile value opaque type.

Attaches ``__fx_repr__`` and registers the class with
``torch._library.opaque_object``. Safe to call on any PyTorch build: on
versions without the opaque-object API it only attaches ``__fx_repr__``
(harmless), so Transformer Engine keeps importing and running in eager mode.

The quantizer class must already provide value ``__eq__`` / ``__hash__`` and
a non-``None`` ``_value_fields`` (see
:class:`transformer_engine.pytorch.quantized_tensor.Quantizer`).
"""
# Stamp the class so it can be recognized as value-opaque in dynamo-traced
# code (used to fall back to eager for unregistered quantizers).
setattr(cls, _VALUE_OPAQUE_FLAG, True)

# ``register_opaque_type`` requires ``__fx_repr__`` to already exist on the
# class, so attach it before registering.
if "__fx_repr__" not in cls.__dict__:
cls.__fx_repr__ = _quantizer_fx_repr

try:
from torch._library.opaque_object import ( # pylint: disable=import-outside-toplevel
register_opaque_type,
is_opaque_value_type,
)
except (ImportError, AttributeError):
# Older PyTorch without the opaque-object API: eager value semantics
# still work; torch.compile specialization on the quantizer does not.
return

try:
if not is_opaque_value_type(cls):
register_opaque_type(cls, typ="value")
except (RuntimeError, TypeError):
# Keep TE importable: neither the opaque-type query nor the registration
# must crash the import, e.g. on PyTorch versions with only partial /
# experimental opaque-object support.
pass
177 changes: 177 additions & 0 deletions transformer_engine/pytorch/dynamo/tensor_proto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""TensorProto: a data-free description of a tensor / quantized tensor."""

from __future__ import annotations
import copy as _copy
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple

import torch


def _contiguous_stride(shape: Tuple[int, ...]) -> Tuple[int, ...]:
Comment thread
pggPL marked this conversation as resolved.
"""Row-major (contiguous) stride for ``shape``."""
stride: list = []
acc = 1
for dim in reversed(shape):
stride.append(acc)
acc *= dim
return tuple(reversed(stride))


@dataclass
class TensorProto:
"""A data-free *prototype* of a tensor or quantized tensor.

Captures ``shape`` / ``dtype`` and, for quantized tensors, the
(value-opaque) ``quantizer`` -- enough to rebuild a tensor without holding
storage. The common abstraction over plain ``torch.Tensor``,
``QuantizedTensorStorage`` and ``QuantizedTensor``, used for custom-op fake
impls and for reassembling a quantized tensor from bare buffers.
"""

shape: Tuple[int, ...]
dtype: torch.dtype
quantizer: Optional[Any] = None
requires_grad: bool = False
device: Optional[torch.device] = field(default=None)

def __post_init__(self) -> None:
# Own a private copy of the quantizer so usage changes (update_usage)
# never touch the shared, value-opaque quantizer. The copy inherits the
# quantizer's current row-/column-wise usage as this proto's layout.
if self.quantizer is not None:
q = self.quantizer
self.quantizer = q.copy() if hasattr(q, "copy") else _copy.copy(q)

@property
def is_quantized(self) -> bool:
"""Whether this proto describes a quantized tensor."""
return self.quantizer is not None

def update_usage(
self,
*,
rowwise_usage: Optional[bool] = None,
columnwise_usage: Optional[bool] = None,
) -> None:
"""Mirror ``QuantizedTensor.update_usage`` on the proto's buffer layout.

Applied to the proto's own quantizer copy, so the shared (value-opaque)
quantizer is never mutated. No-op for plain (non-quantized) protos.
"""
if self.quantizer is None:
return
self.quantizer.set_usage(rowwise=rowwise_usage, columnwise=columnwise_usage)

def inner_names(self) -> Tuple[str, ...]:
"""Names of the flat tensor buffers backing this proto, in order.

The real op flattens a quantized output via the storage's
``__tensor_flatten__`` -- i.e. ``_FLATTEN_TENSOR_BUFFERS`` order, keeping
only the present buffers. ``_describe_buffers`` may emit the same buffers
in a different (per-usage) order (e.g. NVFP4 groups each amax right after
its scale), so reorder to the canonical flatten order here to keep the
fake layout aligned with the real one slot-for-slot.
"""
if self.quantizer is None:
return ("data",)
# pylint: disable=protected-access
described = list(self.quantizer._describe_buffers(tuple(self.shape)).keys())
storage_cls = self.quantizer._storage_metadata(self.dtype)["cls"]
flatten_order = [attr for attr, _ in storage_cls._FLATTEN_TENSOR_BUFFERS]
extra = [name for name in described if name not in flatten_order]
if extra:
raise RuntimeError(
f"{storage_cls.__name__} describes buffer(s) {extra} absent from its "
f"_FLATTEN_TENSOR_BUFFERS {flatten_order}; the fake layout cannot be "
"aligned with the real one slot-for-slot."
)
return tuple(name for name in flatten_order if name in described)

def create_metadata(self) -> Dict[str, Any]:
"""Data-free ``__tensor_unflatten__`` context describing this tensor."""
if self.quantizer is None:
return {
"is_tensor": True,
"is_quantized": False,
"dtype": self.dtype,
"requires_grad": self.requires_grad,
}
return self.quantizer.create_metadata(
tuple(self.shape), dtype=self.dtype, requires_grad=self.requires_grad
)

def create_inner_tensors(self) -> List[torch.Tensor]:
"""Materialize the flat inner buffers (in :meth:`inner_names` order).

Under ``register_fake`` the ``torch.empty`` calls produce ``FakeTensor``s;
``requires_grad`` is left default (managed by ``register_autograd``).
"""
device = self.device if self.device is not None else torch.device("cuda")
if self.quantizer is None:
return [torch.empty(tuple(self.shape), dtype=self.dtype, device=device)]
inner = self.quantizer.alloc_tensors(tuple(self.shape), device=device)
return [inner[name] for name in self.inner_names()]

def create_tensor(self) -> torch.Tensor:
"""Materialize an (uninitialized) tensor matching this proto (traceable).

Quantized protos reassemble the :meth:`create_inner_tensors` buffers via
the storage's ``__tensor_unflatten__``.
"""
if self.quantizer is None:
device = self.device if self.device is not None else torch.device("cuda")
return torch.empty(
tuple(self.shape),
dtype=self.dtype,
device=device,
requires_grad=self.requires_grad,
)
from ..quantized_tensor import ( # pylint: disable=import-outside-toplevel
_STORAGE_REGISTRY,
)

shape = tuple(self.shape)
ctx = self.create_metadata()
inner = dict(zip(self.inner_names(), self.create_inner_tensors()))
storage_cls = _STORAGE_REGISTRY[ctx["cls"]]
return storage_cls.__tensor_unflatten__(inner, ctx, shape, _contiguous_stride(shape))


def to_tensor_proto(tensor: Any) -> TensorProto:
"""Build a :class:`TensorProto` describing ``tensor``.

Works for plain ``torch.Tensor`` and for ``QuantizedTensorStorage`` /
``QuantizedTensor``. A *bare* storage exposes its shape via ``.size()`` and
its (fake) dtype via ``_dtype`` rather than ``.shape`` / ``.dtype``.
"""
from ..quantized_tensor import ( # pylint: disable=import-outside-toplevel
QuantizedTensorStorage,
)

requires_grad = bool(getattr(tensor, "requires_grad", False))
if isinstance(tensor, QuantizedTensorStorage):
shape = getattr(tensor, "shape", None)
if shape is None:
shape = tensor.size()
dtype = getattr(tensor, "dtype", None)
if dtype is None:
dtype = getattr(tensor, "_dtype", None)
return TensorProto(
shape=tuple(shape),
dtype=dtype,
quantizer=getattr(tensor, "_quantizer", None),
requires_grad=requires_grad,
device=tensor.device,
)
return TensorProto(
shape=tuple(tensor.shape),
dtype=tensor.dtype,
quantizer=None,
requires_grad=requires_grad,
device=tensor.device,
)
Loading
Loading