Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 141 additions & 55 deletions megatron/core/distributed/param_and_grad_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,21 @@
from megatron.core.rerun_state_machine import get_rerun_state_machine
from megatron.core.utils import log_single_rank

from ..fp4_utils import get_nvfp4_rowwise_packed_shape, is_nvfp4tensor
from ..fp4_utils import (
get_nvfp4_rowwise_packed_shape,
is_grouped_nvfp4tensor,
is_nvfp4tensor,
modify_grouped_nvfp4_rowwise_storage,
modify_nvfp4_rowwise_storage,
)
from ..fp8_utils import (
copy_tensor_to_quantized_param,
is_float8tensor,
is_grouped_mxfp8tensor,
is_grouped_tensor,
is_grouped_tensor_with_quantized_storage,
is_mxfp8tensor,
modify_grouped_tensor_rowwise_storage,
modify_underlying_storage,
post_all_gather_processing,
)
Expand Down Expand Up @@ -67,6 +78,17 @@ def shard_buffer(buffer: torch.Tensor, data_parallel_world_size: int):
return sharded_buffer


def _param_uses_quantized_storage(param: torch.nn.Parameter) -> bool:
"""Return whether the parameter owns TE quantized storage instead of plain tensor storage."""
# In TE2, is_float8tensor() checks QuantizedTensor, so it includes plain MXFP8Tensor.
# Grouped quantized params are GroupedTensor wrappers, so check their backing storage.
return (
is_float8tensor(param)
or is_nvfp4tensor(param)
or is_grouped_tensor_with_quantized_storage(param)
)


class _ParamAndGradBucket:
"""
Bucket to keep track of a subset of the model's parameters and gradients.
Expand Down Expand Up @@ -283,17 +305,18 @@ def _post_param_sync(self):
"""Run post-processing after param all-gather completes."""
if self.ddp_config.reuse_grad_buf_for_mxfp8_param_ag:
for bucket in self.buckets:
is_bf16_weight_bucket = False
has_non_quantized_weight = False
for param in bucket.params:
# Skip copying since bf16 weights in the mxfp8 model
# are already mapped to param.data.
if not is_float8tensor(param):
is_bf16_weight_bucket = True
# Non-quantized weights are already mapped to param.data. Skip
# mixed buckets because zeroing bucket.param_data would also
# clear those model weights.
if not _param_uses_quantized_storage(param):
has_non_quantized_weight = True
break
param_start, param_end = bucket.param_to_index[param]
param_slice = bucket.param_data.view(-1)[param_start:param_end]
param.data.copy_(param_slice.view(param.data.shape))
if is_bf16_weight_bucket:
copy_tensor_to_quantized_param(param, param_slice)
if has_non_quantized_weight:
continue
# All-gathered params are not needed after being copied to param.data.
# Zero out the param buffer (shared with grad buffer) for gradient accumulation.
Expand All @@ -306,7 +329,7 @@ def _post_param_sync(self):
quantized_params = []
for bucket in self.buckets:
for param in bucket.params:
if is_float8tensor(param) or is_nvfp4tensor(param):
if _param_uses_quantized_storage(param):
quantized_params.append(param)
if len(quantized_params) > 0:
post_all_gather_processing(quantized_params)
Expand Down Expand Up @@ -856,7 +879,7 @@ def group_params_for_buffers(
assert param.requires_grad

param_dtype = param.dtype
if is_float8tensor(param) or is_nvfp4tensor(param):
if _param_uses_quantized_storage(param):
param_dtype = torch.uint8
grad_dtype = torch.float if grad_reduce_in_fp32 else param.dtype
is_expert_parallel = not getattr(param, 'allreduce', True)
Expand Down Expand Up @@ -1038,7 +1061,9 @@ def __init__(
# The packed index map is derived from param_index_map by iterating through
# the already-computed layout and halving numel for NVFP4 tensors.
#
self.has_nvfp4_params = any(is_nvfp4tensor(p) for p in self.params)
self.has_nvfp4_params = any(
is_nvfp4tensor(p) or is_grouped_nvfp4tensor(p) for p in self.params
)
self.nvfp4_packed_param_index_map = None
self.nvfp4_packed_bucket_indices = None
if self.has_nvfp4_params:
Expand Down Expand Up @@ -1095,7 +1120,7 @@ def __init__(
# The buffer is mapped to weight gradients whose dtype is either bf16 or FP32.
# It can be temporarily reused by param AG.
if self.ddp_config.use_distributed_optimizer and any(
is_mxfp8tensor(p) for p in self.params
is_mxfp8tensor(p) or is_grouped_mxfp8tensor(p) for p in self.params
):
self.shared_buffer = torch.zeros(
self.numel,
Expand Down Expand Up @@ -1173,50 +1198,111 @@ def _create_bucket(bucket_id, bucket_params, bucket_params_with_extra_main_grads
nvfp4_packed_param_start_index = None
if self.has_nvfp4_params:
nvfp4_packed_param_start_index, _, _ = self.nvfp4_packed_param_index_map[param]
# For MXFP8 param:
# we only need to map bf16 weights (layernorm, embedding, etc) to the buffer.
if not self.ddp_config.reuse_grad_buf_for_mxfp8_param_ag or not is_mxfp8tensor(param):
# This branch remaps the parameter storage into persistent DDP param_data buffer.
#
# Enter when:
# - `reuse_grad_buf_for_mxfp8_param_ag` is off: param AG has a persistent
# param_data buffer instead of sharing storage with grad_data,
# so every parameter must be backed by param_data.
# - param is not quantized: BF16/FP16/plain params still need persistent
# param_data even when quantized params use grad_data as temporary AG storage.
#
# Skip only when both are true: AG reuses grad_data and the param is quantized.
# In that case AG writes into grad_data, then _post_param_sync copies the
# gathered values back into TE quantized storage.
#
# Remap cases below:
# non-grouped TE NVFP4 tensor -> remap packed rowwise bytes
# non-grouped TE quantized -> remap TE quantized storage
# regular torch.Tensor param -> replace param.data with param_data view
# TE GroupedTensor + NVFP4 -> remap packed rowwise bytes
# TE GroupedTensor + MXFP8 -> unsupported here; require grad-buffer AG reuse
# TE GroupedTensor + BF16/FP16 -> remap grouped rowwise_data
if (
not self.ddp_config.reuse_grad_buf_for_mxfp8_param_ag
or not _param_uses_quantized_storage(param)
):
if self.param_data is not None:
if is_nvfp4tensor(param):
# Remap the NVFP4 tensor's internal rowwise uint8 storage so it
# points into the contiguous DDP param buffer. This enables the
# all-gather to communicate packed NVFP4 bytes directly.
from ..fp4_utils import modify_nvfp4_rowwise_storage

packed_shape = get_nvfp4_rowwise_packed_shape(param.data.shape)
rowwise_bytes_view = self._get(
packed_shape,
nvfp4_packed_param_start_index,
buffer_type=BufferType.PARAM,
)
modify_nvfp4_rowwise_storage(param, rowwise_bytes_view)
elif is_float8tensor(param):
new_param_data = self._get(
param.data.shape,
(
nvfp4_packed_param_start_index
if self.has_nvfp4_params
else param_start_index
),
buffer_type=BufferType.PARAM,
)
modify_underlying_storage(param, new_param_data)
if not is_grouped_tensor(param):
# Plain NVFP4: remap packed rowwise bytes only.
if is_nvfp4tensor(param):
packed_shape = get_nvfp4_rowwise_packed_shape(param.data.shape)
rowwise_bytes_view = self._get(
packed_shape,
nvfp4_packed_param_start_index,
buffer_type=BufferType.PARAM,
)
modify_nvfp4_rowwise_storage(param, rowwise_bytes_view)
# In TE2, is_float8tensor() checks QuantizedTensor, including MXFP8.
# NVFP4 is handled by the branch above.
elif is_float8tensor(param):
# NVFP4 packs two FP4 values per byte, so param_data uses
# packed-byte offsets instead of logical element offsets.
new_param_data = self._get(
param.data.shape,
(
nvfp4_packed_param_start_index
if self.has_nvfp4_params
else param_start_index
),
buffer_type=BufferType.PARAM,
)
modify_underlying_storage(param, new_param_data)
# Plain torch param: replace param.data with DDP buffer view.
else:
# NVFP4 packs two FP4 values per byte, so param_data uses
# packed-byte offsets instead of logical element offsets.
new_param_data = self._get(
param.data.shape,
(
nvfp4_packed_param_start_index
if self.has_nvfp4_params
else param_start_index
),
buffer_type=BufferType.PARAM,
)
old_param_data = param.data
param.data = new_param_data
assert old_param_data._base is None
# Copy tensor values (from initialization or checkpoint).
param.data.detach().copy_(old_param_data)
del old_param_data
else:
new_param_data = self._get(
param.data.shape,
(
nvfp4_packed_param_start_index
if self.has_nvfp4_params
else param_start_index
),
buffer_type=BufferType.PARAM,
)
old_param_data = param.data
param.data = new_param_data
assert old_param_data._base is None
# Copy tensor values (from initialization or checkpoint).
param.data.detach().copy_(old_param_data)
del old_param_data
# GroupedTensor: preserve wrapper/metadata; remap backing storage only.
# Grouped NVFP4: only rowwise bytes live in DDP param_data.
if is_grouped_nvfp4tensor(param):
packed_shape = get_nvfp4_rowwise_packed_shape(param.data.shape)
rowwise_bytes_view = self._get(
packed_shape,
nvfp4_packed_param_start_index,
buffer_type=BufferType.PARAM,
)
modify_grouped_nvfp4_rowwise_storage(param, rowwise_bytes_view)
# Grouped MXFP8: do not remap grouped quantized storage into param_data.
# Use grad-buffer AG reuse and copy gathered values back after AG.
elif is_grouped_mxfp8tensor(param):
raise RuntimeError(
"Single grouped MXFP8 params require "
"--reuse-grad-buf-for-mxfp8-param-ag."
)
elif is_grouped_tensor_with_quantized_storage(param):
raise RuntimeError(
"Unsupported single grouped quantized parameter recipe."
)
# Grouped BF16/FP16: remap full rowwise_data.
else:
# NVFP4 packs two FP4 values per byte, so param_data uses
# packed-byte offsets instead of logical element offsets.
new_param_data = self._get(
param.data.shape,
(
nvfp4_packed_param_start_index
if self.has_nvfp4_params
else param_start_index
),
buffer_type=BufferType.PARAM,
)
modify_grouped_tensor_rowwise_storage(param, new_param_data)

# Grad buffer always uses full-numel offsets from param_index_map.
param.main_grad = self._get(
Expand Down Expand Up @@ -1349,7 +1435,7 @@ def _pad_end_of_bucket(bucket_end_index: int) -> int:
cur_bucket_id = bucket_id

# NVFP4 tensors use half the numel in the packed param buffer.
if is_nvfp4tensor(param):
if is_nvfp4tensor(param) or is_grouped_nvfp4tensor(param):
assert (
param_numel % 2 == 0
), f"NVFP4 requires even numel for packing, got {param_numel}"
Expand Down
52 changes: 51 additions & 1 deletion megatron/core/fp4_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
import torch

from megatron.core.enums import Fp4Recipe
from megatron.core.fp8_utils import _get_custom_recipe
from megatron.core.fp8_utils import (
_get_custom_recipe,
_get_grouped_quantized_recipe,
get_grouped_quantized_members,
is_grouped_tensor_with_quantized_storage,
)
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import is_te_min_version

Expand Down Expand Up @@ -55,6 +60,14 @@ def is_nvfp4tensor(tensor: torch.Tensor) -> bool:
return HAVE_TE_FP4_TENSOR_CLASS and isinstance(tensor, FP4_TENSOR_CLASS)


def is_grouped_nvfp4tensor(tensor: torch.Tensor) -> bool:
"""Check if a TE GroupedTensor stores NVFP4 member tensors."""
if not HAVE_TE_FP4_TENSOR_CLASS:
return False
recipe = _get_grouped_quantized_recipe(tensor)
return recipe is not None and hasattr(recipe, "nvfp4") and recipe.nvfp4()


def get_nvfp4_rowwise_packed_shape(shape: torch.Size) -> torch.Size:
"""Return packed byte shape for NVFP4 rowwise storage (last dim // 2)."""
if len(shape) == 0:
Expand Down Expand Up @@ -85,6 +98,43 @@ def modify_nvfp4_rowwise_storage(fp4_tensor: torch.Tensor, new_rowwise_data: tor
del old_rowwise


def modify_grouped_nvfp4_rowwise_storage(
grouped_tensor: torch.Tensor, new_rowwise_data: torch.Tensor
) -> None:
"""Replace grouped NVFP4 rowwise data with a new uint8 storage view.

The name intentionally mirrors `modify_nvfp4_rowwise_storage`: only the
packed rowwise byte buffer is remapped into the DDP buffer. The grouped
scale, amax, and columnwise buffers remain owned by the original tensor.
"""
tensor = (
grouped_tensor.data if isinstance(grouped_tensor, torch.nn.Parameter) else grouped_tensor
)
if not is_grouped_nvfp4tensor(tensor):
raise ValueError("modify_grouped_nvfp4_rowwise_storage expects grouped NVFP4 storage")

old_rowwise = getattr(tensor, "rowwise_data", None)
if old_rowwise is None:
raise RuntimeError("Grouped NVFP4 tensor is missing rowwise data to replace")

new_rowwise_data = new_rowwise_data.view(-1)
if old_rowwise.numel() != new_rowwise_data.numel():
raise ValueError(
"Grouped NVFP4 rowwise storage size mismatch: "
f"old numel={old_rowwise.numel()}, new numel={new_rowwise_data.numel()}"
)
assert (
old_rowwise.dtype == new_rowwise_data.dtype == torch.uint8
), "Grouped NVFP4 rowwise storage must be uint8"

new_rowwise_data.detach().copy_(old_rowwise.view(-1))
tensor.rowwise_data = new_rowwise_data
# Member views capture data pointers. Refresh them after swapping rowwise storage while
# preserving the existing scale/amax/columnwise grouped buffers.
tensor.quantized_tensors = tensor.split_into_quantized_tensors()
del old_rowwise
Comment thread
kunlunl marked this conversation as resolved.


def quantize_nvfp4_param_shard(
model_params, main_params, start_offsets, data_parallel_group, fsdp_shard_model_params=None
):
Expand Down
Loading