Skip to content
Open

wip #1284

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions atom/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ def __init__(
"",
"fp8",
"mxfp4",
"mxfp8",
]:
self.online_quant = True
online_parser = get_quant_parser("online_quant")
Expand Down
19 changes: 17 additions & 2 deletions atom/model_ops/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
)
from atom.utils import envs
from atom.utils.decorators import mark_trace
from atom.quantization.quark.utils import weight_dequant_fp8
from atom.quantization.quark.utils import weight_dequant_fp8, weight_dequant_mxfp8
from torch import nn

logger = logging.getLogger("atom")
Expand Down Expand Up @@ -414,6 +414,14 @@ def _gather_full_weight(self, weight):
"""Gather sharded weight from all TP ranks to reconstruct the full unpartitioned weight."""
if self.tp_size <= 1 or self.tp_dim is None:
return weight
# NCCL cannot all_gather E8M0 scales (MXFP8 source); gather the raw
# bytes as uint8 and reinterpret afterwards. The gather only moves
# bytes, so this is bit-exact.
if weight.dtype == dtypes.fp8_e8m0:
gathered = get_tp_group().all_gather(
weight.view(torch.uint8), dim=self.tp_dim
)
return gathered.view(dtypes.fp8_e8m0)
return get_tp_group().all_gather(weight, dim=self.tp_dim)

def _shard_quantized_weight(self, q_weight, weight_scale):
Expand Down Expand Up @@ -466,7 +474,11 @@ def online_quantize_weight(self):
f"Unsupported online quant: "
f"dtype={online_quant_dtype}, type={online_quant_type}"
)
assert self.quant_type in [QuantType.No, QuantType.per_1x128], (
assert self.quant_type in [
QuantType.No,
QuantType.per_1x128,
QuantType.per_1x32,
], (
f"Unsupported source quant_type for online quantization: "
f"{self.quant_type} (layer={self.prefix})"
)
Expand Down Expand Up @@ -504,6 +516,9 @@ def online_quantize_weight(self):
if self.quant_type == QuantType.per_1x128:
# dequant per block fp8
weight = weight_dequant_fp8(weight, weight_scale)
elif self.quant_type == QuantType.per_1x32:
# dequant MXFP8 (FP8 elements + 1x32 E8M0 shared scale)
weight = weight_dequant_mxfp8(weight, weight_scale)
q_weight, weight_scale = online_quant_func(
weight, quant_dtype=online_quant_dtype
)
Expand Down
37 changes: 25 additions & 12 deletions atom/model_ops/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
)
from atom.plugin.vllm.moe import FusedMoEDecoratorForPluginMode
from atom.quant_spec import LayerQuantConfig, should_skip_online_quant
from atom.quantization.quark.utils import weight_dequant_fp8
from atom.quantization.quark.utils import weight_dequant_fp8, weight_dequant_mxfp8
from atom.utils import envs
from atom.utils.custom_register import direct_register_custom_op
from atom.utils.decorators import mark_trace
Expand Down Expand Up @@ -2411,11 +2411,24 @@ def _online_quant(self):
return

quant_func = get_hip_quant(online_quant_type)
assert source_quant_type in (QuantType.No, QuantType.per_1x128), (
assert source_quant_type in (
QuantType.No,
QuantType.per_1x128,
QuantType.per_1x32,
), (
f"Unsupported source quant_type for MoE online quantization: "
f"{source_quant_type} (layer={self.layer_name})"
)
need_dequant = source_quant_type == QuantType.per_1x128
need_dequant = source_quant_type in (
QuantType.per_1x128,
QuantType.per_1x32,
)

def _dequant_func(w: torch.Tensor, sc: torch.Tensor) -> torch.Tensor:
# per_1x128 -> deepseek-style square-block FP8; per_1x32 -> MXFP8.
if source_quant_type == QuantType.per_1x32:
return weight_dequant_mxfp8(w.contiguous(), sc.contiguous())
return weight_dequant_fp8(w.contiguous(), sc.contiguous())

# Determine whether each weight needs all_gather to match offline quantization.
# w13 (column parallel): (E, (2*intermediate/tp, hidden)) — TP dim 0
Expand Down Expand Up @@ -2472,13 +2485,13 @@ def check_need_allgather():
if need_dequant:
w13_scale = old_w13_scale[expert_id]
s1_size = w13_scale.shape[0] // 2
w1_bf16 = weight_dequant_fp8(
w13_local[:w1_size].contiguous(),
w13_scale[:s1_size].contiguous(),
w1_bf16 = _dequant_func(
w13_local[:w1_size],
w13_scale[:s1_size],
)
w3_bf16 = weight_dequant_fp8(
w13_local[w1_size:].contiguous(),
w13_scale[s1_size:].contiguous(),
w3_bf16 = _dequant_func(
w13_local[w1_size:],
w13_scale[s1_size:],
)
else:
w1_bf16 = w13_local[:w1_size]
Expand Down Expand Up @@ -2515,9 +2528,9 @@ def check_need_allgather():
# w2 ptpc_fp8 [e, m, n]->[e, m, n//tp]->[e, m, 1]
w2_local = old_w2_data[expert_id]
if need_dequant:
w2_local = weight_dequant_fp8(
w2_local.contiguous(),
old_w2_scale[expert_id].contiguous(),
w2_local = _dequant_func(
w2_local,
old_w2_scale[expert_id],
)
if need_gather_w2:
w2_full = tp_group.all_gather(w2_local, dim=1)
Expand Down
8 changes: 8 additions & 0 deletions atom/quant_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,14 @@ def parse(self, hf_quant_config: dict) -> ParsedQuantConfig:
QuantType.per_1x128,
):
quant_type = QuantType.per_1x32
# Mxfp8 ``[1, K]`` block to per_1x32.
weight_block_size = hf_quant_config.get("weight_block_size")
if (
isinstance(weight_block_size, (list, tuple))
and len(weight_block_size) == 2
and weight_block_size[0] == 1
):
quant_type = QuantType.per_1x32
is_dynamic = hf_quant_config.get("is_dynamic", True)
# Each quantizer uses a different key for excluded layers:
# Quark -> "exclude", compressed-tensors -> "ignore",
Expand Down
28 changes: 28 additions & 0 deletions atom/quantization/quark/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,31 @@ def grid(meta: dict[str, int]) -> tuple[int, int]:

_weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
return y


# Optional E8M0 dtype: only available on newer torch builds.
_E8M0_DTYPE = getattr(torch, "float8_e8m0fnu", None)


def weight_dequant_mxfp8(
x: torch.Tensor, s: torch.Tensor, block_size: int = 32
) -> torch.Tensor:
"""Dequantize an MXFP8 weight to the default float dtype.
"""
assert x.dim() == 2 and s.dim() == 2, "Input tensors must have 2 dimensions"
M, K = x.shape
assert K % block_size == 0, f"K={K} not divisible by block_size={block_size}"
n_blocks = K // block_size
assert s.shape == (M, n_blocks), f"scale shape {tuple(s.shape)} != {(M, n_blocks)}"

if _E8M0_DTYPE is not None and s.dtype == _E8M0_DTYPE:
# E8M0 dtype decodes straight to the 2**(e-127) multiplier.
scale = s.to(torch.float32)
else:
# Raw E8M0 integer codes stored as uint8 / float.
scale = torch.exp2(s.to(torch.float32) - 127.0)

out_dtype = torch.get_default_dtype()
y = x.to(torch.float32).reshape(M, n_blocks, block_size)
y = y * scale.unsqueeze(-1)
return y.reshape(M, K).to(out_dtype)