Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 51 additions & 11 deletions atom/model_ops/fused_moe_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
swizzle_scales as swizzle_scales_cdna4,
)
from aiter.ops.triton.moe.quant_moe import downcast_to_static_fp8
from aiter.ops.triton.quant.quant import dynamic_mxfp8_quant

from atom.model_ops.moe import MoEActivationQuant

Expand Down Expand Up @@ -305,9 +306,33 @@ def triton_kernel_fused_experts(
)
else:
# SiLU (DeepSeek): concatenated [gate | up] layout, manual activation.
# The activation precision selects the routed GEMM: MXFP4 activations
# (a4w4) when act_quant is FP4, otherwise bf16 activations (a16w4).
if act_quant == MoEActivationQuant.FP4:
# The activation precision selects the routed GEMM: FP8 activations
# (a8w4) when act_quant is FP8, MXFP4 activations (a4w4) when FP4,
# otherwise bf16 activations (a16w4).
if act_quant == MoEActivationQuant.FP8:
quant_dtype = torch.float8_e4m3fn
if get_arch() == "gfx942":
quant_dtype = torch.float8_e4m3fnuz

hidden_states, a13_mx_scale = dynamic_mxfp8_quant(
hidden_states, quant_dtype=quant_dtype
)
raw_intermediate = moe_gemm_a8w4(
hidden_states,
w1,
a13_mx_scale,
w13_scale,
None,
None,
w1_bias,
routing_data,
gather_indx=gather_indx,
gammas=gammas if apply_router_weight_on_input else None,
swizzle_mx_scale=w13_swizzle_layout,
out_dtype=torch.bfloat16,
apply_swiglu=False,
)
elif act_quant == MoEActivationQuant.FP4:
hidden_states_fp4, hidden_states_mx_scale = mxfp4_quant(hidden_states)
raw_intermediate = moe_gemm_a4w4(
hidden_states_fp4,
Expand Down Expand Up @@ -341,15 +366,30 @@ def triton_kernel_fused_experts(

raw_2d = raw_intermediate.view(M * topk, N)
intermediate_cache = intermediate_cache.view(M * topk, half_N)
fused_clamp_act_mul(
raw_2d,
out=intermediate_cache,
swiglu_limit=swiglu_limit,
activation="silu",
dtype_quant=None,
)

if act_quant == MoEActivationQuant.FP4:
if act_quant == MoEActivationQuant.FP8:
intermediate_fp8, a2_mx_scale = fused_clamp_act_mul(
raw_2d,
swiglu_limit=swiglu_limit,
activation="silu",
dtype_quant=quant_dtype,
scale_dtype_fmt="ue8m0",
quant_block_size=32,
)
output_tensor = moe_gemm_a8w4(
intermediate_fp8,
w2,
a2_mx_scale,
w2_scale,
None,
None,
w2_bias,
routing_data,
scatter_indx=scatter_indx,
gammas=None if apply_router_weight_on_input else gammas,
swizzle_mx_scale=w2_swizzle_layout,
)
elif act_quant == MoEActivationQuant.FP4:
intermediate_fp4, intermediate_mx_scale = mxfp4_quant(intermediate_cache)
output_tensor = moe_gemm_a4w4(
intermediate_fp4,
Expand Down
66 changes: 56 additions & 10 deletions atom/model_ops/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,8 @@ def __init__(self, quant_config: LayerQuantConfig, moe: FusedMoEConfig):
or (gfx.startswith("gfx95") and envs.ATOM_USE_TRITON_GEMM)
)
self.act_quant = MoEActivationQuant.from_model_config(moe.a_quant_dtype)
self.use_a8w4_prefill = envs.ATOM_USE_A8W4_MOE_PREFILL
self.use_a8w4_decode = envs.ATOM_USE_A8W4_MOE_DECODE

def create_weights(
self,
Expand Down Expand Up @@ -985,6 +987,24 @@ def process_weights_after_loading(self, layer):
layer.w2_weight_scale = w2_scale
layer.w13_swizzle_layout = w13_swizzle_layout
layer.w2_swizzle_layout = w2_swizzle_layout

if (self.use_a8w4_prefill or self.use_a8w4_decode) and self.act_quant != MoEActivationQuant.FP8:
from atom.model_ops.fused_moe_triton import (
_swizzle_scales_for_kernel,
)

w13_s = layer.w13_weight_scale.clone()
w13_s, w13_sl = _swizzle_scales_for_kernel(
w13_s, MoEActivationQuant.FP8
)
w2_s = layer.w2_weight_scale.clone()
w2_s, w2_sl = _swizzle_scales_for_kernel(
w2_s, MoEActivationQuant.FP8
)
layer.w13_weight_scale_a8w4 = w13_s
layer.w2_weight_scale_a8w4 = w2_s
layer.w13_swizzle_layout_a8w4 = w13_sl
layer.w2_swizzle_layout_a8w4 = w2_sl
return

# shuffle weight
Expand Down Expand Up @@ -1065,6 +1085,32 @@ def apply(
triton_kernel_moe_forward,
)

act_quant = self.act_quant
if self.use_a8w4_prefill or self.use_a8w4_decode:
ctx = get_forward_context()
is_prefill = ctx.context.is_prefill
if (is_prefill and self.use_a8w4_prefill) or (
not is_prefill and self.use_a8w4_decode
):
act_quant = MoEActivationQuant.FP8
else:
act_quant = MoEActivationQuant.BF16

use_a8w4_scales = (
act_quant == MoEActivationQuant.FP8
and self.act_quant != MoEActivationQuant.FP8
)
if use_a8w4_scales:
w13_scale = layer.w13_weight_scale_a8w4
w2_scale = layer.w2_weight_scale_a8w4
w13_swizzle = layer.w13_swizzle_layout_a8w4
w2_swizzle = layer.w2_swizzle_layout_a8w4
else:
w13_scale = layer.w13_weight_scale
w2_scale = layer.w2_weight_scale
w13_swizzle = layer.w13_swizzle_layout
w2_swizzle = layer.w2_swizzle_layout

# Check if the model needs custom routing that triton routing()
# does not support (grouped topk, sigmoid scoring, bias correction).
needs_custom_routing = (
Expand Down Expand Up @@ -1118,10 +1164,10 @@ def apply(
scatter_idx,
topk=n_expts_act,
activation=activation,
w13_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
w13_swizzle_layout=layer.w13_swizzle_layout,
w2_swizzle_layout=layer.w2_swizzle_layout,
w13_scale=w13_scale,
w2_scale=w2_scale,
w13_swizzle_layout=w13_swizzle,
w2_swizzle_layout=w2_swizzle,
a13_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
w1_bias=layer.w13_bias,
Expand All @@ -1130,7 +1176,7 @@ def apply(
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=n_expts_tot,
expert_map=expert_map,
act_quant=self.act_quant,
act_quant=act_quant,
)

# Always-on shared expert(s) via a standalone dense GEMM,
Expand All @@ -1154,18 +1200,18 @@ def apply(
topk=top_k,
renormalize=renormalize,
activation=activation,
w13_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
w13_swizzle_layout=layer.w13_swizzle_layout,
w2_swizzle_layout=layer.w2_swizzle_layout,
w13_scale=w13_scale,
w2_scale=w2_scale,
w13_swizzle_layout=w13_swizzle,
w2_swizzle_layout=w2_swizzle,
a13_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
w1_bias=layer.w13_bias,
w2_bias=layer.w2_bias,
expert_map=expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=global_num_experts,
act_quant=self.act_quant,
act_quant=act_quant,
)

topk_weights, topk_ids = FusedMoE.select_experts(
Expand Down
2 changes: 2 additions & 0 deletions atom/utils/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
os.getenv("ATOM_USE_TRITON_MLA_SHUFFLE_KV", "0") == "1"
),
"ATOM_USE_TRITON_MOE": lambda: os.getenv("ATOM_USE_TRITON_MOE", "0") == "1",
"ATOM_USE_A8W4_MOE_PREFILL": lambda: os.getenv("ATOM_USE_A8W4_MOE_PREFILL", "0") == "1",
"ATOM_USE_A8W4_MOE_DECODE": lambda: os.getenv("ATOM_USE_A8W4_MOE_DECODE", "0") == "1",
# --- Kernel Fusion Toggles ---
# fused_compress_attn: switch between Triton (default historical) and a
# flydsl drop-in for V4-Pro Compressor (Main BF16 + Indexer FP8) paths.
Expand Down
Loading