diff --git a/megatron/core/distributed/distributed_data_parallel.py b/megatron/core/distributed/distributed_data_parallel.py index e313113a448..2e29b17f6ff 100644 --- a/megatron/core/distributed/distributed_data_parallel.py +++ b/megatron/core/distributed/distributed_data_parallel.py @@ -408,7 +408,10 @@ def disable_forward_pre_hook(self, param_sync: bool = True): # Force synchronize parameters. if param_sync: - self.start_param_sync(force_sync=True) + # Hook-disable paths (eval/checkpointing/shutdown) synchronize params as an + # explicit state update, not as differentiable forward compute. + with torch.no_grad(): + self.start_param_sync(force_sync=True) def _make_forward_pre_hook(self): """ @@ -529,6 +532,11 @@ def start_param_sync(self, *unused, force_sync: bool = False, force_dispatch: bo for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups: self._start_bucket_group_param_sync(bucket_group, force_sync=force_sync) + def reset_param_sync_dispatch_state(self): + """Mark DDP param all-gathers as not dispatched for the next forward pre-hook.""" + for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups: + bucket_group.param_gather_dispatched = False + def start_grad_sync(self, *unused): """ Initiates grad sync (all-reduce or reduce-scatter) communication operations diff --git a/megatron/core/distributed/param_and_grad_buffer.py b/megatron/core/distributed/param_and_grad_buffer.py index 9051fb9f47e..dcb687a3648 100644 --- a/megatron/core/distributed/param_and_grad_buffer.py +++ b/megatron/core/distributed/param_and_grad_buffer.py @@ -20,10 +20,21 @@ from megatron.core.rerun_state_machine import get_rerun_state_machine from megatron.core.utils import log_single_rank -from ..fp4_utils import get_nvfp4_rowwise_packed_shape, is_nvfp4tensor +from ..fp4_utils import ( + get_nvfp4_rowwise_packed_shape, + is_grouped_nvfp4tensor, + is_nvfp4tensor, + modify_grouped_nvfp4_rowwise_storage, + modify_nvfp4_rowwise_storage, +) from ..fp8_utils import ( + copy_tensor_to_quantized_param, is_float8tensor, + is_grouped_mxfp8tensor, + is_grouped_tensor, + is_grouped_tensor_with_quantized_storage, is_mxfp8tensor, + modify_grouped_tensor_rowwise_storage, modify_underlying_storage, post_all_gather_processing, ) @@ -67,6 +78,17 @@ def shard_buffer(buffer: torch.Tensor, data_parallel_world_size: int): return sharded_buffer +def _param_uses_quantized_storage(param: torch.nn.Parameter) -> bool: + """Return whether the parameter owns TE quantized storage instead of plain tensor storage.""" + # In TE2, is_float8tensor() checks QuantizedTensor, so it includes plain MXFP8Tensor. + # Grouped quantized params are GroupedTensor wrappers, so check their backing storage. + return ( + is_float8tensor(param) + or is_nvfp4tensor(param) + or is_grouped_tensor_with_quantized_storage(param) + ) + + class _ParamAndGradBucket: """ Bucket to keep track of a subset of the model's parameters and gradients. @@ -280,17 +302,18 @@ def _post_param_sync(self): """Run post-processing after param all-gather completes.""" if self.ddp_config.reuse_grad_buf_for_mxfp8_param_ag: for bucket in self.buckets: - is_bf16_weight_bucket = False + has_non_quantized_weight = False for param in bucket.params: - # Skip copying since bf16 weights in the mxfp8 model - # are already mapped to param.data. - if not is_float8tensor(param): - is_bf16_weight_bucket = True + # Non-quantized weights are already mapped to param.data. Skip + # mixed buckets because zeroing bucket.param_data would also + # clear those model weights. + if not _param_uses_quantized_storage(param): + has_non_quantized_weight = True break param_start, param_end = bucket.param_to_index[param] param_slice = bucket.param_data.view(-1)[param_start:param_end] - param.data.copy_(param_slice.view(param.data.shape)) - if is_bf16_weight_bucket: + copy_tensor_to_quantized_param(param, param_slice) + if has_non_quantized_weight: continue # All-gathered params are not needed after being copied to param.data. # Zero out the param buffer (shared with grad buffer) for gradient accumulation. @@ -303,7 +326,7 @@ def _post_param_sync(self): quantized_params = [] for bucket in self.buckets: for param in bucket.params: - if is_float8tensor(param) or is_nvfp4tensor(param): + if _param_uses_quantized_storage(param): quantized_params.append(param) if len(quantized_params) > 0: post_all_gather_processing(quantized_params) @@ -852,7 +875,7 @@ def group_params_for_buffers( assert param.requires_grad param_dtype = param.dtype - if is_float8tensor(param) or is_nvfp4tensor(param): + if _param_uses_quantized_storage(param): param_dtype = torch.uint8 grad_dtype = torch.float if grad_reduce_in_fp32 else param.dtype is_expert_parallel = not getattr(param, 'allreduce', True) @@ -1034,7 +1057,9 @@ def __init__( # The packed index map is derived from param_index_map by iterating through # the already-computed layout and halving numel for NVFP4 tensors. # - self.has_nvfp4_params = any(is_nvfp4tensor(p) for p in self.params) + self.has_nvfp4_params = any( + is_nvfp4tensor(p) or is_grouped_nvfp4tensor(p) for p in self.params + ) self.nvfp4_packed_param_index_map = None self.nvfp4_packed_bucket_indices = None if self.has_nvfp4_params: @@ -1091,7 +1116,7 @@ def __init__( # The buffer is mapped to weight gradients whose dtype is either bf16 or FP32. # It can be temporarily reused by param AG. if self.ddp_config.use_distributed_optimizer and any( - is_mxfp8tensor(p) for p in self.params + is_mxfp8tensor(p) or is_grouped_mxfp8tensor(p) for p in self.params ): self.shared_buffer = torch.zeros( self.numel, @@ -1169,50 +1194,129 @@ def _create_bucket(bucket_id, bucket_params, bucket_params_with_extra_main_grads nvfp4_packed_param_start_index = None if self.has_nvfp4_params: nvfp4_packed_param_start_index, _, _ = self.nvfp4_packed_param_index_map[param] - # For MXFP8 param: - # we only need to map bf16 weights (layernorm, embedding, etc) to the buffer. - if not self.ddp_config.reuse_grad_buf_for_mxfp8_param_ag or not is_mxfp8tensor(param): + # This branch remaps the parameter storage into persistent DDP param_data buffer. + # + # Enter when: + # - `reuse_grad_buf_for_mxfp8_param_ag` is off: param AG has a persistent + # param_data buffer instead of sharing storage with grad_data, + # so every parameter must be backed by param_data. + # - param is not quantized: BF16/FP16/plain params still need persistent + # param_data even when quantized params use grad_data as temporary AG storage. + # + # Skip only when both are true: AG reuses grad_data and the param is quantized. + # In that case AG writes into grad_data, then _post_param_sync copies the + # gathered values back into TE quantized storage. + # + # Remap cases below: + # non-grouped TE NVFP4 tensor -> remap packed rowwise bytes + # non-grouped TE quantized -> remap TE quantized storage + # regular torch.Tensor param -> replace param.data with param_data view + # TE GroupedTensor + NVFP4 -> remap packed rowwise bytes + # TE GroupedTensor + MXFP8 -> unsupported here; require grad-buffer AG reuse + # TE GroupedTensor + BF16/FP16 -> remap grouped rowwise_data + if ( + not self.ddp_config.reuse_grad_buf_for_mxfp8_param_ag + or not _param_uses_quantized_storage(param) + ): if self.param_data is not None: - if is_nvfp4tensor(param): - # Remap the NVFP4 tensor's internal rowwise uint8 storage so it - # points into the contiguous DDP param buffer. This enables the - # all-gather to communicate packed NVFP4 bytes directly. - from ..fp4_utils import modify_nvfp4_rowwise_storage - - packed_shape = get_nvfp4_rowwise_packed_shape(param.data.shape) - rowwise_bytes_view = self._get( - packed_shape, - nvfp4_packed_param_start_index, - buffer_type=BufferType.PARAM, - ) - modify_nvfp4_rowwise_storage(param, rowwise_bytes_view) - elif is_float8tensor(param): - new_param_data = self._get( - param.data.shape, - ( - nvfp4_packed_param_start_index - if self.has_nvfp4_params - else param_start_index - ), - buffer_type=BufferType.PARAM, - ) - modify_underlying_storage(param, new_param_data) + if not is_grouped_tensor(param): + # Plain NVFP4: remap packed rowwise bytes only. + if is_nvfp4tensor(param): + packed_shape = get_nvfp4_rowwise_packed_shape(param.data.shape) + rowwise_bytes_view = self._get( + packed_shape, + nvfp4_packed_param_start_index, + buffer_type=BufferType.PARAM, + ) + modify_nvfp4_rowwise_storage(param, rowwise_bytes_view) + # In TE2, is_float8tensor() checks QuantizedTensor, including MXFP8. + # NVFP4 is handled by the branch above. + elif is_float8tensor(param): + # NVFP4 packs two FP4 values per byte, so param_data uses + # packed-byte offsets instead of logical element offsets. + new_param_data = self._get( + param.data.shape, + ( + nvfp4_packed_param_start_index + if self.has_nvfp4_params + else param_start_index + ), + buffer_type=BufferType.PARAM, + ) + modify_underlying_storage(param, new_param_data) + # Plain torch param: replace param.data with DDP buffer view. + else: + # NVFP4 packs two FP4 values per byte, so param_data uses + # packed-byte offsets instead of logical element offsets. + new_param_data = self._get( + param.data.shape, + ( + nvfp4_packed_param_start_index + if self.has_nvfp4_params + else param_start_index + ), + buffer_type=BufferType.PARAM, + ) + old_param_data = param.data + param.data = new_param_data + assert old_param_data._base is None + # Copy tensor values (from initialization or checkpoint). + param.data.detach().copy_(old_param_data) + del old_param_data else: - new_param_data = self._get( - param.data.shape, - ( - nvfp4_packed_param_start_index - if self.has_nvfp4_params - else param_start_index - ), - buffer_type=BufferType.PARAM, - ) - old_param_data = param.data - param.data = new_param_data - assert old_param_data._base is None - # Copy tensor values (from initialization or checkpoint). - param.data.detach().copy_(old_param_data) - del old_param_data + # GroupedTensor: preserve wrapper/metadata; remap backing storage only. + # Grouped NVFP4: only rowwise bytes live in DDP param_data. + if is_grouped_nvfp4tensor(param): + packed_shape = get_nvfp4_rowwise_packed_shape(param.data.shape) + rowwise_bytes_view = self._get( + packed_shape, + nvfp4_packed_param_start_index, + buffer_type=BufferType.PARAM, + ) + modify_grouped_nvfp4_rowwise_storage(param, rowwise_bytes_view) + rowwise_data = getattr(param, "rowwise_data", None) + if ( + rowwise_data is None + or rowwise_data.data_ptr() != rowwise_bytes_view.view(-1).data_ptr() + ): + raise RuntimeError( + "Failed to remap grouped NVFP4 rowwise storage into DDP " + "param_data." + ) + # Grouped MXFP8: do not remap grouped quantized storage into param_data. + # Use grad-buffer AG reuse and copy gathered values back after AG. + elif is_grouped_mxfp8tensor(param): + raise RuntimeError( + "Single grouped MXFP8 params require " + "--reuse-grad-buf-for-mxfp8-param-ag." + ) + elif is_grouped_tensor_with_quantized_storage(param): + raise RuntimeError( + "Unsupported single grouped quantized parameter recipe." + ) + # Grouped BF16/FP16: remap full rowwise_data. + else: + # NVFP4 packs two FP4 values per byte, so param_data uses + # packed-byte offsets instead of logical element offsets. + new_param_data = self._get( + param.data.shape, + ( + nvfp4_packed_param_start_index + if self.has_nvfp4_params + else param_start_index + ), + buffer_type=BufferType.PARAM, + ) + modify_grouped_tensor_rowwise_storage(param, new_param_data) + rowwise_data = getattr(param, "rowwise_data", None) + if ( + rowwise_data is None + or rowwise_data.data_ptr() != new_param_data.view(-1).data_ptr() + ): + raise RuntimeError( + "Failed to remap high-precision TE GroupedTensor parameter " + "storage into DDP param_data." + ) # Grad buffer always uses full-numel offsets from param_index_map. param.main_grad = self._get( @@ -1345,7 +1449,7 @@ def _pad_end_of_bucket(bucket_end_index: int) -> int: cur_bucket_id = bucket_id # NVFP4 tensors use half the numel in the packed param buffer. - if is_nvfp4tensor(param): + if is_nvfp4tensor(param) or is_grouped_nvfp4tensor(param): assert ( param_numel % 2 == 0 ), f"NVFP4 requires even numel for packing, got {param_numel}" diff --git a/megatron/core/fp4_utils.py b/megatron/core/fp4_utils.py index 45e57285a8d..a31ba7630c8 100644 --- a/megatron/core/fp4_utils.py +++ b/megatron/core/fp4_utils.py @@ -7,7 +7,11 @@ import torch from megatron.core.enums import Fp4Recipe -from megatron.core.fp8_utils import _get_custom_recipe +from megatron.core.fp8_utils import ( + _get_custom_recipe, + _get_grouped_quantized_recipe, + _unwrap_parameter_data, +) from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import is_te_min_version @@ -55,6 +59,14 @@ def is_nvfp4tensor(tensor: torch.Tensor) -> bool: return HAVE_TE_FP4_TENSOR_CLASS and isinstance(tensor, FP4_TENSOR_CLASS) +def is_grouped_nvfp4tensor(tensor: torch.Tensor) -> bool: + """Check if a TE GroupedTensor stores NVFP4 member tensors.""" + if not HAVE_TE_FP4_TENSOR_CLASS: + return False + recipe = _get_grouped_quantized_recipe(tensor) + return recipe is not None and hasattr(recipe, "nvfp4") and recipe.nvfp4() + + def get_nvfp4_rowwise_packed_shape(shape: torch.Size) -> torch.Size: """Return packed byte shape for NVFP4 rowwise storage (last dim // 2).""" if len(shape) == 0: @@ -85,6 +97,41 @@ def modify_nvfp4_rowwise_storage(fp4_tensor: torch.Tensor, new_rowwise_data: tor del old_rowwise +def modify_grouped_nvfp4_rowwise_storage( + grouped_tensor: torch.Tensor, new_rowwise_data: torch.Tensor +) -> None: + """Replace grouped NVFP4 rowwise data with a new uint8 storage view. + + The name intentionally mirrors `modify_nvfp4_rowwise_storage`: only the + packed rowwise byte buffer is remapped into the DDP buffer. The grouped + scale, amax, and columnwise buffers remain owned by the original tensor. + """ + tensor = _unwrap_parameter_data(grouped_tensor) + if not is_grouped_nvfp4tensor(tensor): + raise ValueError("modify_grouped_nvfp4_rowwise_storage expects grouped NVFP4 storage") + + old_rowwise = getattr(tensor, "rowwise_data", None) + if old_rowwise is None: + raise RuntimeError("Grouped NVFP4 tensor is missing rowwise data to replace") + + new_rowwise_data = new_rowwise_data.view(-1) + if old_rowwise.numel() != new_rowwise_data.numel(): + raise ValueError( + "Grouped NVFP4 rowwise storage size mismatch: " + f"old numel={old_rowwise.numel()}, new numel={new_rowwise_data.numel()}" + ) + assert ( + old_rowwise.dtype == new_rowwise_data.dtype == torch.uint8 + ), "Grouped NVFP4 rowwise storage must be uint8" + + new_rowwise_data.detach().copy_(old_rowwise.view(-1)) + tensor.rowwise_data = new_rowwise_data + # Member views capture data pointers. Refresh them after swapping rowwise storage while + # preserving the existing scale/amax/columnwise grouped buffers. + tensor.quantized_tensors = tensor.split_into_quantized_tensors() + del old_rowwise + + def quantize_nvfp4_param_shard( model_params, main_params, start_offsets, data_parallel_group, fsdp_shard_model_params=None ): diff --git a/megatron/core/fp8_utils.py b/megatron/core/fp8_utils.py index c9335e3b9f8..5411b676d83 100644 --- a/megatron/core/fp8_utils.py +++ b/megatron/core/fp8_utils.py @@ -62,6 +62,14 @@ # MXFP8Tensor not found HAVE_TE_MXFP8TENSOR = False +try: + from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor + + HAVE_TE_GROUPED_TENSOR_CLASS = True +except (ImportError, ModuleNotFoundError): + GroupedTensor = None + HAVE_TE_GROUPED_TENSOR_CLASS = False + if HAVE_TE: from megatron.core.extensions.transformer_engine import ( TEColumnParallelLinear, @@ -93,6 +101,24 @@ te_post_all_gather_processing = None +def _unwrap_parameter_data(tensor: torch.Tensor) -> torch.Tensor: + """Return underlying tensor data when PyTorch wraps a tensor subclass as a Parameter.""" + if HAVE_TE_GROUPED_TENSOR_CLASS and isinstance(tensor, GroupedTensor): + # TE GroupedTensor stores its real payload in Python-side metadata fields + # such as rowwise_data/scale_inv. PyTorch marks tensor-subclass parameters + # as Parameters, so tensor.data would create a detached wrapper copy. Return + # the live wrapper so storage metadata mutations update the module parameter. + return tensor + return tensor.data if isinstance(tensor, torch.nn.Parameter) else tensor + + +def _is_instance_or_param_data(tensor: torch.Tensor, tensor_class: type) -> bool: + """Check a tensor subclass, including when wrapped by torch.nn.Parameter.""" + return isinstance(tensor, tensor_class) or isinstance( + _unwrap_parameter_data(tensor), tensor_class + ) + + def is_float8tensor(tensor: torch.Tensor) -> bool: """Check if a tensor is a Transformer Engine Float8Tensor. @@ -102,12 +128,136 @@ def is_float8tensor(tensor: torch.Tensor) -> bool: are both inherited from QuantizedTensor. So, for TE1.x, FP8_TENSOR_CLASS is Float8Tensor, and for TE2.x, FP8_TENSOR_CLASS is QuantizedTensor. """ - return HAVE_TE_FP8_TENSOR_CLASS and isinstance(tensor, FP8_TENSOR_CLASS) + return HAVE_TE_FP8_TENSOR_CLASS and _is_instance_or_param_data(tensor, FP8_TENSOR_CLASS) def is_mxfp8tensor(tensor: torch.Tensor) -> bool: """Check if a tensor is a Transformer Engine MXFP8Tensor""" - return HAVE_TE_MXFP8TENSOR and isinstance(tensor, MXFP8Tensor) + return HAVE_TE_MXFP8TENSOR and _is_instance_or_param_data(tensor, MXFP8Tensor) + + +def is_grouped_tensor(tensor: torch.Tensor) -> bool: + """Check if a tensor is a Transformer Engine GroupedTensor.""" + return HAVE_TE_GROUPED_TENSOR_CLASS and _is_instance_or_param_data(tensor, GroupedTensor) + + +def is_grouped_tensor_with_quantized_storage(tensor: torch.Tensor) -> bool: + """Check if a Transformer Engine GroupedTensor owns quantized primary storage.""" + tensor = _unwrap_parameter_data(tensor) + if not is_grouped_tensor(tensor): + return False + rowwise_data = getattr(tensor, "rowwise_data", None) + return rowwise_data is not None and rowwise_data.dtype == torch.uint8 + + +def _get_grouped_quantized_recipe(tensor: torch.Tensor): + """Return TE recipe for grouped quantized storage, or None if unavailable.""" + tensor = _unwrap_parameter_data(tensor) + if not is_grouped_tensor_with_quantized_storage(tensor): + return None + + quantizer = getattr(tensor, "quantizer", None) + if quantizer is None or not hasattr(quantizer, "_get_compatible_recipe"): + return None + return quantizer._get_compatible_recipe() + + +def is_grouped_mxfp8tensor(tensor: torch.Tensor) -> bool: + """Check if a TE GroupedTensor stores MXFP8 member tensors.""" + if not HAVE_TE_MXFP8TENSOR: + return False + recipe = _get_grouped_quantized_recipe(tensor) + return recipe is not None and hasattr(recipe, "mxfp8") and recipe.mxfp8() + + +def get_grouped_quantized_members( + tensor: torch.Tensor, *, create_if_missing: bool = False +) -> List[torch.Tensor]: + """Return cached per-member views for a grouped quantized tensor.""" + grouped_tensor = _unwrap_parameter_data(tensor) + if not is_grouped_tensor_with_quantized_storage(grouped_tensor): + raise ValueError("get_grouped_quantized_members expects grouped quantized storage.") + + quantized_members = getattr(grouped_tensor, "quantized_tensors", None) + if quantized_members is None: + if not create_if_missing: + raise RuntimeError( + "Grouped quantized parameter is missing cached member tensors. " + "Create them outside the training critical path." + ) + quantized_members = grouped_tensor.split_into_quantized_tensors() + grouped_tensor.quantized_tensors = quantized_members + return quantized_members + + +def copy_tensor_to_quantized_param(param: torch.Tensor, src: torch.Tensor) -> None: + """Copy high-precision values into TE quantized parameter storage.""" + dst = _unwrap_parameter_data(param) + + if is_grouped_tensor_with_quantized_storage(dst): + if src.numel() != dst.numel(): + raise ValueError( + "Grouped quantized parameter copy size mismatch: " + f"src numel={src.numel()}, dst numel={dst.numel()}" + ) + if not dst.all_same_shape(): + raise NotImplementedError( + "Copying into grouped quantized parameters requires uniform member shapes." + ) + + # Grouped quantized tensors cannot use GroupedTensor.copy_ here because + # the generic grouped path can rebuild member tensors through + # split_into_quantized_tensors(), which is not graph safe. Update cached + # member tensors in place instead. + quantized_members = get_grouped_quantized_members(dst) + src_members = src.view(dst.shape).unbind(dim=0) + if len(src_members) != len(quantized_members): + raise RuntimeError( + "Grouped quantized parameter member count mismatch: " + f"src members={len(src_members)}, dst members={len(quantized_members)}" + ) + + for src_member, dst_member in zip(src_members, quantized_members): + dst.quantizer.update_quantized(src_member, dst_member) + return + + # Plain TE quantized tensors override copy_ to requantize into their + # backing storage. + dst.copy_(src.view(dst.shape)) + + +def modify_grouped_tensor_rowwise_storage(tensor: torch.Tensor, new_storage: torch.Tensor) -> None: + """Replace a high-precision Transformer Engine GroupedTensor's rowwise storage.""" + tensor = _unwrap_parameter_data(tensor) + if not is_grouped_tensor(tensor): + raise ValueError("modify_grouped_tensor_rowwise_storage expects a GroupedTensor.") + if is_grouped_tensor_with_quantized_storage(tensor): + raise ValueError( + "modify_grouped_tensor_rowwise_storage only supports high-precision GroupedTensor " + "storage. Quantized grouped storage also owns scale buffers." + ) + + old_rowwise_data = getattr(tensor, "rowwise_data", None) + if old_rowwise_data is None: + raise RuntimeError("GroupedTensor is missing rowwise_data.") + + new_storage = new_storage.view(-1) + if old_rowwise_data.numel() != new_storage.numel(): + raise ValueError( + "GroupedTensor backing storage size mismatch: " + f"old numel={old_rowwise_data.numel()}, new numel={new_storage.numel()}" + ) + if old_rowwise_data.dtype != new_storage.dtype: + raise ValueError( + "GroupedTensor backing storage dtype mismatch: " + f"old dtype={old_rowwise_data.dtype}, new dtype={new_storage.dtype}" + ) + + new_storage.detach().copy_(old_rowwise_data) + tensor.rowwise_data = new_storage + tensor.columnwise_data = None + tensor.quantized_tensors = None + del old_rowwise_data def dequantize_fp8_tensor(fp8_tensor: torch.Tensor) -> torch.Tensor: @@ -502,8 +652,18 @@ def post_all_gather_processing(model_params): - tensorwise: may need to create a transposed view to match backend GEMM. - blockwise: create column-wise storage. """ + if not isinstance(model_params, list): + model_params = [model_params] + + expanded_model_params = [] + for param in model_params: + if is_grouped_tensor_with_quantized_storage(param): + expanded_model_params.extend(get_grouped_quantized_members(param)) + else: + expanded_model_params.append(param) + if te_post_all_gather_processing is not None: - te_post_all_gather_processing(model_params) + te_post_all_gather_processing(expanded_model_params) else: # If the TE version is old and does not have post_all_gather_processing function, this is # a no-op, and the transpose/columnwise data will be created in the next forward pass. diff --git a/megatron/core/optimizer/distrib_optimizer.py b/megatron/core/optimizer/distrib_optimizer.py index 9e030a6b17f..456bcefd48a 100644 --- a/megatron/core/optimizer/distrib_optimizer.py +++ b/megatron/core/optimizer/distrib_optimizer.py @@ -52,8 +52,14 @@ group_params_for_buffers, partition_buckets, ) -from ..fp4_utils import is_nvfp4tensor, quantize_nvfp4_param_shard -from ..fp8_utils import dequantize_fp8_tensor, is_float8tensor, quantize_param_shard +from ..fp4_utils import is_grouped_nvfp4tensor, is_nvfp4tensor, quantize_nvfp4_param_shard +from ..fp8_utils import ( + dequantize_fp8_tensor, + get_grouped_quantized_members, + is_float8tensor, + is_grouped_tensor_with_quantized_storage, + quantize_param_shard, +) from ..transformer.fsdp_dtensor_checkpoint import handle_experts_in_state_dict from ..transformer.module import MegatronModule from .grad_scaler import MegatronGradScaler @@ -1071,19 +1077,31 @@ def _get_main_param_and_optimizer_states(self, model_param): @staticmethod def _is_grouped_quantized_tensor(tensor: torch.Tensor) -> bool: """Check if tensor is a TE GroupedTensor using quantized storage.""" - return ( - hasattr(tensor, "split_into_quantized_tensors") - and callable(tensor.split_into_quantized_tensors) - and getattr(tensor, "quantizer", None) is not None - ) + return is_grouped_tensor_with_quantized_storage(tensor) @classmethod def _is_distopt_quantized_param(cls, tensor: torch.Tensor) -> bool: """Check if tensor should follow quantized parameter path in dist optimizer.""" return is_float8tensor(tensor) or cls._is_grouped_quantized_tensor(tensor) + @classmethod + def _get_grouped_quantized_members(cls, tensor: torch.Tensor) -> List[torch.Tensor]: + """Return cached member tensors from a grouped quantized parameter.""" + return get_grouped_quantized_members(tensor, create_if_missing=True) + + @classmethod + def _is_grouped_nvfp4_param(cls, tensor: torch.Tensor) -> bool: + """Check if a grouped quantized parameter stores NVFP4 member tensors.""" + return is_grouped_nvfp4tensor(tensor) + + @classmethod + def _is_fp8_param_for_param_gather(cls, tensor: torch.Tensor) -> bool: + """Check if a quantized param should use the FP8/MXFP8 param-gather cast path.""" + return cls._is_distopt_quantized_param(tensor) and not cls._is_grouped_nvfp4_param(tensor) + + @classmethod def _expand_quantized_param_shard_for_cast( - self, + cls, model_param: torch.Tensor, shard_main_param: Optional[torch.Tensor], start_offset: Optional[int], @@ -1094,12 +1112,10 @@ def _expand_quantized_param_shard_for_cast( master slice to per-member offset ranges, while preserving deterministic ordering across DP ranks. """ - if not self._is_grouped_quantized_tensor(model_param): + if not cls._is_grouped_quantized_tensor(model_param): return [model_param], [shard_main_param], [start_offset] - quantized_members = model_param.quantized_tensors - if quantized_members is None: - quantized_members = model_param.split_into_quantized_tensors() + quantized_members = cls._get_grouped_quantized_members(model_param) shard_start = 0 if start_offset is None else start_offset shard_size = 0 if shard_main_param is None else shard_main_param.numel() @@ -2592,7 +2608,7 @@ def _get_fp8_params_and_shard_fp32_from_fp8(self): idx = 0 for buffer in buffers: for param in buffer.params: - if self._is_distopt_quantized_param(param): + if self._is_fp8_param_for_param_gather(param): fp8_params.append(param) shard_fp32_from_fp8.append(None) shard_offsets_in_fp8.append(None) @@ -2607,7 +2623,7 @@ def get_shard_fp32_from_fp8(shard_main_groups, model_groups): """ for shard_main_group, model_group in zip(shard_main_groups, model_groups): for shard_main_param, model_param in zip(shard_main_group, model_group): - if self._is_distopt_quantized_param(model_param): + if self._is_fp8_param_for_param_gather(model_param): param_range_map = self._get_model_param_range_map(model_param) param_range = param_range_map["param"] assert param_range.size == shard_main_param.nelement() @@ -2642,6 +2658,13 @@ def _get_nvfp4_params_and_shard_fp32_from_nvfp4(self): shard_offsets_in_nvfp4.append(None) nvfp4_param_to_idx_map[param] = idx idx += 1 + elif self._is_grouped_nvfp4_param(param): + members = self._get_grouped_quantized_members(param) + nvfp4_params.extend(members) + shard_fp32_from_nvfp4.extend([None] * len(members)) + shard_offsets_in_nvfp4.extend([None] * len(members)) + nvfp4_param_to_idx_map[param] = list(range(idx, idx + len(members))) + idx += len(members) def _get_shard_fp32_from_nvfp4(shard_main_groups, model_groups): """Populate shard_fp32_from_nvfp4 and shard_offsets_in_nvfp4 for NVFP4 params.""" @@ -2654,6 +2677,28 @@ def _get_shard_fp32_from_nvfp4(shard_main_groups, model_groups): idx = nvfp4_param_to_idx_map[model_param] shard_fp32_from_nvfp4[idx] = shard_main_param shard_offsets_in_nvfp4[idx] = param_range.start + elif self._is_grouped_nvfp4_param(model_param): + param_range_map = self._get_model_param_range_map(model_param) + param_range = param_range_map["param"] + assert param_range.size == shard_main_param.nelement() + ( + expanded_model_params, + expanded_shard_main_params, + expanded_start_offsets, + ) = self._expand_quantized_param_shard_for_cast( + model_param, shard_main_param, param_range.start + ) + indices = nvfp4_param_to_idx_map[model_param] + assert len(indices) == len(expanded_model_params) + for idx, member, member_master, member_offset in zip( + indices, + expanded_model_params, + expanded_shard_main_params, + expanded_start_offsets, + ): + assert nvfp4_params[idx] is member + shard_fp32_from_nvfp4[idx] = member_master + shard_offsets_in_nvfp4[idx] = member_offset _get_shard_fp32_from_nvfp4(self.shard_fp32_from_float16_groups, self.model_float16_groups) _get_shard_fp32_from_nvfp4(self.shard_fp32_groups, self.model_fp32_groups) @@ -2836,6 +2881,12 @@ def _copy_main_params_to_param_buffer(self): shard_param_buffer.copy_(shard_main_param) + # Staging params into the DDP param buffer invalidates any prior "already + # dispatched" state. The next forward pre-hook must run post-sync cleanup, + # especially when MXFP8 reuses grad_data as the param AG buffer. + for model_chunk in self.model_chunks: + model_chunk.reset_param_sync_dispatch_state() + @staticmethod def _normalize_state_dict_for_grouped_params(state_dict_flat, model_chunk): """Normalize state dict keys when grouped/indexed parameter formats differ. diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index bde1737fce0..05d97269b1c 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -360,27 +360,61 @@ def _is_fused_impl_supported(self) -> bool: if not (use_glu_fusion or use_srelu_fusion): return False if self.config.activation_func == F.silu: - return True - if self.config.activation_func == quick_gelu: + pass + elif self.config.activation_func == quick_gelu: try: from transformer_engine.pytorch.ops import ScaledClampedQGeGLU # noqa: F401 except ImportError: return False - return True - if self.config.activation_func == squared_relu: + elif self.config.activation_func == squared_relu: try: from transformer_engine.pytorch.ops import ScaledSReLU # noqa: F401 except ImportError: return False - return True + else: + return False - return False + # Check TE CuTe DSL fused kernel conditions (must match TE's + # fuse_grouped_mlp_ops matching logic). + import os + + if use_glu_fusion and int(os.environ.get("NVTE_CUTEDSL_FUSED_GROUPED_MLP", "0")) <= 0: + return False + return True def _make_fused_ops(self) -> torch.nn.Module: """Construct fused module for FC1, activation, and FC2.""" assert HAVE_TE, "_make_fused_ops requires Transformer Engine." + def register_grouped_linear_params( + op: torch.nn.Module, + linear: torch.nn.Module, + single_grouped_weight: bool, + single_grouped_bias: bool, + ) -> None: + """Register real GroupedLinear params on a meta TE op shell.""" + if single_grouped_weight: + op.register_parameter("weight", linear.get_parameter("weight")) + for idx in range(linear.num_gemms): + op.register_parameter(f"weight{idx}", None) + else: + op.register_parameter("weight", None) + for idx in range(linear.num_gemms): + op.register_parameter(f"weight{idx}", linear.get_parameter(f"weight{idx}")) + + if not linear.use_bias: + return + + if single_grouped_bias: + op.register_parameter("bias", linear.get_parameter("bias")) + for idx in range(linear.num_gemms): + op.register_parameter(f"bias{idx}", None) + else: + op.register_parameter("bias", None) + for idx in range(linear.num_gemms): + op.register_parameter(f"bias{idx}", linear.get_parameter(f"bias{idx}")) + # Container for fusible ops ops = te.pytorch.ops.Sequential() @@ -421,17 +455,11 @@ def _make_fused_ops(self) -> torch.nn.Module: delay_wgrad_compute=fc1_delay_wgrad_compute, ) - # Copy the weights from GroupedLinear module to GroupedLinear op. - if fc1_single_grouped_weight: - setattr(op, "weight", getattr(self.linear_fc1, "weight")) - - for idx in range(self.linear_fc1.num_gemms): - if not fc1_single_grouped_weight: - setattr(op, f"weight{idx}", getattr(self.linear_fc1, f"weight{idx}")) - if self.linear_fc1.use_bias and not fc1_single_grouped_bias: - setattr(op, f"bias{idx}", getattr(self.linear_fc1, f"bias{idx}")) - if self.linear_fc1.use_bias and fc1_single_grouped_bias: - setattr(op, "bias", getattr(self.linear_fc1, "bias")) + # In single grouped mode, clear stale per-expert meta params so TE does not reset + # the op and replace the shared DDP parameter with a fresh one lacking main_grad. + register_grouped_linear_params( + op, self.linear_fc1, fc1_single_grouped_weight, fc1_single_grouped_bias + ) ops.append(op) # Activation and post-multiply probs (SwiGLU, clamped quick-GeGLU, or SReLU) @@ -510,17 +538,11 @@ def _make_fused_ops(self) -> torch.nn.Module: delay_wgrad_compute=fc2_delay_wgrad_compute, ) - # Copy the weights from GroupedLinear module to GroupedLinear op. - if fc2_single_grouped_weight: - setattr(op, "weight", getattr(self.linear_fc2, "weight")) - - for idx in range(self.linear_fc2.num_gemms): - if not fc2_single_grouped_weight: - setattr(op, f"weight{idx}", getattr(self.linear_fc2, f"weight{idx}")) - if self.linear_fc2.use_bias and not fc2_single_grouped_bias: - setattr(op, f"bias{idx}", getattr(self.linear_fc2, f"bias{idx}")) - if self.linear_fc2.use_bias and fc2_single_grouped_bias: - setattr(op, "bias", getattr(self.linear_fc2, "bias")) + # In single grouped mode, clear stale per-expert meta params so TE does not reset + # the op and replace the shared DDP parameter with a fresh one lacking main_grad. + register_grouped_linear_params( + op, self.linear_fc2, fc2_single_grouped_weight, fc2_single_grouped_bias + ) ops.append(op) # Emulate submodule pre-forward hooks diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 812470a73f4..67d171e5745 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -1480,16 +1480,24 @@ def __post_init__(self): f"transformer-engine>=2.14.0, but your version is {get_te_version()}." ) if self.moe_single_grouped_weight: - # The dist-optimizer's quantized-param shard path on the single-grouped-weight - # storage is only validated for fp8 mode with the mxfp8 recipe today; other - # combinations have a known numerical issue tracked in upstream PR - # NVIDIA/Megatron-LM#4621. Reject at construction time so users don't silently - # train on a broken numerical path. (moe_single_grouped_bias is not gated: - # biases aren't quantized, so they don't enter the buggy code path.) - if self.fp4 or not self.fp8 or self.fp8_recipe != Fp8Recipe.mxfp8: + # Single grouped weights are supported for high-precision primary weights + # (BF16/FP16), MXFP8 primary weights, and NVFP4 primary weights. + # Other quantized primary-weight paths need grouped partial-cast support + # before they are safe to enable. + if (self.fp8 and self.fp8_recipe != Fp8Recipe.mxfp8) or ( + self.fp4 and self.fp4_recipe != Fp4Recipe.nvfp4 + ): raise ValueError( - "moe_single_grouped_weight is currently supported only with fp8 mode " - "and fp8_recipe='mxfp8'." + "moe_single_grouped_weight is currently supported with high-precision " + "primary weights, fp8_recipe='mxfp8', or fp4_recipe='nvfp4'." + ) + if not self.use_transformer_engine_op_fuser: + warnings.warn( + "moe_single_grouped_weight=True without " + "use_transformer_engine_op_fuser=True is functionally supported but not " + "performance-optimized. The non-op-fuser TE GroupedLinear path may split " + "the grouped weight into per-expert tensors for GEMM; enable " + "--use-transformer-engine-op-fuser for the fast grouped-weight path." ) if self.moe_single_grouped_bias and not self.add_bias_linear: raise ValueError("moe_single_grouped_bias requires add_bias_linear=True.") diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 9764bb5f0b6..883bd5b383e 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1083,6 +1083,17 @@ def validate_args(args, defaults={}): args.use_distributed_optimizer = True # Optimizer step MXFP8 buffer operation that is not relevant or supported for Megatron-FSDP. args.reuse_grad_buf_for_mxfp8_param_ag = False + if args.moe_single_grouped_weight or args.moe_single_grouped_bias: + # Megatron-FSDP currently remaps module parameters through plain Tensor and TE + # Float8Tensor/MXFP8Tensor storage paths. TE GroupedTensor parameters need their + # grouped backing storage remapped instead; quantized grouped tensors also need + # grouped scale/amax handling. DDP has a separate GroupedTensor-aware path. + raise ValueError( + "Megatron-FSDP does not currently support moe_single_grouped_weight or " + "moe_single_grouped_bias. Disable single grouped MoE parameters or use the " + "regular DDP/distributed optimizer path until Megatron-FSDP supports TE " + "GroupedTensor param buffers." + ) # Optimizer compatibility check. assert args.optimizer in ('sgd', 'adam'), \ f"Megatron-FSDP does not support the {args.optimizer} optimizer yet." diff --git a/tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py b/tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py new file mode 100644 index 00000000000..a3cb20295a0 --- /dev/null +++ b/tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py @@ -0,0 +1,726 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +import gc +import inspect +import os +import sys +import traceback + +import pytest +import torch + +from megatron.core.enums import ModelType +from megatron.core.fp4_utils import is_grouped_nvfp4tensor +from megatron.core.fp8_utils import ( + is_grouped_mxfp8tensor, + is_grouped_tensor, + is_grouped_tensor_with_quantized_storage, +) +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec +from megatron.core.models.gpt.gpt_model import GPTModel +from megatron.core.num_microbatches_calculator import destroy_num_microbatches_calculator +from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from megatron.core.utils import is_te_min_version +from megatron.training.arguments import core_transformer_config_from_args, parse_args, validate_args +from megatron.training.checkpointing import load_checkpoint, save_checkpoint +from megatron.training.global_vars import ( + destroy_global_vars, + get_args, + set_args, + set_global_variables, +) +from megatron.training.training import force_param_sync, setup_model_and_optimizer +from megatron.training.utils import get_device_arch_version +from tests.unit_tests.dist_checkpointing import TempNamedDir +from tests.unit_tests.test_utilities import Utils + +try: + from transformer_engine.pytorch.fp8 import check_fp8_support, check_nvfp4_support + + _FP8_AVAILABLE, _NO_FP8_REASON = check_fp8_support() + _NVFP4_AVAILABLE, _NO_NVFP4_REASON = check_nvfp4_support() +except ImportError: + _FP8_AVAILABLE = False + _NO_FP8_REASON = "Transformer Engine FP8 support is unavailable" + _NVFP4_AVAILABLE = False + _NO_NVFP4_REASON = "Transformer Engine NVFP4 support is unavailable" + + +_SEED = 1234 +_BLACKWELL_AVAILABLE = torch.cuda.is_available() and get_device_arch_version() >= 10 +try: + from transformer_engine.pytorch import GroupedLinear as TEGroupedLinear + + _TE_GROUPED_LINEAR_SUPPORTS_SINGLE_PARAM = ( + "single_grouped_weight" in inspect.signature(TEGroupedLinear.__init__).parameters + ) +except (ImportError, AttributeError): + _TE_GROUPED_LINEAR_SUPPORTS_SINGLE_PARAM = False + +pytestmark = [ + pytest.mark.internal, + pytest.mark.skipif( + not is_te_min_version("2.14.0"), + reason="moe_single_grouped_weight requires Transformer Engine >= 2.14.0", + ), + pytest.mark.skipif( + not _TE_GROUPED_LINEAR_SUPPORTS_SINGLE_PARAM, + reason="Installed TE GroupedLinear does not expose single_grouped_weight", + ), +] + + +def _skip_if_unsupported(precision: str) -> None: + if Utils.world_size < 2: + pytest.skip("distributed optimizer parity test requires torchrun with at least 2 ranks") + + if precision in ("mxfp8", "nvfp4") and not _BLACKWELL_AVAILABLE: + pytest.skip(f"{precision} single grouped weight parity requires Blackwell (SM >= 10)") + if precision == "mxfp8" and not _FP8_AVAILABLE: + pytest.skip(_NO_FP8_REASON) + if precision == "nvfp4" and not _NVFP4_AVAILABLE: + pytest.skip(_NO_NVFP4_REASON) + + +class TestMoESingleGroupedWeightNumerics: + """Numerical parity tests for MoE single grouped weights under DistOpt.""" + + seq_length = 128 + micro_batch_size = 2 + num_train_steps = 4 + + def setup_method(self, method): + self._old_single_param_env = os.environ.get("NVTE_GROUPED_LINEAR_SINGLE_PARAM") + self._old_cutedsl_fused_grouped_mlp_env = os.environ.get("NVTE_CUTEDSL_FUSED_GROUPED_MLP") + os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" + os.environ["NVTE_GROUPED_LINEAR_SINGLE_PARAM"] = "1" + os.environ["NVTE_CUTEDSL_FUSED_GROUPED_MLP"] = "1" + + def teardown_method(self, method): + try: + self._cleanup() + finally: + if self._old_single_param_env is None: + os.environ.pop("NVTE_GROUPED_LINEAR_SINGLE_PARAM", None) + else: + os.environ["NVTE_GROUPED_LINEAR_SINGLE_PARAM"] = self._old_single_param_env + if self._old_cutedsl_fused_grouped_mlp_env is None: + os.environ.pop("NVTE_CUTEDSL_FUSED_GROUPED_MLP", None) + else: + os.environ["NVTE_CUTEDSL_FUSED_GROUPED_MLP"] = ( + self._old_cutedsl_fused_grouped_mlp_env + ) + + def _cleanup(self): + Utils.destroy_model_parallel() + destroy_global_vars() + destroy_num_microbatches_calculator() + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + def model_provider( + self, pre_process=True, post_process=True, config=None, pg_collection=None, vp_stage=None + ): + model_parallel_cuda_manual_seed(_SEED) + args = get_args() + if config is None: + config = core_transformer_config_from_args(args) + transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( + num_experts=args.num_experts, moe_grouped_gemm=args.moe_grouped_gemm + ) + return GPTModel( + config=config, + transformer_layer_spec=transformer_layer_spec, + vocab_size=args.vocal_size, + max_sequence_length=args.max_position_embeddings, + pre_process=pre_process, + post_process=post_process, + fp16_lm_cross_entropy=args.fp16_lm_cross_entropy, + parallel_output=True, + share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, + position_embedding_type=args.position_embedding_type, + rotary_percent=args.rotary_percent, + pg_collection=pg_collection, + vp_stage=vp_stage, + ) + + def create_test_args( + self, + precision: str, + primary_param_gather: bool, + single_weight: bool, + gradient_accumulation_fusion: bool, + use_transformer_engine_op_fuser: bool, + overlap_param_gather: bool = False, + overlap_grad_reduce: bool = False, + grad_reduce_in_fp32: bool = False, + ): + self._cleanup() + + sys.argv = ["test_moe_single_grouped_weight_numerics.py"] + args = parse_args() + args.num_layers = 1 + args.vocal_size = 1024 + args.hidden_size = 256 + args.ffn_hidden_size = 256 + args.num_attention_heads = 8 + args.max_position_embeddings = self.seq_length + args.seq_length = self.seq_length + args.micro_batch_size = self.micro_batch_size + args.global_batch_size = self.micro_batch_size * Utils.world_size + args.create_attention_mask_in_dataloader = True + args.tensor_model_parallel_size = 1 + args.pipeline_model_parallel_size = 1 + args.context_parallel_size = 1 + args.expert_model_parallel_size = 1 + args.train_iters = self.num_train_steps + args.lr = 3e-5 + args.bf16 = True + args.attention_backend = "unfused" + args.add_bias_linear = False + args.hidden_dropout = 0.0 + args.attention_dropout = 0.0 + args.swiglu = True + args.gradient_accumulation_fusion = gradient_accumulation_fusion + args.use_distributed_optimizer = True + args.use_transformer_engine_op_fuser = use_transformer_engine_op_fuser + args.overlap_param_gather = overlap_param_gather + args.overlap_grad_reduce = overlap_grad_reduce + args.accumulate_allreduce_grads_in_fp32 = grad_reduce_in_fp32 + args.ddp_bucket_size = 40960 + + args.num_experts = 2 + args.moe_layer_freq = 1 + args.moe_grouped_gemm = True + args.moe_single_grouped_weight = single_weight + args.moe_token_dispatcher_type = "alltoall" + args.moe_router_topk = 1 + args.moe_router_pre_softmax = True + args.moe_router_load_balancing_type = "none" + args.moe_aux_loss_coeff = 0.0 + args.moe_ffn_hidden_size = 256 + args.moe_mlp_glu_interleave_size = 32 + + if precision == "mxfp8": + args.fp8 = "e4m3" + args.fp8_recipe = "mxfp8" + args.fp8_param_gather = primary_param_gather + args.reuse_grad_buf_for_mxfp8_param_ag = primary_param_gather + elif precision == "nvfp4": + args.fp4 = "e2m1" + args.fp4_recipe = "nvfp4" + args.fp4_param_gather = primary_param_gather + elif precision != "bf16": + raise ValueError(f"Unknown precision test case: {precision}") + + validate_args(args) + set_global_variables(args, False) + return args + + def get_batch(self): + data = torch.arange(self.seq_length, dtype=torch.int64, device="cuda") + input_ids = data.repeat((self.micro_batch_size, 1)) + labels = (data + 1).repeat((self.micro_batch_size, 1)) + position_ids = data.repeat((self.micro_batch_size, 1)) + attention_mask = torch.ones( + (self.micro_batch_size, 1, self.seq_length, self.seq_length), dtype=bool, device="cuda" + ) + loss_mask = torch.ones( + (self.micro_batch_size, self.seq_length), dtype=torch.float32, device="cuda" + ) + return input_ids, labels, position_ids, attention_mask, loss_mask + + def assert_storage_path_is_exercised( + self, model, precision: str, primary_param_gather: bool, single_weight: bool + ): + params = list(model.named_parameters()) + if not single_weight: + assert not any(is_grouped_tensor(param) for _, param in params) + return + + grouped_params = [param for _, param in params if is_grouped_tensor(param)] + assert grouped_params, "Expected at least one TE GroupedTensor MoE parameter" + + if not primary_param_gather or precision == "bf16": + assert any( + not is_grouped_tensor_with_quantized_storage(param) for param in grouped_params + ), "Expected high-precision grouped primary weights" + return + + if precision == "mxfp8": + assert any(is_grouped_mxfp8tensor(param) for param in grouped_params) + elif precision == "nvfp4": + assert any(is_grouped_nvfp4tensor(param) for param in grouped_params) + + @staticmethod + def iter_distopt_buffers(optimizer): + optimizers = getattr(optimizer, "chained_optimizers", [optimizer]) + for optim_instance in optimizers: + for buffer in getattr(optim_instance, "buffers", []): + yield buffer + + def assert_grouped_params_remapped_to_ddp_param_data(self, optimizer, precision: str): + """Grouped BF16/NVFP4 params must point at the live DDP param_data slice.""" + num_checked = 0 + for buffer in self.iter_distopt_buffers(optimizer): + for bucket in buffer.buckets: + if bucket.param_data is None: + continue + for param in bucket.params: + if not is_grouped_tensor(param): + continue + + rowwise_data = getattr(param, "rowwise_data", None) + assert rowwise_data is not None, "GroupedTensor is missing rowwise_data" + + if precision == "bf16": + if is_grouped_tensor_with_quantized_storage(param): + continue + start, end = bucket.param_to_index[param] + expected = bucket.param_data.view(-1)[start:end] + elif precision == "nvfp4": + if not is_grouped_nvfp4tensor(param): + continue + packed_start, packed_end, bucket_id = buffer.nvfp4_packed_param_index_map[ + param + ] + assert bucket_id == bucket.bucket_id + bucket_start, _ = buffer.nvfp4_packed_bucket_indices[bucket_id] + expected = bucket.param_data.view(-1)[ + packed_start - bucket_start : packed_end - bucket_start + ] + else: + raise ValueError(f"Unsupported remap precision: {precision}") + + rowwise_flat = rowwise_data.view(-1) + assert rowwise_flat.numel() == expected.numel() + assert rowwise_flat.dtype == expected.dtype + assert rowwise_flat.data_ptr() == expected.data_ptr(), ( + "Live grouped parameter rowwise_data is not mapped to the DDP " + f"param_data slice for precision={precision}" + ) + num_checked += 1 + + assert num_checked > 0, f"Did not find any {precision} grouped params to verify" + + def assert_execution_path_is_exercised( + self, model, use_transformer_engine_op_fuser: bool, after_forward: bool = False + ): + grouped_mlps = [ + module for module in model.modules() if module.__class__.__name__ == "TEGroupedMLP" + ] + assert grouped_mlps, "Expected at least one TEGroupedMLP module" + assert all( + module._with_fused_impl == use_transformer_engine_op_fuser for module in grouped_mlps + ), "Unexpected TEGroupedMLP execution path" + if after_forward: + assert all( + (module._fused_ops is not None) == use_transformer_engine_op_fuser + for module in grouped_mlps + ), "Unexpected TEGroupedMLP fused-op construction state" + + def run_training_case( + self, + precision: str, + primary_param_gather: bool, + single_weight: bool, + gradient_accumulation_fusion: bool, + use_transformer_engine_op_fuser: bool, + ): + args = self.create_test_args( + precision, + primary_param_gather, + single_weight, + gradient_accumulation_fusion, + use_transformer_engine_op_fuser, + ) + set_args(args) + torch.manual_seed(_SEED) + Utils.initialize_model_parallel( + tensor_model_parallel_size=1, expert_model_parallel_size=args.expert_model_parallel_size + ) + + batch = self.get_batch() + model, optimizer, _ = setup_model_and_optimizer( + self.model_provider, ModelType.encoder_or_decoder + ) + assert len(model) == 1 + self.assert_storage_path_is_exercised( + model[0], precision, primary_param_gather, single_weight + ) + self.assert_execution_path_is_exercised(model[0], use_transformer_engine_op_fuser) + + losses = [] + for _ in range(self.num_train_steps): + model[0].zero_grad_buffer() + optimizer.zero_grad() + model[0].set_is_first_microbatch() + output = model[0].forward( + input_ids=batch[0], + labels=batch[1], + position_ids=batch[2], + attention_mask=batch[3], + loss_mask=batch[4], + ) + loss = output.mean() + assert torch.isfinite(loss) + loss.backward() + + if args.overlap_grad_reduce: + model[0].finish_grad_sync() + + update_successful, _, _ = optimizer.step() + assert update_successful + losses.append(loss.detach().float().cpu()) + + self.assert_execution_path_is_exercised( + model[0], use_transformer_engine_op_fuser, after_forward=True + ) + return torch.stack(losses) + + def run_one_mxfp8_overlap_train_step(self, args, model, optimizer, batch): + model[0].zero_grad_buffer() + optimizer.zero_grad() + if args.reuse_grad_buf_for_mxfp8_param_ag and args.overlap_param_gather: + optimizer.prepare_model_params_for_param_sync() + model[0].set_is_first_microbatch() + output = model[0].forward( + input_ids=batch[0], + labels=batch[1], + position_ids=batch[2], + attention_mask=batch[3], + loss_mask=batch[4], + ) + loss = output.mean() + assert torch.isfinite(loss) + loss.backward() + + if args.overlap_grad_reduce: + model[0].finish_grad_sync() + + update_successful, _, _ = optimizer.step() + assert update_successful + return loss.detach().float().cpu() + + def run_mxfp8_eval_step(self, args, model, optimizer, batch): + if args.reuse_grad_buf_for_mxfp8_param_ag and args.overlap_param_gather: + optimizer.prepare_model_params_for_param_sync() + + model[0].disable_forward_pre_hook(param_sync=True) + model[0].eval() + with torch.no_grad(): + output = model[0].forward( + input_ids=batch[0], + labels=batch[1], + position_ids=batch[2], + attention_mask=batch[3], + loss_mask=batch[4], + ) + assert torch.isfinite(output.mean()) + model[0].train() + model[0].enable_forward_pre_hook() + + def setup_mxfp8_overlap_case(self, single_weight: bool, checkpoint_dir=None): + args = self.create_test_args( + precision="mxfp8", + primary_param_gather=True, + single_weight=single_weight, + gradient_accumulation_fusion=True, + use_transformer_engine_op_fuser=True, + overlap_param_gather=True, + overlap_grad_reduce=True, + grad_reduce_in_fp32=True, + ) + if checkpoint_dir is not None: + args.save = checkpoint_dir + args.load = checkpoint_dir + args.ckpt_format = "torch_dist" + args.use_dist_ckpt = True + args.auto_detect_ckpt_format = False + args.async_save = False + args.ckpt_assume_constant_structure = False + args.ckpt_load_validate_sharding_integrity = True + args.dist_ckpt_strictness = "assume_ok_unexpected" + args.no_save_optim = True + args.no_load_optim = True + args.no_save_rng = True + args.no_load_rng = True + args.load_main_params_from_ckpt = True + set_args(args) + torch.manual_seed(_SEED) + Utils.initialize_model_parallel( + tensor_model_parallel_size=1, expert_model_parallel_size=args.expert_model_parallel_size + ) + + model, optimizer, opt_param_scheduler = setup_model_and_optimizer( + self.model_provider, ModelType.encoder_or_decoder + ) + assert len(model) == 1 + self.assert_storage_path_is_exercised(model[0], "mxfp8", True, single_weight) + self.assert_execution_path_is_exercised(model[0], True) + + batch = self.get_batch() + return args, model, optimizer, opt_param_scheduler, batch + + def run_mxfp8_training_losses_with_optional_eval( + self, eval_after_step: int | None, single_weight: bool = True + ): + args, model, optimizer, _, batch = self.setup_mxfp8_overlap_case( + single_weight=single_weight + ) + losses = [] + for step in range(4): + if eval_after_step is not None and step == eval_after_step: + self.run_mxfp8_eval_step(args, model, optimizer, batch) + losses.append(self.run_one_mxfp8_overlap_train_step(args, model, optimizer, batch)) + return torch.stack(losses) + + def run_mxfp8_training_losses_with_optional_checkpoint( + self, checkpoint_dir, checkpoint_before_step: int | None + ): + args, model, optimizer, opt_param_scheduler, batch = self.setup_mxfp8_overlap_case( + single_weight=True, checkpoint_dir=checkpoint_dir + ) + losses = [] + for step in range(4): + if checkpoint_before_step is not None and step == checkpoint_before_step: + force_param_sync(model, optimizer=optimizer) + save_checkpoint(step, model, optimizer, opt_param_scheduler, 0) + if torch.distributed.is_initialized(): + torch.distributed.barrier() + losses.append(self.run_one_mxfp8_overlap_train_step(args, model, optimizer, batch)) + return torch.stack(losses) + + def run_mxfp8_checkpoint_save_load_next_loss( + self, checkpoint_dir, save_single_weight: bool, load_single_weight: bool + ): + args, model, optimizer, opt_param_scheduler, batch = self.setup_mxfp8_overlap_case( + single_weight=save_single_weight, checkpoint_dir=checkpoint_dir + ) + + for _ in range(2): + self.run_one_mxfp8_overlap_train_step(args, model, optimizer, batch) + force_param_sync(model, optimizer=optimizer) + save_checkpoint(2, model, optimizer, opt_param_scheduler, 0) + if torch.distributed.is_initialized(): + torch.distributed.barrier() + + self._cleanup() + + args, model, optimizer, opt_param_scheduler, batch = self.setup_mxfp8_overlap_case( + single_weight=load_single_weight, checkpoint_dir=checkpoint_dir + ) + loaded_iteration, _ = load_checkpoint(model, optimizer, opt_param_scheduler, strict=True) + assert loaded_iteration == 2 + return self.run_one_mxfp8_overlap_train_step(args, model, optimizer, batch) + + @staticmethod + def assert_loss_parity(precision: str, single_weight_losses, discrete_weight_losses): + if precision == "bf16": + atol = rtol = 5e-3 + else: + atol = rtol = 2e-2 + torch.testing.assert_close( + single_weight_losses, discrete_weight_losses, atol=atol, rtol=rtol + ) + + @staticmethod + def assert_all_ranks_passed(local_passed: bool, local_error: str) -> None: + if not torch.distributed.is_available() or not torch.distributed.is_initialized(): + if not local_passed: + pytest.fail(local_error) + return + + pass_flag = torch.tensor( + [1 if local_passed else 0], dtype=torch.int32, device=torch.cuda.current_device() + ) + torch.distributed.all_reduce(pass_flag, op=torch.distributed.ReduceOp.MIN) + if pass_flag.item() == 1: + return + + rank = torch.distributed.get_rank() + if local_passed: + pytest.fail("At least one distributed rank failed this parity case.") + pytest.fail(f"Rank {rank} failed this parity case:\n{local_error}") + + def run_parity_case( + self, + precision: str, + primary_param_gather: bool, + gradient_accumulation_fusion: bool, + use_transformer_engine_op_fuser: bool, + ) -> None: + local_passed = True + local_error = "" + try: + single_losses = self.run_training_case( + precision=precision, + primary_param_gather=primary_param_gather, + single_weight=True, + gradient_accumulation_fusion=gradient_accumulation_fusion, + use_transformer_engine_op_fuser=use_transformer_engine_op_fuser, + ) + discrete_losses = self.run_training_case( + precision=precision, + primary_param_gather=primary_param_gather, + single_weight=False, + gradient_accumulation_fusion=gradient_accumulation_fusion, + use_transformer_engine_op_fuser=use_transformer_engine_op_fuser, + ) + self.assert_loss_parity(precision, single_losses, discrete_losses) + except Exception: + local_passed = False + local_error = traceback.format_exc() + + self.assert_all_ranks_passed(local_passed, local_error) + + def run_remap_case(self, precision: str) -> None: + local_passed = True + local_error = "" + try: + primary_param_gather = precision == "nvfp4" + args = self.create_test_args( + precision=precision, + primary_param_gather=primary_param_gather, + single_weight=True, + gradient_accumulation_fusion=True, + use_transformer_engine_op_fuser=True, + ) + set_args(args) + torch.manual_seed(_SEED) + Utils.initialize_model_parallel( + tensor_model_parallel_size=1, + expert_model_parallel_size=args.expert_model_parallel_size, + ) + + model, optimizer, _ = setup_model_and_optimizer( + self.model_provider, ModelType.encoder_or_decoder + ) + assert len(model) == 1 + self.assert_storage_path_is_exercised( + model[0], precision, primary_param_gather, single_weight=True + ) + self.assert_grouped_params_remapped_to_ddp_param_data(optimizer, precision) + except Exception: + local_passed = False + local_error = traceback.format_exc() + + self.assert_all_ranks_passed(local_passed, local_error) + + @pytest.mark.parametrize("precision", ["bf16", "nvfp4"]) + def test_single_grouped_weight_ddp_param_data_remap_data_ptr(self, precision): + """BF16/NVFP4 single grouped weights must alias DDP param_data after buffer setup.""" + _skip_if_unsupported(precision) + self.run_remap_case(precision) + + def test_single_grouped_mxfp8_train_eval_train_matches_train_only(self): + """Eval should not change subsequent MXFP8 single grouped weight training losses.""" + _skip_if_unsupported("mxfp8") + local_passed = True + local_error = "" + try: + train_only_losses = self.run_mxfp8_training_losses_with_optional_eval( + eval_after_step=None + ) + train_eval_train_losses = self.run_mxfp8_training_losses_with_optional_eval( + eval_after_step=2 + ) + torch.testing.assert_close( + train_eval_train_losses, train_only_losses, atol=1e-4, rtol=1e-4 + ) + except Exception: + local_passed = False + local_error = traceback.format_exc() + + self.assert_all_ranks_passed(local_passed, local_error) + + @pytest.mark.parametrize( + "checkpoint_case, save_single_weight, load_single_weight", + [ + # Save-only: checkpointing should not perturb live training state. + pytest.param("save_only", True, None, id="save-only-single"), + # Layout interchange: torch_dist saves grouped MoE weights as per-expert keys. + pytest.param("save_load", True, False, id="save-single-load-discrete"), + # Reverse interchange: per-expert checkpoint keys must fold into one grouped param. + pytest.param("save_load", False, True, id="save-discrete-load-single"), + ], + ) + def test_mxfp8_single_weight_torch_dist_checkpoint_matches_discrete_baseline( + self, tmp_path_dist_ckpt, checkpoint_case, save_single_weight, load_single_weight + ): + """torch_dist checkpoint save/load should preserve MXFP8 discrete baseline numerics.""" + _skip_if_unsupported("mxfp8") + local_passed = True + local_error = "" + try: + discrete_train_only_losses = self.run_mxfp8_training_losses_with_optional_eval( + eval_after_step=None, single_weight=False + ) + with TempNamedDir( + tmp_path_dist_ckpt / "test_mxfp8_single_weight_torch_dist_checkpoint", sync=True + ) as checkpoint_dir: + if checkpoint_case == "save_only": + # This catches forced-param-sync/checkpoint side effects without reload. + checkpoint_losses = self.run_mxfp8_training_losses_with_optional_checkpoint( + checkpoint_dir=checkpoint_dir, checkpoint_before_step=2 + ) + self.assert_loss_parity("mxfp8", checkpoint_losses, discrete_train_only_losses) + else: + # This catches checkpoint key/layout conversion bugs across single/discrete. + loaded_next_loss = self.run_mxfp8_checkpoint_save_load_next_loss( + checkpoint_dir, + save_single_weight=save_single_weight, + load_single_weight=load_single_weight, + ) + torch.testing.assert_close( + loaded_next_loss, discrete_train_only_losses[2], atol=2e-2, rtol=2e-2 + ) + except Exception: + local_passed = False + local_error = traceback.format_exc() + + self.assert_all_ranks_passed(local_passed, local_error) + + @pytest.mark.parametrize("precision", ["bf16", "mxfp8", "nvfp4"]) + @pytest.mark.parametrize("gradient_accumulation_fusion", [False, True]) + def test_single_grouped_weight_parity_with_primary_param_gather( + self, precision, gradient_accumulation_fusion + ): + """Compare single vs discrete MoE weights with primary param gather enabled if applicable.""" + _skip_if_unsupported(precision) + self.run_parity_case( + precision=precision, + primary_param_gather=True, + gradient_accumulation_fusion=gradient_accumulation_fusion, + use_transformer_engine_op_fuser=True, + ) + + @pytest.mark.parametrize("precision", ["bf16", "mxfp8", "nvfp4"]) + @pytest.mark.parametrize("gradient_accumulation_fusion", [False, True]) + def test_single_grouped_weight_parity_without_primary_param_gather( + self, precision, gradient_accumulation_fusion + ): + """Compare single vs discrete MoE weights when primary weights stay BF16.""" + _skip_if_unsupported(precision) + self.run_parity_case( + precision=precision, + primary_param_gather=False, + gradient_accumulation_fusion=gradient_accumulation_fusion, + use_transformer_engine_op_fuser=True, + ) + + @pytest.mark.parametrize("precision", ["bf16", "mxfp8", "nvfp4"]) + @pytest.mark.parametrize("primary_param_gather", [False, True]) + @pytest.mark.parametrize("gradient_accumulation_fusion", [False, True]) + def test_single_grouped_weight_parity_module_grouped_linear( + self, precision, primary_param_gather, gradient_accumulation_fusion + ): + """Compare single vs discrete MoE weights through TE module.GroupedLinear.""" + _skip_if_unsupported(precision) + self.run_parity_case( + precision=precision, + primary_param_gather=primary_param_gather, + gradient_accumulation_fusion=gradient_accumulation_fusion, + use_transformer_engine_op_fuser=False, + )