Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions skills/adding-model-support/SKILL.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
---
name: adding-model-support
description: Guide for adding support for new LLM or VLM models in Megatron-Bridge. Covers bridge, provider, recipe, tests, docs, and examples.
when_to_use: User asks to add, onboard, or integrate a new model family; 'add Qwen4 support', 'onboard Llama 5', 'create a bridge for X', 'write a recipe for Y'.
description: Guide for adding, onboarding, or integrating new LLM or VLM model families in Megatron-Bridge, including bridges, providers, recipes, tests, docs, and examples.
---

# Adding New Model Support in Megatron-Bridge
Expand Down Expand Up @@ -82,6 +81,38 @@ Also add or update focused tests when touching export/import quantization paths;

## Phase 2: Bridge Support

### Declare Transformers compatibility

For every new bridge, determine the earliest verified upstream Transformers release that provides
the required model APIs. Test that boundary and the preceding release; do not copy the current
project pin or guess a minimum from the model publication date.

Register architectures that may be absent at the package floor by string and
attach the verified compatibility contract:

```python
@MegatronModelBridge.register_bridge(
source="NewModelForCausalLM",
target=GPTModel,
model_type="new_model",
min_transformers_version="5.8.0",
required_transformers_symbols=("transformers.NewModelConfig", "transformers.NewModelForCausalLM"),
)
class NewModelBridge(MegatronModelBridge):
...
```
- `min_transformers_version` is the earliest verified release that can support the bridge. Do not
leave it `None` for newly added model support.
- `required_transformers_symbols` lists exact dotted config, model, processor, or internal module
attributes the bridge uses. Add it for newly introduced or version-sensitive APIs so vendor,
development, and incomplete builds fail with an actionable Bridge error instead of a later import error.
- Required symbols need an explicit minimum. The version check runs first; the
symbols are resolved lazily only after that bridge is selected.
- Keep version-sensitive Transformers imports out of module scope. Import them
inside a guarded provider/load/export path after bridge selection.
- Do not add native Transformers symbols for `trust_remote_code` classes that
exist only in a model repository; validate those through the remote-code path.

### File structure

