-
Notifications
You must be signed in to change notification settings - Fork 759
[PyTorch][torch.compile] Add TensorProto mechanism #3153
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
pggPL
wants to merge
19
commits into
NVIDIA:main
Choose a base branch
from
pggPL:tensor_proto_mechanism
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
f3401df
[PyTorch] Make tensorless quantizers opaque value objects for torch.c…
pggPL c4ad54c
[PyTorch] Drop quantizer value registry; reconstruct via __fx_repr__ …
pggPL a06324b
[PyTorch] Split dynamo.py into a dynamo/ package
pggPL ea5b396
[PyTorch] Raise in quantizer __fx_repr__ when a process group is stored
pggPL aa65e34
[PyTorch] Cover NVFP4 in quantizer value-object test
pggPL e1b1db6
Reject a value quantizer that carries an amax reduction group in __eq…
pggPL 8c33d0e
Recognize value-opaque quantizers via a class flag
pggPL 945f62d
Address review: narrow opaque-type except, add fullgraph test, fix nv…
pggPL e3c8f43
Restore NVFP4 rht_matrix on value-key rebuild; assert quantize round-…
pggPL 3f68621
Enforce process-group rejection in _value_key, not __fx_repr__; add test
pggPL 32d1768
Strengthen fullgraph test: quantize/dequantize via a custom op, not p…
pggPL 28bde9e
Clarify comments: rht_matrix_random_sign_mask_t derivation; why the o…
pggPL 2c3c5df
Reword opaque-flag comment: self-contained, no Linear reference
pggPL 826f271
Cover is_opaque_value_type with the import-safety guard too
pggPL ad1ccce
Add TensorProto mechanism for data-free quantized tensor allocation
pggPL ea3df7a
[PyTorch] torch.compile: dedup cached FP8 weight from saved-for-backward
pggPL 4997929
[PyTorch] nvfp4: emit _describe_buffers in canonical flatten order
pggPL 50c11cd
Address review: error on undescribed buffers, gate nvfp4 test on HW s…
pggPL ff48e52
[PyTorch] Workaround torch.compile staticmethod guard bug in NVFP4 _d…
pggPL File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # | ||
| # See LICENSE for license information. | ||
|
|
||
| """torch.compile glue for Transformer Engine.""" | ||
|
|
||
| from .quantizer_opaque import register_value_opaque_quantizer, is_value_opaque_quantizer | ||
| from .tensor_proto import TensorProto, to_tensor_proto | ||
|
|
||
| __all__ = [ | ||
| "register_value_opaque_quantizer", | ||
| "is_value_opaque_quantizer", | ||
| "TensorProto", | ||
| "to_tensor_proto", | ||
| ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,116 @@ | ||
| # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # | ||
| # See LICENSE for license information. | ||
|
|
||
| """Value-opaque quantizers for torch.compile.""" | ||
|
|
||
| from __future__ import annotations | ||
| from typing import Any, Dict, Tuple | ||
|
|
||
| from ..constants import DType | ||
|
|
||
|
|
||
| # Registration marks the class with this attribute rather than recording it in a | ||
| # module-level set. It looks odd but is a deliberate workaround: the check must | ||
| # stay traceable when it runs inside a torch.compile graph -- Dynamo can bake a | ||
| # ``getattr`` on the opaque quantizer into a constant, but cannot evaluate | ||
| # ``type(q) in some_set`` (no equality/hash rules for the opaque class object), | ||
| # which would graph-break under ``fullgraph=True``. | ||
| _VALUE_OPAQUE_FLAG = "_te_compile_value_opaque" | ||
|
|
||
|
|
||
| def is_value_opaque_quantizer(quantizer: Any) -> bool: | ||
| """Whether *quantizer*'s class is registered as a torch.compile value-opaque | ||
| type.""" | ||
| return getattr(quantizer, _VALUE_OPAQUE_FLAG, False) | ||
|
|
||
|
|
||
| def _rebuild_quantizer(cls: type, items: Tuple[Tuple[str, Any], ...]) -> Any: | ||
| """Rebuild a tensorless quantizer of type *cls* from its value items. | ||
|
|
||
| Referenced by the ``__fx_repr__`` emitted for value-opaque quantizers; the | ||
| generated FX code calls this to materialize the quantizer constant. | ||
| """ | ||
| # Bypass ``__init__`` and restore the value attributes directly: the value | ||
| # items already capture every value-defining field (including derived ones), | ||
| # and the constructors have heterogeneous signatures / side effects. | ||
| obj = cls.__new__(cls) | ||
| field_names = set() | ||
| for name, value in items: | ||
| if name == "dtype": | ||
| value = DType.cast(value) | ||
| object.__setattr__(obj, name, value) | ||
| field_names.add(name) | ||
| # The deprecated amax-reduction group is not a value field; initialize it to | ||
| # None so attribute access keeps working on the rebuilt quantizer. | ||
| if "with_amax_reduction" in field_names and not hasattr(obj, "amax_reduction_group"): | ||
| object.__setattr__(obj, "amax_reduction_group", None) | ||
| # Restore non-value derived state that ``__init__`` would normally build but | ||
| # that cannot live in the value key (e.g. NVFP4's ``rht_matrix`` tensor). | ||
| finalize = getattr(obj, "_rebuild_derived_state", None) | ||
| if finalize is not None: | ||
| finalize() | ||
| return obj | ||
|
|
||
|
|
||
| def _quantizer_fx_repr(self: Any) -> Tuple[str, Dict[str, Any]]: | ||
| """``__fx_repr__`` for value-opaque quantizers (attached at registration). | ||
|
|
||
| Returns an evaluable expression that rebuilds the quantizer via | ||
| :func:`_rebuild_quantizer`, capturing both the helper and the quantizer | ||
| class itself in the FX globals so codegen can resolve them with no global | ||
| registry and no qualname collisions. | ||
|
|
||
| Raises ``TypeError`` (via :meth:`Quantizer._value_key`) if the quantizer | ||
| stores a process group (e.g. a non-``None`` deprecated | ||
| ``amax_reduction_group``): live distributed state must never be baked into | ||
| the graph as a constant. Pass the reduction group per quantize call instead | ||
| of storing it on the quantizer. | ||
| """ | ||
| cls = type(self) | ||
| items = self._value_key()[1] | ||
| return ( | ||
| f"_rebuild_quantizer({cls.__name__}, {items!r})", | ||
| {"_rebuild_quantizer": _rebuild_quantizer, cls.__name__: cls}, | ||
| ) | ||
|
|
||
|
|
||
| def register_value_opaque_quantizer(cls: type) -> None: | ||
| """Register a tensorless quantizer class as a torch.compile value opaque type. | ||
|
|
||
| Attaches ``__fx_repr__`` and registers the class with | ||
| ``torch._library.opaque_object``. Safe to call on any PyTorch build: on | ||
| versions without the opaque-object API it only attaches ``__fx_repr__`` | ||
| (harmless), so Transformer Engine keeps importing and running in eager mode. | ||
|
|
||
| The quantizer class must already provide value ``__eq__`` / ``__hash__`` and | ||
| a non-``None`` ``_value_fields`` (see | ||
| :class:`transformer_engine.pytorch.quantized_tensor.Quantizer`). | ||
| """ | ||
| # Stamp the class so it can be recognized as value-opaque in dynamo-traced | ||
| # code (used to fall back to eager for unregistered quantizers). | ||
| setattr(cls, _VALUE_OPAQUE_FLAG, True) | ||
|
|
||
| # ``register_opaque_type`` requires ``__fx_repr__`` to already exist on the | ||
| # class, so attach it before registering. | ||
| if "__fx_repr__" not in cls.__dict__: | ||
| cls.__fx_repr__ = _quantizer_fx_repr | ||
|
|
||
| try: | ||
| from torch._library.opaque_object import ( # pylint: disable=import-outside-toplevel | ||
| register_opaque_type, | ||
| is_opaque_value_type, | ||
| ) | ||
| except (ImportError, AttributeError): | ||
| # Older PyTorch without the opaque-object API: eager value semantics | ||
| # still work; torch.compile specialization on the quantizer does not. | ||
| return | ||
|
|
||
| try: | ||
| if not is_opaque_value_type(cls): | ||
| register_opaque_type(cls, typ="value") | ||
| except (RuntimeError, TypeError): | ||
| # Keep TE importable: neither the opaque-type query nor the registration | ||
| # must crash the import, e.g. on PyTorch versions with only partial / | ||
| # experimental opaque-object support. | ||
| pass |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,177 @@ | ||
| # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # | ||
| # See LICENSE for license information. | ||
|
|
||
| """TensorProto: a data-free description of a tensor / quantized tensor.""" | ||
|
|
||
| from __future__ import annotations | ||
| import copy as _copy | ||
| from dataclasses import dataclass, field | ||
| from typing import Any, Dict, List, Optional, Tuple | ||
|
|
||
| import torch | ||
|
|
||
|
|
||
| def _contiguous_stride(shape: Tuple[int, ...]) -> Tuple[int, ...]: | ||
| """Row-major (contiguous) stride for ``shape``.""" | ||
| stride: list = [] | ||
| acc = 1 | ||
| for dim in reversed(shape): | ||
| stride.append(acc) | ||
| acc *= dim | ||
| return tuple(reversed(stride)) | ||
|
|
||
|
|
||
| @dataclass | ||
| class TensorProto: | ||
| """A data-free *prototype* of a tensor or quantized tensor. | ||
|
|
||
| Captures ``shape`` / ``dtype`` and, for quantized tensors, the | ||
| (value-opaque) ``quantizer`` -- enough to rebuild a tensor without holding | ||
| storage. The common abstraction over plain ``torch.Tensor``, | ||
| ``QuantizedTensorStorage`` and ``QuantizedTensor``, used for custom-op fake | ||
| impls and for reassembling a quantized tensor from bare buffers. | ||
| """ | ||
|
|
||
| shape: Tuple[int, ...] | ||
| dtype: torch.dtype | ||
| quantizer: Optional[Any] = None | ||
| requires_grad: bool = False | ||
| device: Optional[torch.device] = field(default=None) | ||
|
|
||
| def __post_init__(self) -> None: | ||
| # Own a private copy of the quantizer so usage changes (update_usage) | ||
| # never touch the shared, value-opaque quantizer. The copy inherits the | ||
| # quantizer's current row-/column-wise usage as this proto's layout. | ||
| if self.quantizer is not None: | ||
| q = self.quantizer | ||
| self.quantizer = q.copy() if hasattr(q, "copy") else _copy.copy(q) | ||
|
|
||
| @property | ||
| def is_quantized(self) -> bool: | ||
| """Whether this proto describes a quantized tensor.""" | ||
| return self.quantizer is not None | ||
|
|
||
| def update_usage( | ||
| self, | ||
| *, | ||
| rowwise_usage: Optional[bool] = None, | ||
| columnwise_usage: Optional[bool] = None, | ||
| ) -> None: | ||
| """Mirror ``QuantizedTensor.update_usage`` on the proto's buffer layout. | ||
|
|
||
| Applied to the proto's own quantizer copy, so the shared (value-opaque) | ||
| quantizer is never mutated. No-op for plain (non-quantized) protos. | ||
| """ | ||
| if self.quantizer is None: | ||
| return | ||
| self.quantizer.set_usage(rowwise=rowwise_usage, columnwise=columnwise_usage) | ||
|
|
||
| def inner_names(self) -> Tuple[str, ...]: | ||
| """Names of the flat tensor buffers backing this proto, in order. | ||
|
|
||
| The real op flattens a quantized output via the storage's | ||
| ``__tensor_flatten__`` -- i.e. ``_FLATTEN_TENSOR_BUFFERS`` order, keeping | ||
| only the present buffers. ``_describe_buffers`` may emit the same buffers | ||
| in a different (per-usage) order (e.g. NVFP4 groups each amax right after | ||
| its scale), so reorder to the canonical flatten order here to keep the | ||
| fake layout aligned with the real one slot-for-slot. | ||
| """ | ||
| if self.quantizer is None: | ||
| return ("data",) | ||
| # pylint: disable=protected-access | ||
| described = list(self.quantizer._describe_buffers(tuple(self.shape)).keys()) | ||
| storage_cls = self.quantizer._storage_metadata(self.dtype)["cls"] | ||
| flatten_order = [attr for attr, _ in storage_cls._FLATTEN_TENSOR_BUFFERS] | ||
| extra = [name for name in described if name not in flatten_order] | ||
| if extra: | ||
| raise RuntimeError( | ||
| f"{storage_cls.__name__} describes buffer(s) {extra} absent from its " | ||
| f"_FLATTEN_TENSOR_BUFFERS {flatten_order}; the fake layout cannot be " | ||
| "aligned with the real one slot-for-slot." | ||
| ) | ||
| return tuple(name for name in flatten_order if name in described) | ||
|
|
||
| def create_metadata(self) -> Dict[str, Any]: | ||
| """Data-free ``__tensor_unflatten__`` context describing this tensor.""" | ||
| if self.quantizer is None: | ||
| return { | ||
| "is_tensor": True, | ||
| "is_quantized": False, | ||
| "dtype": self.dtype, | ||
| "requires_grad": self.requires_grad, | ||
| } | ||
| return self.quantizer.create_metadata( | ||
| tuple(self.shape), dtype=self.dtype, requires_grad=self.requires_grad | ||
| ) | ||
|
|
||
| def create_inner_tensors(self) -> List[torch.Tensor]: | ||
| """Materialize the flat inner buffers (in :meth:`inner_names` order). | ||
|
|
||
| Under ``register_fake`` the ``torch.empty`` calls produce ``FakeTensor``s; | ||
| ``requires_grad`` is left default (managed by ``register_autograd``). | ||
| """ | ||
| device = self.device if self.device is not None else torch.device("cuda") | ||
| if self.quantizer is None: | ||
| return [torch.empty(tuple(self.shape), dtype=self.dtype, device=device)] | ||
| inner = self.quantizer.alloc_tensors(tuple(self.shape), device=device) | ||
| return [inner[name] for name in self.inner_names()] | ||
|
|
||
| def create_tensor(self) -> torch.Tensor: | ||
| """Materialize an (uninitialized) tensor matching this proto (traceable). | ||
|
|
||
| Quantized protos reassemble the :meth:`create_inner_tensors` buffers via | ||
| the storage's ``__tensor_unflatten__``. | ||
| """ | ||
| if self.quantizer is None: | ||
| device = self.device if self.device is not None else torch.device("cuda") | ||
| return torch.empty( | ||
| tuple(self.shape), | ||
| dtype=self.dtype, | ||
| device=device, | ||
| requires_grad=self.requires_grad, | ||
| ) | ||
| from ..quantized_tensor import ( # pylint: disable=import-outside-toplevel | ||
| _STORAGE_REGISTRY, | ||
| ) | ||
|
|
||
| shape = tuple(self.shape) | ||
| ctx = self.create_metadata() | ||
| inner = dict(zip(self.inner_names(), self.create_inner_tensors())) | ||
| storage_cls = _STORAGE_REGISTRY[ctx["cls"]] | ||
| return storage_cls.__tensor_unflatten__(inner, ctx, shape, _contiguous_stride(shape)) | ||
|
|
||
|
|
||
| def to_tensor_proto(tensor: Any) -> TensorProto: | ||
| """Build a :class:`TensorProto` describing ``tensor``. | ||
|
|
||
| Works for plain ``torch.Tensor`` and for ``QuantizedTensorStorage`` / | ||
| ``QuantizedTensor``. A *bare* storage exposes its shape via ``.size()`` and | ||
| its (fake) dtype via ``_dtype`` rather than ``.shape`` / ``.dtype``. | ||
| """ | ||
| from ..quantized_tensor import ( # pylint: disable=import-outside-toplevel | ||
| QuantizedTensorStorage, | ||
| ) | ||
|
|
||
| requires_grad = bool(getattr(tensor, "requires_grad", False)) | ||
| if isinstance(tensor, QuantizedTensorStorage): | ||
| shape = getattr(tensor, "shape", None) | ||
| if shape is None: | ||
| shape = tensor.size() | ||
| dtype = getattr(tensor, "dtype", None) | ||
| if dtype is None: | ||
| dtype = getattr(tensor, "_dtype", None) | ||
| return TensorProto( | ||
| shape=tuple(shape), | ||
| dtype=dtype, | ||
| quantizer=getattr(tensor, "_quantizer", None), | ||
| requires_grad=requires_grad, | ||
| device=tensor.device, | ||
| ) | ||
| return TensorProto( | ||
| shape=tuple(tensor.shape), | ||
| dtype=tensor.dtype, | ||
| quantizer=None, | ||
| requires_grad=requires_grad, | ||
| device=tensor.device, | ||
| ) | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.