Skip to content

[PyTorch][torch.compile] Add TensorProto mechanism#3153

Open
pggPL wants to merge 19 commits into
NVIDIA:mainfrom
pggPL:tensor_proto_mechanism
Open

[PyTorch][torch.compile] Add TensorProto mechanism#3153
pggPL wants to merge 19 commits into
NVIDIA:mainfrom
pggPL:tensor_proto_mechanism

Conversation

@pggPL

@pggPL pggPL commented Jun 29, 2026

Copy link
Copy Markdown
Collaborator

Note

This PR is stacked on top of #3152 ([PyTorch][torch.compile] Make quantizers opaque value objects).
Until #3152 is merged, the diff below also includes its changes. Review/merge #3152 first.

Description

This PR introduces TensorProto — a data-free prototype of a tensor (or quantized tensor) that captures everything needed to reason about and rebuild a tensor without holding any storage: its logical shape/dtype and, for quantized tensors, the value-opaque quantizer defining the layout.

The key property is that TensorProto.create_tensor() materializes a quantized tensor purely in Python (via Quantizer.alloc_tensors + the storage's __tensor_unflatten__), so it traces under torch.compile(fullgraph=True) with no graph break — unlike make_empty, which goes through the opaque C++ tex.create_empty_quantized_tensor. This is the foundation for writing torch.library custom-op fake implementations of quantized ops.

This builds on the value-opaque quantizer work (so a TensorProto is itself safe to treat as a compile-time constant).

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • dynamo.py: Add TensorProto dataclass (shape, dtype, quantizer, requires_grad, device) with is_quantized, inner_names(), create_metadata() and create_tensor(), plus a to_tensor_proto() helper that builds a proto from a plain torch.Tensor or a QuantizedTensorStorage/QuantizedTensor.
  • quantized_tensor.py:
    • Add the PyTorch wrapper-subclass flatten protocol (__tensor_flatten__ / __tensor_unflatten__) to QuantizedTensorStorage, driven by a per-class _FLATTEN_TENSOR_BUFFERS declaration of (attribute_name, constructor_kwarg) pairs.
    • Add a _STORAGE_REGISTRY (populated via __init_subclass__) so __tensor_unflatten__ can resolve a concrete storage/wrapper class from its qualname inside an FX graph.
    • Add pure-Python, traceable allocation hooks to Quantizer: alloc_tensors, create_metadata, and the opt-in overrides _describe_buffers, _storage_scalars, _resolve_storage_cls.
  • Quantizers: Implement the allocation hooks for Float8CurrentScalingQuantizer, MXFP8Quantizer and Float8BlockQuantizer.
  • Storage classes: Declare _FLATTEN_TENSOR_BUFFERS for Float8TensorStorage, MXFP8TensorStorage and Float8BlockwiseQTensorStorage.
  • ops/basic/basic_linear.py: Add allocation-free _functional_forward_fake / _functional_backward_fake that operate on TensorProto and return output/gradient protos, as a basis for custom-op fake impls (single-device only; TP/SP shape effects not yet modeled).
  • Tests: Add tests/pytorch/test_tensor_proto.py (CPU smoke tests for _describe_buffers/alloc_tensors/create_metadata, flatten round-trip, and to_tensor_proto) and torch.compile fullgraph tests in test_torch_compile.py.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

pggPL and others added 8 commits June 29, 2026 11:25
…ompile

Give tensorless quantizers (MXFP8, FP8 blockwise, FP8 current-scaling,
NVFP4) value-object semantics so torch.compile can treat them as baked-in
constants:

- Add opt-in value identity to the base Quantizer (_value_fields /
  _value_key / __eq__ / __hash__). Quantizers holding live tensors
  (delayed-scaling Float8Quantizer) and custom quantizers keep identity
  semantics.
- New transformer_engine/pytorch/dynamo.py houses the torch.compile glue:
  __fx_repr__, value-key reconstruction and register_value_opaque_quantizer
  (gracefully a no-op on PyTorch builds without the opaque-object API).
- Register the four tensorless quantizers as value opaque types.

Also fix CustomRecipe state caching in TransformerEngineBaseModule:
set_meta_tensor now rebuilds quantizers when the CustomRecipe instance
changes (e.g. nested te.autocast regions) instead of reusing the first
recipe's state, since every CustomRecipe shares the CustomRecipeState type
but carries its own qfactory.

Move the quantizer value-object tests into tests/pytorch/test_torch_compile.py
and add that file to the L0 pytorch unittest QA suite.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
…globals

Follow-up to the value-opaque quantizer support:

- Remove the module-level _QUANTIZER_VALUE_REGISTRY (qualname -> class) and
  _quantizer_from_value_key. __fx_repr__ now captures the quantizer class
  directly in the FX globals and reconstructs via _rebuild_quantizer(cls, items),
  matching how PyTorch's own value opaque types (e.g. DTensor placements)
  reconstruct themselves. This removes global mutable state and the qualname
  collision risk.
- Consolidate the quantizer value-object tests in test_torch_compile.py down to
  two functions and exercise reconstruction through the public __fx_repr__ path
  instead of internal helpers.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Replace the single dynamo.py module with a dynamo/ package so the
torch.compile glue can grow with a clear responsibility split across the
stacked branches. This branch owns the value-opaque quantizer layer.

  * dynamo/quantizer_opaque.py -- register_value_opaque_quantizer and helpers
  * dynamo/__init__.py -- re-exports the public API so callers keep importing
    from transformer_engine.pytorch.dynamo unchanged

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
A value-opaque quantizer must not carry live distributed state. Scan the
quantizer attributes in __fx_repr__ and raise TypeError if any holds a
torch.distributed.ProcessGroup (e.g. a non-None deprecated amax_reduction_group),
so it cannot be silently baked into a torch.compile FX graph. Clarify the related
comments accordingly.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
NVFP4Quantizer is registered as a value-opaque quantizer but was missing
from the value-semantics / __fx_repr__ round-trip test. Add it to
_VALUE_QUANTIZERS (skipped without CUDA, which it needs to construct).

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
…__/__hash__

The amax reduction group is excluded from the value key, so a value quantizer
that stored one would compare/hash equal to a groupless one and let torch.compile
reuse a graph that skips the reduction. __eq__/__hash__ now raise (mirroring
__fx_repr__, which already rejects any process-group-bearing quantizer). The
group should be passed per quantize call, not stored on the quantizer.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Add is_value_opaque_quantizer() + the _te_compile_value_opaque flag stamped at
registration, so dynamo-traced code can detect registered quantizers (and fall
back to eager for unregistered ones).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
…fp4 value key

- Narrow register_opaque_type except to (RuntimeError, TypeError): the API is
  already imported above, so ImportError/AttributeError there only mask real errors.
- Add test_quantizer_value_object_fullgraph exercising torch.compile(fullgraph=True)
  end-to-end to verify opaque-type registration took effect.
- Restore missing NVFP4Quantizer._with_random_sign_mask assignment required by
  _value_fields()/_value_key().

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL pggPL requested a review from ksivaman as a code owner June 29, 2026 09:39
@greptile-apps

greptile-apps Bot commented Jun 29, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR introduces TensorProto — a data-free prototype of a tensor that captures shape, dtype, and quantizer, enabling pure-Python quantized-tensor allocation via __tensor_flatten__/__tensor_unflatten__. This is foundational infrastructure for writing torch.library custom-op fake implementations that trace under torch.compile(fullgraph=True) without graph breaks.

  • TensorProto (dynamo/tensor_proto.py): new dataclass with create_tensor() / inner_names() / create_metadata() that materialises quantized tensors purely in Python through Quantizer.alloc_tensors + storage __tensor_unflatten__.
  • QuantizedTensorStorage (quantized_tensor.py): adds __tensor_flatten__/__tensor_unflatten__ protocol, _FLATTEN_TENSOR_BUFFERS class-level declarations, and _STORAGE_REGISTRY auto-populated via __init_subclass__; Quantizer gains _describe_buffers, alloc_tensors, and create_metadata pure-Python allocation hooks.
  • Quantizers (float8_tensor.py, mxfp8_tensor.py, float8_blockwise_tensor.py, nvfp4_tensor.py): implement the allocation hooks and register as value-opaque types; NVFP4Quantizer adds _rebuild_derived_state to restore rht_matrix after FX reconstruction; NVFP4Quantizer also adds _with_random_sign_mask field.
  • BasicLinear (module/linear.py): preparatory _linear_forward_impl_fake / _linear_backward_impl_fake shape-propagation stubs for future custom-op fake registrations.

Confidence Score: 4/5

Safe to merge after fixing the wrong flag in _rebuild_derived_state; the copy() omission of with_random_sign_mask is a related follow-up. The core TensorProto and flatten/unflatten machinery is correct and well-tested.

The NVFP4Quantizer._rebuild_derived_state method calls get_rht_matrix(self._with_random_sign_mask, ...), but the original __init__ calls get_rht_matrix(with_rht, ...). These are distinct parameters: with_rht gates whether the transform is used at all; _with_random_sign_mask only controls the sign pattern within it. When they differ — e.g. with_rht=True, with_random_sign_mask=False — the rebuilt quantizer silently receives a different matrix and produces wrong quantization results. The default path (both True) passes unchanged, which is why the existing tests don't catch it.

transformer_engine/pytorch/tensor/nvfp4_tensor.py — both _rebuild_derived_state (wrong flag) and copy() (missing with_random_sign_mask forwarding) need attention.

Important Files Changed

Filename Overview
transformer_engine/pytorch/tensor/nvfp4_tensor.py Adds _value_fields, _storage_metadata, _describe_buffers, _rebuild_derived_state, and registration. _rebuild_derived_state uses self._with_random_sign_mask as the flag to get_rht_matrix, but init uses with_rht — when they differ these produce different matrices. copy() also omits with_random_sign_mask, leaving the copy with the wrong value for that field.
transformer_engine/pytorch/dynamo/tensor_proto.py New file implementing TensorProto dataclass; core logic is sound. Minor: storage_cls lookup in create_tensor is redundant since tensor_unflatten re-resolves from the context dict.
transformer_engine/pytorch/quantized_tensor.py Adds flatten/unflatten protocol and _STORAGE_REGISTRY to QuantizedTensorStorage; the init_subclass registration and tensor_unflatten static dispatch are correct. Registry uses bare qualname without module path, which could silently collide for two classes with identical names across modules.
transformer_engine/pytorch/dynamo/quantizer_opaque.py Adds _rebuild_quantizer with _rebuild_derived_state hook mechanism. The process-group guard and LRU-cache-based device-local reconstruction are well thought out.
transformer_engine/pytorch/module/linear.py Adds _linear_forward_impl_fake and _linear_backward_impl_fake shape-propagation stubs (not yet wired to a custom op). The requires_grad fix (args.*_requires_grad instead of live tensor flags) is correct and important. New alias tracking for the cached FP8 weight is correctly handled in both real and fake paths.
tests/pytorch/test_torch_compile.py Comprehensive tests for value-opaque quantizers, TensorProto primitives, flatten/unflatten roundtrips, and fullgraph compile. Tests cover the default NVFP4 path (both with_rht=True and with_random_sign_mask=True) but do not exercise the with_rht=True/with_random_sign_mask=False case that exposes the _rebuild_derived_state bug.

Sequence Diagram

%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
    participant C as Caller / register_fake
    participant TP as TensorProto
    participant Q as Quantizer
    participant S as QuantizedTensorStorage
    participant R as _STORAGE_REGISTRY

    Note over C,R: TensorProto.create_tensor() flow
    C->>TP: TensorProto(shape, dtype, quantizer)
    TP->>Q: copy() [__post_init__]

    C->>TP: create_tensor()
    TP->>TP: create_metadata()
    TP->>Q: create_metadata(shape, dtype)
    Q->>Q: _storage_metadata(dtype)
    Q-->>TP: "ctx {cls: qualname, nontensor_kwargs, ...}"

    TP->>TP: create_inner_tensors()
    TP->>Q: alloc_tensors(shape, device)
    Q->>Q: _describe_buffers(shape)
    Q-->>TP: "{attr: torch.empty(buf_shape, buf_dtype)}"

    TP->>TP: inner_names() [reorder to _FLATTEN_TENSOR_BUFFERS order]
    TP->>R: _STORAGE_REGISTRY[ctx["cls"]]
    R-->>TP: storage_cls

    TP->>S: __tensor_unflatten__(inner, ctx, shape, stride)
    S->>R: _STORAGE_REGISTRY[ctx["cls"]]
    R-->>S: cls
    S-->>C: QuantizedTensorStorage instance

    Note over C,R: FX graph rebuild via _rebuild_quantizer
    C->>C: eval(quantizer.__fx_repr__())
    C->>C: _rebuild_quantizer(cls, items)
    C->>C: object.__setattr__(obj, field, value) for each field
    C->>Q: _rebuild_derived_state() [NVFP4 only]
    Q->>Q: get_rht_matrix(_with_random_sign_mask, device)
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
    participant C as Caller / register_fake
    participant TP as TensorProto
    participant Q as Quantizer
    participant S as QuantizedTensorStorage
    participant R as _STORAGE_REGISTRY

    Note over C,R: TensorProto.create_tensor() flow
    C->>TP: TensorProto(shape, dtype, quantizer)
    TP->>Q: copy() [__post_init__]

    C->>TP: create_tensor()
    TP->>TP: create_metadata()
    TP->>Q: create_metadata(shape, dtype)
    Q->>Q: _storage_metadata(dtype)
    Q-->>TP: "ctx {cls: qualname, nontensor_kwargs, ...}"

    TP->>TP: create_inner_tensors()
    TP->>Q: alloc_tensors(shape, device)
    Q->>Q: _describe_buffers(shape)
    Q-->>TP: "{attr: torch.empty(buf_shape, buf_dtype)}"

    TP->>TP: inner_names() [reorder to _FLATTEN_TENSOR_BUFFERS order]
    TP->>R: _STORAGE_REGISTRY[ctx["cls"]]
    R-->>TP: storage_cls

    TP->>S: __tensor_unflatten__(inner, ctx, shape, stride)
    S->>R: _STORAGE_REGISTRY[ctx["cls"]]
    R-->>S: cls
    S-->>C: QuantizedTensorStorage instance

    Note over C,R: FX graph rebuild via _rebuild_quantizer
    C->>C: eval(quantizer.__fx_repr__())
    C->>C: _rebuild_quantizer(cls, items)
    C->>C: object.__setattr__(obj, field, value) for each field
    C->>Q: _rebuild_derived_state() [NVFP4 only]
    Q->>Q: get_rht_matrix(_with_random_sign_mask, device)
Loading

Comments Outside Diff (1)

  1. transformer_engine/pytorch/tensor/nvfp4_tensor.py, line 231-245 (link)

    P1 copy() does not forward with_random_sign_mask to the constructor, so the copy always gets _with_random_sign_mask = True (the default) regardless of the original's value. If _rebuild_derived_state is fixed to use self.with_rht, this inconsistency becomes harmless for reconstruction; but a copy used directly (without going through FX reconstruction) would still carry the wrong _with_random_sign_mask and rht_matrix_random_sign_mask_t, which can affect future behavior if those fields are ever inspected independently.

Reviews (7): Last reviewed commit: "[PyTorch] Workaround torch.compile stati..." | Re-trigger Greptile

Comment thread transformer_engine/pytorch/tensor/mxfp8_tensor.py Outdated
Comment thread transformer_engine/pytorch/dynamo/tensor_proto.py
…trip

_rebuild_quantizer only restores value-key fields, so a reconstructed
NVFP4Quantizer was missing the derived rht_matrix tensor (not hashable, so not
in the value key) and failed at copy()/quantize time. Add a _rebuild_derived_state
hook (called by _rebuild_quantizer) that NVFP4Quantizer uses to rebuild rht_matrix
from _with_random_sign_mask (lru_cache -> cheap).

Extend test_quantizer_value_object to also quantize with the original and the
rebuilt quantizer and require bit-exact results (gated on HW support), so a
field the kernel needs but the value key omits can no longer slip through.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL pggPL force-pushed the tensor_proto_mechanism branch 3 times, most recently from 5131ebc to 77831be Compare June 29, 2026 10:24
Move the ProcessGroup guard out of the (overridable) __fx_repr__ into
Quantizer._value_key -- the single point every value-materialization path
(__eq__/__hash__/__fx_repr__) goes through -- so a custom __fx_repr__ can no
longer bypass it. Generalizes the old amax-only check to any field holding a
ProcessGroup. Add a test that a value quantizer carrying a live group raises.

Addresses review on NVIDIA#3152.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL pggPL force-pushed the tensor_proto_mechanism branch from 77831be to 29e5245 Compare June 29, 2026 12:47
…assthrough

Replace the trivial pass-through fullgraph test with one that drives each
production quantizer through a minimal custom op (quantize + dequantize) under
torch.compile(fullgraph=True) and compares to eager -- so the opaque-type
registration is actually exercised inside the graph (a graph break would make
fullgraph=True raise). Op registration sits right before the test. Also drop
stale comments referencing the old __fx_repr__-side process-group guard.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL pggPL force-pushed the tensor_proto_mechanism branch from 29e5245 to 99c1377 Compare June 29, 2026 13:10
…paque flag

- rht_matrix_random_sign_mask_t is a device-independent int derived from
  _with_random_sign_mask (the device only places a throwaway tensor); fix the
  misleading comment.
- Explain why registration uses a class attribute, not a registry set:
  is_value_opaque_quantizer is traced inside the compile graph and dynamo can
  bake a getattr constant but cannot do 'type(q) in set' on the opaque class.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL pggPL force-pushed the tensor_proto_mechanism branch from 99c1377 to afa86ff Compare June 29, 2026 13:29
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL pggPL force-pushed the tensor_proto_mechanism branch from afa86ff to 9e78a6c Compare June 29, 2026 13:30
pggPL and others added 5 commits June 29, 2026 15:38
is_opaque_value_type(cls) sat between the import guard and the
register_opaque_type guard, so on a partial/experimental opaque-object build it
could raise RuntimeError/TypeError and crash TE import. Move it inside the same
except so the 'registration never crashes import' promise holds for both calls.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Squashed PR #8 (tensor_proto_mechanism) onto the rebased base. Adds TensorProto
(pure-Python, torch.compile-traceable quantized-tensor allocation via
Quantizer.alloc_tensors + storage __tensor_flatten__/__tensor_unflatten__),
Linear fake fwd/bwd impls for the custom-op path, and tests.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
The cached FP8 weight is the same tensor returned as new_weight_workspace (cache miss) or passed in as weight_workspace (cache hit). A custom op may not return a tensor that aliases an input or another return, so mark those slots and reconstruct wt_save in _linear_setup_ctx instead of saving it twice. Mirrored in the fake impl so the saved-slot layout matches.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
NVFP4Quantizer._describe_buffers grouped each amax right after its scale (per-usage), diverging from NVFP4TensorStorage._FLATTEN_TENSOR_BUFFERS (amax buffers last). The order is functionally irrelevant (buffers are consumed by name in alloc_tensors and reordered in TensorProto.inner_names), but aligning it makes describe/flatten agree and fixes test_to_tensor_proto_quantized[nvfp4].

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
…upport

- TensorProto.inner_names now raises if the quantizer describes buffer(s) absent
  from the storage's _FLATTEN_TENSOR_BUFFERS, instead of silently appending them.
- Gate the nvfp4 proto-quantizer param on nvfp4_available so it skips on hardware
  without NVFP4 support rather than failing.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL pggPL force-pushed the tensor_proto_mechanism branch from 9e78a6c to 50c11cd Compare June 29, 2026 13:46
…escribe_buffers

Access NVFP4Quantizer @staticmethods (convert_shape_for_fp4, get_columnwise_shape)
via the class instead of the instance. Under torch.compile, instance access of a
@staticmethod on a value-opaque object crashes Dynamo guard generation with
"'function' object has no attribute '__func__'" (pytorch/pytorch#182741).
Temporary workaround until the PyTorch-side fix lands.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
``_rebuild_quantizer`` calls this hook to rebuild it; the ``lru_cache`` on
:func:`get_rht_matrix` makes an already-seen (flag, device) a cheap hit.
"""
self.rht_matrix = get_rht_matrix(self._with_random_sign_mask, torch.cuda.current_device())

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 _rebuild_derived_state passes self._with_random_sign_mask to get_rht_matrix, but __init__ passes with_rht. These are distinct parameters: with_rht controls whether the Hadamard transform is applied at all; _with_random_sign_mask controls whether random signs are used within it. When they differ (e.g. with_rht=True, with_random_sign_mask=False) the rebuilt quantizer gets a different matrix than the original. Since get_rht_matrix is LRU-cached by (flag, device), get_rht_matrix(True, d) and get_rht_matrix(False, d) return different objects, so the kernel would receive the wrong transform matrix and silently produce incorrect quantization results.

Suggested change
self.rht_matrix = get_rht_matrix(self._with_random_sign_mask, torch.cuda.current_device())
self.rht_matrix = get_rht_matrix(self.with_rht, torch.cuda.current_device())

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant