Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/megatron/bridge/models/conversion/mapping_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import List, Optional

from megatron.bridge.models.conversion.param_mapping import AutoMapping, MegatronParamMapping
from megatron.bridge.models.conversion.quant_mapping import convert_to_amax_map
from megatron.bridge.models.conversion.quant_mapping import convert_to_amax_map, derive_kv_bmm_amax_map


class MegatronMappingRegistry:
Expand Down Expand Up @@ -120,6 +120,7 @@ def _add_quantization_mappings(self) -> None:
original_mappings = list(self.mappings)
self.mappings.extend(convert_to_amax_map(original_mappings, ".weight_quantizer._amax"))
self.mappings.extend(convert_to_amax_map(original_mappings, ".input_quantizer._amax"))
self.mappings.extend(derive_kv_bmm_amax_map(original_mappings))

def _convert_pattern_to_regex(self, pattern: str) -> str:
"""Convert a pattern with wildcards to regex pattern.
Expand Down
79 changes: 78 additions & 1 deletion src/megatron/bridge/models/conversion/quant_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import torch

from megatron.bridge.models.conversion.param_mapping import MegatronParamMapping, ReplicatedMapping
from megatron.bridge.models.conversion.param_mapping import MegatronParamMapping, QKVMapping, ReplicatedMapping


class AmaxMapping(ReplicatedMapping):
Expand Down Expand Up @@ -231,6 +231,83 @@ def _convert_hf_weight_names(hf_param: str | dict[str, str], mapped_name: str) -
return []


_QKV_PROJECTION_NAMES = {"q": "q_proj", "k": "k_proj", "v": "v_proj"}
# Speculative-decoding draft models and MTP layers are not supported by the
# KV-cache amax refit path yet, so do not derive mappings for their QKV blocks.
_SKIPPED_QKV_PATH_SEGMENTS = frozenset(
{
"draft",
"draft_layers",
"draft_model_layer",
"mtp",
"mtp_layers",
"mtp_model_layer",
}
)


def _has_skipped_qkv_path_segment(path: str) -> bool:
return any(segment in _SKIPPED_QKV_PATH_SEGMENTS for segment in path.split("."))


def _derive_qkv_megatron_parent(megatron_param: str) -> str | None:
suffix = ".self_attention.linear_qkv.weight"
if not megatron_param.endswith(suffix):
return None
return megatron_param.removesuffix(".linear_qkv.weight") + ".core_attention"


def _derive_qkv_hf_parent(hf_params: dict[str, str]) -> str | None:
parents = []
for key, expected_proj_name in _QKV_PROJECTION_NAMES.items():
hf_name = hf_params.get(key)
if not isinstance(hf_name, str):
return None
parts = hf_name.split(".")
if len(parts) < 3 or parts[-1] != "weight" or parts[-2] != expected_proj_name:
return None
parents.append(".".join(parts[:-2]))
if len(set(parents)) != 1:
return None
return parents[0]


def derive_kv_bmm_amax_map(mappings: list[MegatronParamMapping]) -> list[MegatronParamMapping]:
"""Derive K/V BMM quantizer amax mappings from eligible fused-QKV mappings."""
derived_mappings = []

for mapping in mappings:
if not isinstance(mapping, QKVMapping):
continue
if mapping.allow_hf_name_mismatch:
# Shared/tied-KV bridges may intentionally omit an HF projection.
continue
if _has_skipped_qkv_path_segment(mapping.megatron_param):
continue
if any(_has_skipped_qkv_path_segment(path) for path in mapping.hf_param.values()):
continue

megatron_parent = _derive_qkv_megatron_parent(mapping.megatron_param)
hf_parent = _derive_qkv_hf_parent(mapping.hf_param)
if megatron_parent is None or hf_parent is None:
continue

derived_mappings.extend(
[
AmaxMapping(
megatron_param=f"{megatron_parent}.k_bmm_quantizer._amax",
hf_param=f"{hf_parent}.k_bmm_quantizer._amax",
),
AmaxMapping(
megatron_param=f"{megatron_parent}.v_bmm_quantizer._amax",
hf_param=f"{hf_parent}.v_bmm_quantizer._amax",
),
]
)

return derived_mappings


def convert_to_amax_map(
mappings: list[MegatronParamMapping], mapped_name: str = ".weight_quantizer._amax"
) -> list[MegatronParamMapping]:
Expand Down
210 changes: 210 additions & 0 deletions tests/unit_tests/models/test_quant_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry
from megatron.bridge.models.conversion.param_mapping import (
AutoMapping,
ConcatenatedQKVMapping,
GatedMLPMapping,
QKVMapping,
ReplicatedMapping,
Expand All @@ -31,6 +32,7 @@
AmaxFanoutMapping,
AmaxMapping,
MoeAmaxFanoutMapping,
derive_kv_bmm_amax_map,
)


Expand Down Expand Up @@ -91,6 +93,128 @@ def test_resolve_replaces_wildcards(self):
assert set(resolved.hf_targets) == expected


class TestDeriveKvBmmAmaxMap:
def test_derives_kv_bmm_amax_mappings_from_qkv_mapping(self):
mapping = QKVMapping(
megatron_param="decoder.layers.*.self_attention.linear_qkv.weight",
q="model.layers.*.self_attn.q_proj.weight",
k="model.layers.*.self_attn.k_proj.weight",
v="model.layers.*.self_attn.v_proj.weight",
)

result = derive_kv_bmm_amax_map([mapping])

assert [(m.megatron_param, m.hf_param) for m in result] == [
(
"decoder.layers.*.self_attention.core_attention.k_bmm_quantizer._amax",
"model.layers.*.self_attn.k_bmm_quantizer._amax",
),
(
"decoder.layers.*.self_attention.core_attention.v_bmm_quantizer._amax",
"model.layers.*.self_attn.v_bmm_quantizer._amax",
),
]
assert all(isinstance(mapping, AmaxMapping) for mapping in result)

def test_preserves_wildcards_and_language_model_prefixes(self):
mapping = QKVMapping(
megatron_param="language_model.decoder.layers.*.self_attention.linear_qkv.weight",
q="language_model.model.layers.*.self_attn.q_proj.weight",
k="language_model.model.layers.*.self_attn.k_proj.weight",
v="language_model.model.layers.*.self_attn.v_proj.weight",
)

result = derive_kv_bmm_amax_map([mapping])

assert [(m.megatron_param, m.hf_param) for m in result] == [
(
"language_model.decoder.layers.*.self_attention.core_attention.k_bmm_quantizer._amax",
"language_model.model.layers.*.self_attn.k_bmm_quantizer._amax",
),
(
"language_model.decoder.layers.*.self_attention.core_attention.v_bmm_quantizer._amax",
"language_model.model.layers.*.self_attn.v_bmm_quantizer._amax",
),
]

@pytest.mark.parametrize(
("mapping", "reason"),
[
pytest.param(
QKVMapping(
megatron_param="decoder.layers.*.self_attention.linear_qkv.bias",
q="model.layers.*.self_attn.q_proj.bias",
k="model.layers.*.self_attn.k_proj.bias",
v="model.layers.*.self_attn.v_proj.bias",
),
"bias",
id="bias-mapping",
),
pytest.param(
QKVMapping(
megatron_param="decoder.layers.*.self_attention.linear_qkv.weight",
q="model.layers.*.self_attn.q_proj.weight",
k="language_model.model.layers.*.self_attn.k_proj.weight",
v="model.layers.*.self_attn.v_proj.weight",
),
"non-common-parent",
id="non-common-parent",
),
pytest.param(
QKVMapping(
megatron_param="decoder.layers.*.self_attention.linear_qkv.weight",
q="model.layers.*.self_attn.q_proj.weight",
k="model.layers.*.self_attn.k_proj.kernel",
v="model.layers.*.self_attn.v_proj.weight",
),
"malformed-projection",
id="malformed-projection",
),
pytest.param(
QKVMapping(
megatron_param="mtp.layers.*.self_attention.linear_qkv.weight",
q="model.mtp_layers.*.self_attn.q_proj.weight",
k="model.mtp_layers.*.self_attn.k_proj.weight",
v="model.mtp_layers.*.self_attn.v_proj.weight",
),
"mtp-path",
id="mtp-path",
),
pytest.param(
QKVMapping(
megatron_param="draft.layers.*.self_attention.linear_qkv.weight",
q="model.draft_layers.*.self_attn.q_proj.weight",
k="model.draft_layers.*.self_attn.k_proj.weight",
v="model.draft_layers.*.self_attn.v_proj.weight",
),
"draft-path",
id="draft-path",
),
pytest.param(
ConcatenatedQKVMapping(
megatron_param="vision.blocks.*.self_attention.linear_qkv.weight",
hf_param="vision_model.layers.*.self_attn.qkv.weight",
),
"vision-concatenated-qkv",
id="vision-concatenated-qkv",
),
],
)
def test_skips_disallowed_qkv_shapes(self, mapping, reason):
assert derive_kv_bmm_amax_map([mapping]) == [], reason

def test_skips_mappings_that_allow_missing_hf_projections(self):
mapping = QKVMapping(
megatron_param="decoder.layers.*.self_attention.linear_qkv.weight",
q="model.layers.*.self_attn.q_proj.weight",
k="model.layers.*.self_attn.k_proj.weight",
v="model.layers.*.self_attn.v_proj.weight",
)
mapping.allow_hf_name_mismatch = True

assert derive_kv_bmm_amax_map([mapping]) == []


class TestQuantMappingRegistryIntegration:
"""Test quantization mappings inside MegatronMappingRegistry with a Llama-like bridge."""

Expand Down Expand Up @@ -138,6 +262,10 @@ def test_quant_mappings_disabled_by_default(self, llama_like_mappings):
with patch.dict(os.environ, {"ENABLE_BRIDGE_QUANT_MAPPING": "0"}, clear=False):
registry = MegatronMappingRegistry(*llama_like_mappings)
assert not any(isinstance(m, AmaxMapping) for m in registry.get_all_mappings())
assert (
registry.megatron_to_hf_lookup("decoder.layers.0.self_attention.core_attention.k_bmm_quantizer._amax")
is None
)

def test_quant_mappings_count(self, registry):
"""weight_quantizer and input_quantizer amax mappings are added in equal numbers."""
Expand Down Expand Up @@ -252,6 +380,88 @@ def test_layer_index_independence(self, registry):
for proj in ["q", "k", "v"]:
assert f"model.layers.{layer_idx}.self_attn.{proj}_proj.weight_quantizer._amax" in m.hf_targets

@pytest.mark.parametrize(
"megatron_amax, expected_hf_amax",
[
(
"decoder.layers.0.self_attention.core_attention.k_bmm_quantizer._amax",
"model.layers.0.self_attn.k_bmm_quantizer._amax",
),
(
"decoder.layers.0.self_attention.core_attention.v_bmm_quantizer._amax",
"model.layers.0.self_attn.v_bmm_quantizer._amax",
),
],
)
def test_kv_bmm_amax_forward_lookup(self, registry, megatron_amax, expected_hf_amax):
mapping = registry.megatron_to_hf_lookup(megatron_amax)

assert mapping is not None
assert isinstance(mapping, AmaxMapping)
assert mapping.megatron_param == megatron_amax
assert mapping.hf_param == expected_hf_amax

@pytest.mark.parametrize(
"hf_amax, expected_megatron_amax",
[
(
"model.layers.7.self_attn.k_bmm_quantizer._amax",
"decoder.layers.7.self_attention.core_attention.k_bmm_quantizer._amax",
),
(
"model.layers.7.self_attn.v_bmm_quantizer._amax",
"decoder.layers.7.self_attention.core_attention.v_bmm_quantizer._amax",
),
],
)
def test_kv_bmm_amax_reverse_lookup(self, registry, hf_amax, expected_megatron_amax):
mapping = registry.hf_to_megatron_lookup(hf_amax)

assert mapping is not None
assert isinstance(mapping, AmaxMapping)
assert mapping.megatron_param == expected_megatron_amax
assert mapping.hf_param == hf_amax

def test_kv_bmm_amax_coexists_with_weight_and_input_quantizer_mappings(self, registry):
assert (
registry.megatron_to_hf_lookup("decoder.layers.0.self_attention.core_attention.k_bmm_quantizer._amax")
is not None
)
assert (
registry.megatron_to_hf_lookup("decoder.layers.0.self_attention.linear_qkv.weight_quantizer._amax")
is not None
)
assert (
registry.megatron_to_hf_lookup("decoder.layers.0.self_attention.linear_qkv.input_quantizer._amax")
is not None
)


class TestKvBmmQuantMappingPrefixes:
def test_registry_preserves_prefixes_and_wildcards(self):
mappings = [
QKVMapping(
megatron_param="language_model.decoder.layers.*.self_attention.linear_qkv.weight",
q="language_model.model.layers.*.self_attn.q_proj.weight",
k="language_model.model.layers.*.self_attn.k_proj.weight",
v="language_model.model.layers.*.self_attn.v_proj.weight",
)
]

with patch.dict(os.environ, {"ENABLE_BRIDGE_QUANT_MAPPING": "1"}, clear=False):
registry = MegatronMappingRegistry(*mappings)

mapping = registry.megatron_to_hf_lookup(
"language_model.decoder.layers.9.self_attention.core_attention.k_bmm_quantizer._amax"
)

assert mapping is not None
assert (
mapping.megatron_param
== "language_model.decoder.layers.9.self_attention.core_attention.k_bmm_quantizer._amax"
)
assert mapping.hf_param == "language_model.model.layers.9.self_attn.k_bmm_quantizer._amax"


class TestMoeAmaxFanoutMapping:
"""Tests for grouped-MoE expert amax fanout via all_gather across the EP group."""
Expand Down
Loading