**LLM** — Reference: Qwen2 (`src/megatron/bridge/models/qwen/qwen2_bridge.py`)
Expand Down Expand Up @@ -352,6 +383,16 @@ tests/functional_tests/test_groups/models/<model>/
└── test_<model>_provider.py # compare_provider_configs (optional)
```

Add compatibility boundary coverage for every new registration:

- At the declared minimum, package/registry import and bridge selection pass.
- On the immediately preceding release, package/registry import still passes,
while selecting the incompatible bridge raises `TransformersVersionError`
before provider construction or model loading.
- Assert required-symbol failures separately from old-version failures.
- Run version cases in fresh subprocesses or immutable environments so
Transformers lazy modules and Bridge registration state cannot leak.

For detailed test patterns, see @skills/adding-model-support/tests-and-examples.md.

## Phase 5: Docs and Examples
Expand Down
10 changes: 10 additions & 0 deletions src/megatron/bridge/models/conversion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@
ReplicatedMapping,
RowParallelMapping,
)
from megatron.bridge.models.conversion.transformers_version import (
TransformersVersionError,
get_transformers_version,
is_transformers_min_version,
require_transformers_version,
)
from megatron.bridge.models.conversion.utils import weights_verification_table


Expand All @@ -44,4 +50,8 @@
"RowParallelMapping",
"AutoMapping",
"weights_verification_table",
"TransformersVersionError",
"get_transformers_version",
"is_transformers_min_version",
"require_transformers_version",
]
13 changes: 10 additions & 3 deletions src/megatron/bridge/models/conversion/auto_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,8 +323,6 @@ def from_auto_config(cls, megatron_path: str, hf_model_id: str, trust_remote_cod
Raises:
FileNotFoundError: If run_config.yaml is not found in the Megatron path
"""
from transformers import AutoConfig

from megatron.bridge.models.conversion.utils import conform_config_to_reference
from megatron.bridge.training.model_load_save import load_model_config

Expand Down Expand Up @@ -352,7 +350,7 @@ def from_auto_config(cls, megatron_path: str, hf_model_id: str, trust_remote_cod
"Loading a model with trust_remote_code=True allows arbitrary code execution "
"from the model repository. Only use this with models you trust."
)
hf_cfg = AutoConfig.from_pretrained(hf_model_id, trust_remote_code=trust_remote_code)
hf_cfg = safe_load_config_with_retry(hf_model_id, trust_remote_code=trust_remote_code)
# 2. Translate Megatron config -> HF, conforming to reference config
bridge = cls.from_hf_config(hf_cfg)
megatron_hf_cfg_dict = bridge._model_bridge.megatron_to_hf_config(megatron_cfg)
Expand Down Expand Up @@ -1933,6 +1931,15 @@ def _validate_config(cls, config: PretrainedConfig, path: str | None = None) ->
f" • src/megatron/bridge/models/qwen/qwen_2_causal_bridge.py"
) from None

registered_source_name = arch_key if isinstance(arch_key, str) else arch_key.__name__
bridge_class = model_bridge.get_registered_bridge_class(
source_name=registered_source_name,
model_type=getattr(config, "model_type", None),
)
if bridge_class is not None:
action = "select this model configuration" if path is None else f"load model {path}"
bridge_class.require_transformers_compatibility(action=action)

def _get_model_instance(self, model: list[MegatronModelT]) -> MegatronModelT:
model_instance = model[0]
while hasattr(model_instance, "module"):
Expand Down
66 changes: 65 additions & 1 deletion src/megatron/bridge/models/conversion/model_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
from megatron.bridge.models.conversion.transformers_compat import (
rope_theta_from_hf,
)
from megatron.bridge.models.conversion.transformers_version import require_transformers_version
from megatron.bridge.models.conversion.utils import (
extract_sort_key,
get_module_and_param_from_name,
Expand All @@ -82,6 +83,9 @@
MegatronModel = TypeVar("MegatronModel", bound=MegatronModule)
_BridgeImplClass = TypeVar("_BridgeImplClass", bound="MegatronModelBridge")

_BRIDGE_CLASSES_BY_SOURCE_NAME: dict[str, type["MegatronModelBridge"]] = {}
_BRIDGE_CLASSES_BY_MODEL_TYPE: dict[str, type["MegatronModelBridge"]] = {}


class MegatronWeightTuple(NamedTuple):
"""Tuple representing a Megatron model weight with its metadata."""
Expand Down Expand Up @@ -610,6 +614,26 @@ def hf_config_to_provider_kwargs(self, hf_config) -> dict:
# Set by @register_bridge decorator
SOURCE_NAME: str | None = None
MODEL_TYPE: str | None = None
MIN_TRANSFORMERS_VERSION: str | None = None
REQUIRED_TRANSFORMERS_SYMBOLS: tuple[str, ...] = ()

def __new__(cls, *args: Any, **kwargs: Any) -> "MegatronModelBridge":
"""Create a bridge only after validating its Transformers policy."""
instance = super().__new__(cls)
cls.require_transformers_compatibility(action="use this model bridge")
return instance

@classmethod
def require_transformers_compatibility(cls, *, action: str | None = None) -> None:
"""Validate the explicit Transformers policy attached to this bridge."""
if cls.MIN_TRANSFORMERS_VERSION is None:
return
require_transformers_version(
cls.SOURCE_NAME or cls.__name__,
cls.MIN_TRANSFORMERS_VERSION,
symbols=cls.REQUIRED_TRANSFORMERS_SYMBOLS,
action=action,
)

def provider_bridge(self, hf_pretrained: HFPreTrained) -> ModelProviderTarget:
"""Create a Megatron model provider from HuggingFace configuration.
Expand Down Expand Up @@ -1949,6 +1973,8 @@ def register_bridge(
target: Type[MegatronModel],
provider: Type[ModelProviderTarget] | None = None,
model_type: str | None = None,
min_transformers_version: str | None = None,
required_transformers_symbols: tuple[str, ...] = (),
) -> Callable[[_BridgeImplClass], _BridgeImplClass]:
"""Class decorator for registering bridge implementations.

Expand All @@ -1967,6 +1993,11 @@ def register_bridge(
Defaults to GPTModelProvider if not specified.
model_type (str, optional): HuggingFace model_type string (e.g., "llama").
Used for megatron_to_hf_config conversion.
min_transformers_version: Minimum Transformers version required to
select or use this bridge. ``None`` means the policy is not yet
researched and is not enforced.
required_transformers_symbols: Dotted Transformers symbols that
must be available after the minimum version check passes.

Returns:
Callable[[_BridgeImplClass], _BridgeImplClass]: Decorator function
Expand Down Expand Up @@ -2006,7 +2037,14 @@ class MegatronDeepseekV3Bridge(MegatronModelBridge):
class is defined.
"""

return create_bridge_decorator(source=source, target=target, provider=provider, model_type=model_type)
return create_bridge_decorator(
source=source,
target=target,
provider=provider,
model_type=model_type,
min_transformers_version=min_transformers_version,
required_transformers_symbols=required_transformers_symbols,
)


# Core dispatch functions
Expand Down Expand Up @@ -2167,6 +2205,8 @@ def create_bridge_decorator(
target: Type["MegatronModule"],
provider: Type["ModelProviderMixin"] | None = None,
model_type: str | None = None,
min_transformers_version: str | None = None,
required_transformers_symbols: tuple[str, ...] = (),
) -> Callable[[Type["MegatronModelBridge"]], Type["MegatronModelBridge"]]:
"""Create a decorator for registering bridge implementations.

Expand All @@ -2176,21 +2216,45 @@ def create_bridge_decorator(
target: Megatron model class
provider: Provider class to use for this model (e.g., DeepSeekModelProvider)
model_type: HuggingFace model_type string (e.g., "llama", "deepseek_v3")
min_transformers_version: Minimum Transformers version required by the bridge.
required_transformers_symbols: Dotted Transformers symbols required by the bridge.

Returns:
Decorator function that registers the bridge implementation
"""

def decorator(bridge_class: Type["MegatronModelBridge"]) -> Type["MegatronModelBridge"]:
if required_transformers_symbols and min_transformers_version is None:
raise ValueError("required_transformers_symbols requires min_transformers_version")
# Store source name for HF config generation
bridge_class.SOURCE_NAME = source if isinstance(source, str) else source.__name__
bridge_class.MIN_TRANSFORMERS_VERSION = min_transformers_version
bridge_class.REQUIRED_TRANSFORMERS_SYMBOLS = tuple(required_transformers_symbols)
# Store model_type for HF config generation
if model_type is not None:
bridge_class.MODEL_TYPE = model_type
# Set the provider class on the bridge
if provider is not None:
bridge_class.PROVIDER_CLASS = provider
_BRIDGE_CLASSES_BY_SOURCE_NAME[bridge_class.SOURCE_NAME] = bridge_class
if model_type is not None:
_BRIDGE_CLASSES_BY_MODEL_TYPE[model_type] = bridge_class
register_bridge_implementation(source=source, target=target, bridge_class=bridge_class)
return bridge_class

return decorator


def get_registered_bridge_class(
*,
source_name: str | None = None,
model_type: str | None = None,
) -> type[MegatronModelBridge] | None:
"""Return a registered bridge class by architecture name or model type."""
if source_name is not None:
bridge_class = _BRIDGE_CLASSES_BY_SOURCE_NAME.get(source_name)
if bridge_class is not None:
return bridge_class
if model_type is not None:
return _BRIDGE_CLASSES_BY_MODEL_TYPE.get(model_type)
return None
Loading
Loading