diff --git a/src/megatron/bridge/models/conversion/mapping_registry.py b/src/megatron/bridge/models/conversion/mapping_registry.py index f4d89b9dc9..83f09630e9 100644 --- a/src/megatron/bridge/models/conversion/mapping_registry.py +++ b/src/megatron/bridge/models/conversion/mapping_registry.py @@ -17,7 +17,7 @@ from typing import List, Optional from megatron.bridge.models.conversion.param_mapping import AutoMapping, MegatronParamMapping -from megatron.bridge.models.conversion.quant_mapping import convert_to_amax_map +from megatron.bridge.models.conversion.quant_mapping import convert_to_amax_map, derive_kv_bmm_amax_map class MegatronMappingRegistry: @@ -120,6 +120,7 @@ def _add_quantization_mappings(self) -> None: original_mappings = list(self.mappings) self.mappings.extend(convert_to_amax_map(original_mappings, ".weight_quantizer._amax")) self.mappings.extend(convert_to_amax_map(original_mappings, ".input_quantizer._amax")) + self.mappings.extend(derive_kv_bmm_amax_map(original_mappings)) def _convert_pattern_to_regex(self, pattern: str) -> str: """Convert a pattern with wildcards to regex pattern. diff --git a/src/megatron/bridge/models/conversion/quant_mapping.py b/src/megatron/bridge/models/conversion/quant_mapping.py index 1cada16ec4..61ec917bb9 100644 --- a/src/megatron/bridge/models/conversion/quant_mapping.py +++ b/src/megatron/bridge/models/conversion/quant_mapping.py @@ -16,7 +16,7 @@ import torch -from megatron.bridge.models.conversion.param_mapping import MegatronParamMapping, ReplicatedMapping +from megatron.bridge.models.conversion.param_mapping import MegatronParamMapping, QKVMapping, ReplicatedMapping class AmaxMapping(ReplicatedMapping): @@ -231,6 +231,83 @@ def _convert_hf_weight_names(hf_param: str | dict[str, str], mapped_name: str) - return [] +_QKV_PROJECTION_NAMES = {"q": "q_proj", "k": "k_proj", "v": "v_proj"} +# Speculative-decoding draft models and MTP layers are not supported by the +# KV-cache amax refit path yet, so do not derive mappings for their QKV blocks. +_SKIPPED_QKV_PATH_SEGMENTS = frozenset( + { + "draft", + "draft_layers", + "draft_model_layer", + "mtp", + "mtp_layers", + "mtp_model_layer", + } +) + + +def _has_skipped_qkv_path_segment(path: str) -> bool: + return any(segment in _SKIPPED_QKV_PATH_SEGMENTS for segment in path.split(".")) + + +def _derive_qkv_megatron_parent(megatron_param: str) -> str | None: + suffix = ".self_attention.linear_qkv.weight" + if not megatron_param.endswith(suffix): + return None + return megatron_param.removesuffix(".linear_qkv.weight") + ".core_attention" + + +def _derive_qkv_hf_parent(hf_params: dict[str, str]) -> str | None: + parents = [] + for key, expected_proj_name in _QKV_PROJECTION_NAMES.items(): + hf_name = hf_params.get(key) + if not isinstance(hf_name, str): + return None + parts = hf_name.split(".") + if len(parts) < 3 or parts[-1] != "weight" or parts[-2] != expected_proj_name: + return None + parents.append(".".join(parts[:-2])) + if len(set(parents)) != 1: + return None + return parents[0] + + +def derive_kv_bmm_amax_map(mappings: list[MegatronParamMapping]) -> list[MegatronParamMapping]: + """Derive K/V BMM quantizer amax mappings from eligible fused-QKV mappings.""" + derived_mappings = [] + + for mapping in mappings: + if not isinstance(mapping, QKVMapping): + continue + if mapping.allow_hf_name_mismatch: + # Shared/tied-KV bridges may intentionally omit an HF projection. + continue + if _has_skipped_qkv_path_segment(mapping.megatron_param): + continue + if any(_has_skipped_qkv_path_segment(path) for path in mapping.hf_param.values()): + continue + + megatron_parent = _derive_qkv_megatron_parent(mapping.megatron_param) + hf_parent = _derive_qkv_hf_parent(mapping.hf_param) + if megatron_parent is None or hf_parent is None: + continue + + derived_mappings.extend( + [ + AmaxMapping( + megatron_param=f"{megatron_parent}.k_bmm_quantizer._amax", + hf_param=f"{hf_parent}.k_bmm_quantizer._amax", + ), + AmaxMapping( + megatron_param=f"{megatron_parent}.v_bmm_quantizer._amax", + hf_param=f"{hf_parent}.v_bmm_quantizer._amax", + ), + ] + ) + + return derived_mappings + + def convert_to_amax_map( mappings: list[MegatronParamMapping], mapped_name: str = ".weight_quantizer._amax" ) -> list[MegatronParamMapping]: diff --git a/tests/unit_tests/models/test_quant_mapping.py b/tests/unit_tests/models/test_quant_mapping.py index ed0f571f81..185f7a26b9 100644 --- a/tests/unit_tests/models/test_quant_mapping.py +++ b/tests/unit_tests/models/test_quant_mapping.py @@ -23,6 +23,7 @@ from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry from megatron.bridge.models.conversion.param_mapping import ( AutoMapping, + ConcatenatedQKVMapping, GatedMLPMapping, QKVMapping, ReplicatedMapping, @@ -31,6 +32,7 @@ AmaxFanoutMapping, AmaxMapping, MoeAmaxFanoutMapping, + derive_kv_bmm_amax_map, ) @@ -91,6 +93,128 @@ def test_resolve_replaces_wildcards(self): assert set(resolved.hf_targets) == expected +class TestDeriveKvBmmAmaxMap: + def test_derives_kv_bmm_amax_mappings_from_qkv_mapping(self): + mapping = QKVMapping( + megatron_param="decoder.layers.*.self_attention.linear_qkv.weight", + q="model.layers.*.self_attn.q_proj.weight", + k="model.layers.*.self_attn.k_proj.weight", + v="model.layers.*.self_attn.v_proj.weight", + ) + + result = derive_kv_bmm_amax_map([mapping]) + + assert [(m.megatron_param, m.hf_param) for m in result] == [ + ( + "decoder.layers.*.self_attention.core_attention.k_bmm_quantizer._amax", + "model.layers.*.self_attn.k_bmm_quantizer._amax", + ), + ( + "decoder.layers.*.self_attention.core_attention.v_bmm_quantizer._amax", + "model.layers.*.self_attn.v_bmm_quantizer._amax", + ), + ] + assert all(isinstance(mapping, AmaxMapping) for mapping in result) + + def test_preserves_wildcards_and_language_model_prefixes(self): + mapping = QKVMapping( + megatron_param="language_model.decoder.layers.*.self_attention.linear_qkv.weight", + q="language_model.model.layers.*.self_attn.q_proj.weight", + k="language_model.model.layers.*.self_attn.k_proj.weight", + v="language_model.model.layers.*.self_attn.v_proj.weight", + ) + + result = derive_kv_bmm_amax_map([mapping]) + + assert [(m.megatron_param, m.hf_param) for m in result] == [ + ( + "language_model.decoder.layers.*.self_attention.core_attention.k_bmm_quantizer._amax", + "language_model.model.layers.*.self_attn.k_bmm_quantizer._amax", + ), + ( + "language_model.decoder.layers.*.self_attention.core_attention.v_bmm_quantizer._amax", + "language_model.model.layers.*.self_attn.v_bmm_quantizer._amax", + ), + ] + + @pytest.mark.parametrize( + ("mapping", "reason"), + [ + pytest.param( + QKVMapping( + megatron_param="decoder.layers.*.self_attention.linear_qkv.bias", + q="model.layers.*.self_attn.q_proj.bias", + k="model.layers.*.self_attn.k_proj.bias", + v="model.layers.*.self_attn.v_proj.bias", + ), + "bias", + id="bias-mapping", + ), + pytest.param( + QKVMapping( + megatron_param="decoder.layers.*.self_attention.linear_qkv.weight", + q="model.layers.*.self_attn.q_proj.weight", + k="language_model.model.layers.*.self_attn.k_proj.weight", + v="model.layers.*.self_attn.v_proj.weight", + ), + "non-common-parent", + id="non-common-parent", + ), + pytest.param( + QKVMapping( + megatron_param="decoder.layers.*.self_attention.linear_qkv.weight", + q="model.layers.*.self_attn.q_proj.weight", + k="model.layers.*.self_attn.k_proj.kernel", + v="model.layers.*.self_attn.v_proj.weight", + ), + "malformed-projection", + id="malformed-projection", + ), + pytest.param( + QKVMapping( + megatron_param="mtp.layers.*.self_attention.linear_qkv.weight", + q="model.mtp_layers.*.self_attn.q_proj.weight", + k="model.mtp_layers.*.self_attn.k_proj.weight", + v="model.mtp_layers.*.self_attn.v_proj.weight", + ), + "mtp-path", + id="mtp-path", + ), + pytest.param( + QKVMapping( + megatron_param="draft.layers.*.self_attention.linear_qkv.weight", + q="model.draft_layers.*.self_attn.q_proj.weight", + k="model.draft_layers.*.self_attn.k_proj.weight", + v="model.draft_layers.*.self_attn.v_proj.weight", + ), + "draft-path", + id="draft-path", + ), + pytest.param( + ConcatenatedQKVMapping( + megatron_param="vision.blocks.*.self_attention.linear_qkv.weight", + hf_param="vision_model.layers.*.self_attn.qkv.weight", + ), + "vision-concatenated-qkv", + id="vision-concatenated-qkv", + ), + ], + ) + def test_skips_disallowed_qkv_shapes(self, mapping, reason): + assert derive_kv_bmm_amax_map([mapping]) == [], reason + + def test_skips_mappings_that_allow_missing_hf_projections(self): + mapping = QKVMapping( + megatron_param="decoder.layers.*.self_attention.linear_qkv.weight", + q="model.layers.*.self_attn.q_proj.weight", + k="model.layers.*.self_attn.k_proj.weight", + v="model.layers.*.self_attn.v_proj.weight", + ) + mapping.allow_hf_name_mismatch = True + + assert derive_kv_bmm_amax_map([mapping]) == [] + + class TestQuantMappingRegistryIntegration: """Test quantization mappings inside MegatronMappingRegistry with a Llama-like bridge.""" @@ -138,6 +262,10 @@ def test_quant_mappings_disabled_by_default(self, llama_like_mappings): with patch.dict(os.environ, {"ENABLE_BRIDGE_QUANT_MAPPING": "0"}, clear=False): registry = MegatronMappingRegistry(*llama_like_mappings) assert not any(isinstance(m, AmaxMapping) for m in registry.get_all_mappings()) + assert ( + registry.megatron_to_hf_lookup("decoder.layers.0.self_attention.core_attention.k_bmm_quantizer._amax") + is None + ) def test_quant_mappings_count(self, registry): """weight_quantizer and input_quantizer amax mappings are added in equal numbers.""" @@ -252,6 +380,88 @@ def test_layer_index_independence(self, registry): for proj in ["q", "k", "v"]: assert f"model.layers.{layer_idx}.self_attn.{proj}_proj.weight_quantizer._amax" in m.hf_targets + @pytest.mark.parametrize( + "megatron_amax, expected_hf_amax", + [ + ( + "decoder.layers.0.self_attention.core_attention.k_bmm_quantizer._amax", + "model.layers.0.self_attn.k_bmm_quantizer._amax", + ), + ( + "decoder.layers.0.self_attention.core_attention.v_bmm_quantizer._amax", + "model.layers.0.self_attn.v_bmm_quantizer._amax", + ), + ], + ) + def test_kv_bmm_amax_forward_lookup(self, registry, megatron_amax, expected_hf_amax): + mapping = registry.megatron_to_hf_lookup(megatron_amax) + + assert mapping is not None + assert isinstance(mapping, AmaxMapping) + assert mapping.megatron_param == megatron_amax + assert mapping.hf_param == expected_hf_amax + + @pytest.mark.parametrize( + "hf_amax, expected_megatron_amax", + [ + ( + "model.layers.7.self_attn.k_bmm_quantizer._amax", + "decoder.layers.7.self_attention.core_attention.k_bmm_quantizer._amax", + ), + ( + "model.layers.7.self_attn.v_bmm_quantizer._amax", + "decoder.layers.7.self_attention.core_attention.v_bmm_quantizer._amax", + ), + ], + ) + def test_kv_bmm_amax_reverse_lookup(self, registry, hf_amax, expected_megatron_amax): + mapping = registry.hf_to_megatron_lookup(hf_amax) + + assert mapping is not None + assert isinstance(mapping, AmaxMapping) + assert mapping.megatron_param == expected_megatron_amax + assert mapping.hf_param == hf_amax + + def test_kv_bmm_amax_coexists_with_weight_and_input_quantizer_mappings(self, registry): + assert ( + registry.megatron_to_hf_lookup("decoder.layers.0.self_attention.core_attention.k_bmm_quantizer._amax") + is not None + ) + assert ( + registry.megatron_to_hf_lookup("decoder.layers.0.self_attention.linear_qkv.weight_quantizer._amax") + is not None + ) + assert ( + registry.megatron_to_hf_lookup("decoder.layers.0.self_attention.linear_qkv.input_quantizer._amax") + is not None + ) + + +class TestKvBmmQuantMappingPrefixes: + def test_registry_preserves_prefixes_and_wildcards(self): + mappings = [ + QKVMapping( + megatron_param="language_model.decoder.layers.*.self_attention.linear_qkv.weight", + q="language_model.model.layers.*.self_attn.q_proj.weight", + k="language_model.model.layers.*.self_attn.k_proj.weight", + v="language_model.model.layers.*.self_attn.v_proj.weight", + ) + ] + + with patch.dict(os.environ, {"ENABLE_BRIDGE_QUANT_MAPPING": "1"}, clear=False): + registry = MegatronMappingRegistry(*mappings) + + mapping = registry.megatron_to_hf_lookup( + "language_model.decoder.layers.9.self_attention.core_attention.k_bmm_quantizer._amax" + ) + + assert mapping is not None + assert ( + mapping.megatron_param + == "language_model.decoder.layers.9.self_attention.core_attention.k_bmm_quantizer._amax" + ) + assert mapping.hf_param == "language_model.model.layers.9.self_attn.k_bmm_quantizer._amax" + class TestMoeAmaxFanoutMapping: """Tests for grouped-MoE expert amax fanout via all_gather across the EP group."""