diff --git a/atom/model_ops/fused_moe_triton.py b/atom/model_ops/fused_moe_triton.py index eea8da8548..6a700622d2 100644 --- a/atom/model_ops/fused_moe_triton.py +++ b/atom/model_ops/fused_moe_triton.py @@ -40,6 +40,7 @@ swizzle_scales as swizzle_scales_cdna4, ) from aiter.ops.triton.moe.quant_moe import downcast_to_static_fp8 + from aiter.ops.triton.quant.quant import dynamic_mxfp8_quant from atom.model_ops.moe import MoEActivationQuant @@ -305,9 +306,33 @@ def triton_kernel_fused_experts( ) else: # SiLU (DeepSeek): concatenated [gate | up] layout, manual activation. - # The activation precision selects the routed GEMM: MXFP4 activations - # (a4w4) when act_quant is FP4, otherwise bf16 activations (a16w4). - if act_quant == MoEActivationQuant.FP4: + # The activation precision selects the routed GEMM: FP8 activations + # (a8w4) when act_quant is FP8, MXFP4 activations (a4w4) when FP4, + # otherwise bf16 activations (a16w4). + if act_quant == MoEActivationQuant.FP8: + quant_dtype = torch.float8_e4m3fn + if get_arch() == "gfx942": + quant_dtype = torch.float8_e4m3fnuz + + hidden_states, a13_mx_scale = dynamic_mxfp8_quant( + hidden_states, quant_dtype=quant_dtype + ) + raw_intermediate = moe_gemm_a8w4( + hidden_states, + w1, + a13_mx_scale, + w13_scale, + None, + None, + w1_bias, + routing_data, + gather_indx=gather_indx, + gammas=gammas if apply_router_weight_on_input else None, + swizzle_mx_scale=w13_swizzle_layout, + out_dtype=torch.bfloat16, + apply_swiglu=False, + ) + elif act_quant == MoEActivationQuant.FP4: hidden_states_fp4, hidden_states_mx_scale = mxfp4_quant(hidden_states) raw_intermediate = moe_gemm_a4w4( hidden_states_fp4, @@ -341,15 +366,30 @@ def triton_kernel_fused_experts( raw_2d = raw_intermediate.view(M * topk, N) intermediate_cache = intermediate_cache.view(M * topk, half_N) - fused_clamp_act_mul( - raw_2d, - out=intermediate_cache, - swiglu_limit=swiglu_limit, - activation="silu", - dtype_quant=None, - ) - if act_quant == MoEActivationQuant.FP4: + if act_quant == MoEActivationQuant.FP8: + intermediate_fp8, a2_mx_scale = fused_clamp_act_mul( + raw_2d, + swiglu_limit=swiglu_limit, + activation="silu", + dtype_quant=quant_dtype, + scale_dtype_fmt="ue8m0", + quant_block_size=32, + ) + output_tensor = moe_gemm_a8w4( + intermediate_fp8, + w2, + a2_mx_scale, + w2_scale, + None, + None, + w2_bias, + routing_data, + scatter_indx=scatter_indx, + gammas=None if apply_router_weight_on_input else gammas, + swizzle_mx_scale=w2_swizzle_layout, + ) + elif act_quant == MoEActivationQuant.FP4: intermediate_fp4, intermediate_mx_scale = mxfp4_quant(intermediate_cache) output_tensor = moe_gemm_a4w4( intermediate_fp4, diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index ad45981e36..43dcba4394 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -792,6 +792,8 @@ def __init__(self, quant_config: LayerQuantConfig, moe: FusedMoEConfig): or (gfx.startswith("gfx95") and envs.ATOM_USE_TRITON_GEMM) ) self.act_quant = MoEActivationQuant.from_model_config(moe.a_quant_dtype) + self.use_a8w4_prefill = envs.ATOM_USE_A8W4_MOE_PREFILL + self.use_a8w4_decode = envs.ATOM_USE_A8W4_MOE_DECODE def create_weights( self, @@ -985,6 +987,24 @@ def process_weights_after_loading(self, layer): layer.w2_weight_scale = w2_scale layer.w13_swizzle_layout = w13_swizzle_layout layer.w2_swizzle_layout = w2_swizzle_layout + + if (self.use_a8w4_prefill or self.use_a8w4_decode) and self.act_quant != MoEActivationQuant.FP8: + from atom.model_ops.fused_moe_triton import ( + _swizzle_scales_for_kernel, + ) + + w13_s = layer.w13_weight_scale.clone() + w13_s, w13_sl = _swizzle_scales_for_kernel( + w13_s, MoEActivationQuant.FP8 + ) + w2_s = layer.w2_weight_scale.clone() + w2_s, w2_sl = _swizzle_scales_for_kernel( + w2_s, MoEActivationQuant.FP8 + ) + layer.w13_weight_scale_a8w4 = w13_s + layer.w2_weight_scale_a8w4 = w2_s + layer.w13_swizzle_layout_a8w4 = w13_sl + layer.w2_swizzle_layout_a8w4 = w2_sl return # shuffle weight @@ -1065,6 +1085,32 @@ def apply( triton_kernel_moe_forward, ) + act_quant = self.act_quant + if self.use_a8w4_prefill or self.use_a8w4_decode: + ctx = get_forward_context() + is_prefill = ctx.context.is_prefill + if (is_prefill and self.use_a8w4_prefill) or ( + not is_prefill and self.use_a8w4_decode + ): + act_quant = MoEActivationQuant.FP8 + else: + act_quant = MoEActivationQuant.BF16 + + use_a8w4_scales = ( + act_quant == MoEActivationQuant.FP8 + and self.act_quant != MoEActivationQuant.FP8 + ) + if use_a8w4_scales: + w13_scale = layer.w13_weight_scale_a8w4 + w2_scale = layer.w2_weight_scale_a8w4 + w13_swizzle = layer.w13_swizzle_layout_a8w4 + w2_swizzle = layer.w2_swizzle_layout_a8w4 + else: + w13_scale = layer.w13_weight_scale + w2_scale = layer.w2_weight_scale + w13_swizzle = layer.w13_swizzle_layout + w2_swizzle = layer.w2_swizzle_layout + # Check if the model needs custom routing that triton routing() # does not support (grouped topk, sigmoid scoring, bias correction). needs_custom_routing = ( @@ -1118,10 +1164,10 @@ def apply( scatter_idx, topk=n_expts_act, activation=activation, - w13_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - w13_swizzle_layout=layer.w13_swizzle_layout, - w2_swizzle_layout=layer.w2_swizzle_layout, + w13_scale=w13_scale, + w2_scale=w2_scale, + w13_swizzle_layout=w13_swizzle, + w2_swizzle_layout=w2_swizzle, a13_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, w1_bias=layer.w13_bias, @@ -1130,7 +1176,7 @@ def apply( apply_router_weight_on_input=layer.apply_router_weight_on_input, global_num_experts=n_expts_tot, expert_map=expert_map, - act_quant=self.act_quant, + act_quant=act_quant, ) # Always-on shared expert(s) via a standalone dense GEMM, @@ -1154,10 +1200,10 @@ def apply( topk=top_k, renormalize=renormalize, activation=activation, - w13_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - w13_swizzle_layout=layer.w13_swizzle_layout, - w2_swizzle_layout=layer.w2_swizzle_layout, + w13_scale=w13_scale, + w2_scale=w2_scale, + w13_swizzle_layout=w13_swizzle, + w2_swizzle_layout=w2_swizzle, a13_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, w1_bias=layer.w13_bias, @@ -1165,7 +1211,7 @@ def apply( expert_map=expert_map, apply_router_weight_on_input=layer.apply_router_weight_on_input, global_num_experts=global_num_experts, - act_quant=self.act_quant, + act_quant=act_quant, ) topk_weights, topk_ids = FusedMoE.select_experts( diff --git a/atom/utils/envs.py b/atom/utils/envs.py index 87b00cba74..10578847e7 100644 --- a/atom/utils/envs.py +++ b/atom/utils/envs.py @@ -42,6 +42,8 @@ os.getenv("ATOM_USE_TRITON_MLA_SHUFFLE_KV", "0") == "1" ), "ATOM_USE_TRITON_MOE": lambda: os.getenv("ATOM_USE_TRITON_MOE", "0") == "1", + "ATOM_USE_A8W4_MOE_PREFILL": lambda: os.getenv("ATOM_USE_A8W4_MOE_PREFILL", "0") == "1", + "ATOM_USE_A8W4_MOE_DECODE": lambda: os.getenv("ATOM_USE_A8W4_MOE_DECODE", "0") == "1", # --- Kernel Fusion Toggles --- # fused_compress_attn: switch between Triton (default historical) and a # flydsl drop-in for V4-Pro Compressor (Main BF16 + Indexer FP8) paths.