From 0bc353f267fb528ba0b3043d172fee28daa62d7f Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 18 Jun 2026 09:38:14 +0000 Subject: [PATCH] wip Signed-off-by: Haoyang Li --- atom/config.py | 1 + atom/model_ops/linear.py | 19 ++++++++++++++-- atom/model_ops/moe.py | 37 +++++++++++++++++++++----------- atom/quant_spec.py | 8 +++++++ atom/quantization/quark/utils.py | 28 ++++++++++++++++++++++++ 5 files changed, 79 insertions(+), 14 deletions(-) diff --git a/atom/config.py b/atom/config.py index 657095ab55..7095ed1141 100644 --- a/atom/config.py +++ b/atom/config.py @@ -309,6 +309,7 @@ def __init__( "", "fp8", "mxfp4", + "mxfp8", ]: self.online_quant = True online_parser = get_quant_parser("online_quant") diff --git a/atom/model_ops/linear.py b/atom/model_ops/linear.py index c0103f1735..550ffa855f 100644 --- a/atom/model_ops/linear.py +++ b/atom/model_ops/linear.py @@ -32,7 +32,7 @@ ) from atom.utils import envs from atom.utils.decorators import mark_trace -from atom.quantization.quark.utils import weight_dequant_fp8 +from atom.quantization.quark.utils import weight_dequant_fp8, weight_dequant_mxfp8 from torch import nn logger = logging.getLogger("atom") @@ -414,6 +414,14 @@ def _gather_full_weight(self, weight): """Gather sharded weight from all TP ranks to reconstruct the full unpartitioned weight.""" if self.tp_size <= 1 or self.tp_dim is None: return weight + # NCCL cannot all_gather E8M0 scales (MXFP8 source); gather the raw + # bytes as uint8 and reinterpret afterwards. The gather only moves + # bytes, so this is bit-exact. + if weight.dtype == dtypes.fp8_e8m0: + gathered = get_tp_group().all_gather( + weight.view(torch.uint8), dim=self.tp_dim + ) + return gathered.view(dtypes.fp8_e8m0) return get_tp_group().all_gather(weight, dim=self.tp_dim) def _shard_quantized_weight(self, q_weight, weight_scale): @@ -466,7 +474,11 @@ def online_quantize_weight(self): f"Unsupported online quant: " f"dtype={online_quant_dtype}, type={online_quant_type}" ) - assert self.quant_type in [QuantType.No, QuantType.per_1x128], ( + assert self.quant_type in [ + QuantType.No, + QuantType.per_1x128, + QuantType.per_1x32, + ], ( f"Unsupported source quant_type for online quantization: " f"{self.quant_type} (layer={self.prefix})" ) @@ -504,6 +516,9 @@ def online_quantize_weight(self): if self.quant_type == QuantType.per_1x128: # dequant per block fp8 weight = weight_dequant_fp8(weight, weight_scale) + elif self.quant_type == QuantType.per_1x32: + # dequant MXFP8 (FP8 elements + 1x32 E8M0 shared scale) + weight = weight_dequant_mxfp8(weight, weight_scale) q_weight, weight_scale = online_quant_func( weight, quant_dtype=online_quant_dtype ) diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index a710e8fad4..122da02b83 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -52,7 +52,7 @@ ) from atom.plugin.vllm.moe import FusedMoEDecoratorForPluginMode from atom.quant_spec import LayerQuantConfig, should_skip_online_quant -from atom.quantization.quark.utils import weight_dequant_fp8 +from atom.quantization.quark.utils import weight_dequant_fp8, weight_dequant_mxfp8 from atom.utils import envs from atom.utils.custom_register import direct_register_custom_op from atom.utils.decorators import mark_trace @@ -2411,11 +2411,24 @@ def _online_quant(self): return quant_func = get_hip_quant(online_quant_type) - assert source_quant_type in (QuantType.No, QuantType.per_1x128), ( + assert source_quant_type in ( + QuantType.No, + QuantType.per_1x128, + QuantType.per_1x32, + ), ( f"Unsupported source quant_type for MoE online quantization: " f"{source_quant_type} (layer={self.layer_name})" ) - need_dequant = source_quant_type == QuantType.per_1x128 + need_dequant = source_quant_type in ( + QuantType.per_1x128, + QuantType.per_1x32, + ) + + def _dequant_func(w: torch.Tensor, sc: torch.Tensor) -> torch.Tensor: + # per_1x128 -> deepseek-style square-block FP8; per_1x32 -> MXFP8. + if source_quant_type == QuantType.per_1x32: + return weight_dequant_mxfp8(w.contiguous(), sc.contiguous()) + return weight_dequant_fp8(w.contiguous(), sc.contiguous()) # Determine whether each weight needs all_gather to match offline quantization. # w13 (column parallel): (E, (2*intermediate/tp, hidden)) — TP dim 0 @@ -2472,13 +2485,13 @@ def check_need_allgather(): if need_dequant: w13_scale = old_w13_scale[expert_id] s1_size = w13_scale.shape[0] // 2 - w1_bf16 = weight_dequant_fp8( - w13_local[:w1_size].contiguous(), - w13_scale[:s1_size].contiguous(), + w1_bf16 = _dequant_func( + w13_local[:w1_size], + w13_scale[:s1_size], ) - w3_bf16 = weight_dequant_fp8( - w13_local[w1_size:].contiguous(), - w13_scale[s1_size:].contiguous(), + w3_bf16 = _dequant_func( + w13_local[w1_size:], + w13_scale[s1_size:], ) else: w1_bf16 = w13_local[:w1_size] @@ -2515,9 +2528,9 @@ def check_need_allgather(): # w2 ptpc_fp8 [e, m, n]->[e, m, n//tp]->[e, m, 1] w2_local = old_w2_data[expert_id] if need_dequant: - w2_local = weight_dequant_fp8( - w2_local.contiguous(), - old_w2_scale[expert_id].contiguous(), + w2_local = _dequant_func( + w2_local, + old_w2_scale[expert_id], ) if need_gather_w2: w2_full = tp_group.all_gather(w2_local, dim=1) diff --git a/atom/quant_spec.py b/atom/quant_spec.py index 2396c75714..85dc23a3b9 100644 --- a/atom/quant_spec.py +++ b/atom/quant_spec.py @@ -320,6 +320,14 @@ def parse(self, hf_quant_config: dict) -> ParsedQuantConfig: QuantType.per_1x128, ): quant_type = QuantType.per_1x32 + # Mxfp8 ``[1, K]`` block to per_1x32. + weight_block_size = hf_quant_config.get("weight_block_size") + if ( + isinstance(weight_block_size, (list, tuple)) + and len(weight_block_size) == 2 + and weight_block_size[0] == 1 + ): + quant_type = QuantType.per_1x32 is_dynamic = hf_quant_config.get("is_dynamic", True) # Each quantizer uses a different key for excluded layers: # Quark -> "exclude", compressed-tensors -> "ignore", diff --git a/atom/quantization/quark/utils.py b/atom/quantization/quark/utils.py index e711c63152..b9f17beebd 100644 --- a/atom/quantization/quark/utils.py +++ b/atom/quantization/quark/utils.py @@ -97,3 +97,31 @@ def grid(meta: dict[str, int]) -> tuple[int, int]: _weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size) return y + + +# Optional E8M0 dtype: only available on newer torch builds. +_E8M0_DTYPE = getattr(torch, "float8_e8m0fnu", None) + + +def weight_dequant_mxfp8( + x: torch.Tensor, s: torch.Tensor, block_size: int = 32 +) -> torch.Tensor: + """Dequantize an MXFP8 weight to the default float dtype. + """ + assert x.dim() == 2 and s.dim() == 2, "Input tensors must have 2 dimensions" + M, K = x.shape + assert K % block_size == 0, f"K={K} not divisible by block_size={block_size}" + n_blocks = K // block_size + assert s.shape == (M, n_blocks), f"scale shape {tuple(s.shape)} != {(M, n_blocks)}" + + if _E8M0_DTYPE is not None and s.dtype == _E8M0_DTYPE: + # E8M0 dtype decodes straight to the 2**(e-127) multiplier. + scale = s.to(torch.float32) + else: + # Raw E8M0 integer codes stored as uint8 / float. + scale = torch.exp2(s.to(torch.float32) - 127.0) + + out_dtype = torch.get_default_dtype() + y = x.to(torch.float32).reshape(M, n_blocks, block_size) + y = y * scale.unsqueeze(-1) + return y.reshape(M, K).to(out_dtype)