From d3ace5a3db792d5961878171181f84d08148d626 Mon Sep 17 00:00:00 2001 From: Chen Cui Date: Wed, 1 Jul 2026 12:53:20 -0700 Subject: [PATCH 1/3] feat(model): add lazy transformers compatibility guards Signed-off-by: Chen Cui --- .../bridge/models/conversion/__init__.py | 10 + .../bridge/models/conversion/auto_bridge.py | 13 +- .../bridge/models/conversion/model_bridge.py | 66 +++++- .../models/conversion/transformers_version.py | 178 +++++++++++++++ .../bridge/models/glm_moe_dsa/glm5_bridge.py | 8 +- .../hf_pretrained/safe_config_loader.py | 35 +++ .../bridge/models/qwen/qwen35_bridge.py | 17 +- .../bridge/models/qwen_vl/qwen35_vl_bridge.py | 6 + .../models/qwen_vl/qwen35_vl_provider.py | 79 +++---- .../test_transformers_version_guard.py | 89 ++++++++ .../models/qwen_vl/test_qwen35_vl_bridge.py | 18 +- .../models/test_transformers_version.py | 206 ++++++++++++++++++ 12 files changed, 662 insertions(+), 63 deletions(-) create mode 100644 src/megatron/bridge/models/conversion/transformers_version.py create mode 100644 tests/unit_tests/models/hf_pretrained/test_transformers_version_guard.py create mode 100644 tests/unit_tests/models/test_transformers_version.py diff --git a/src/megatron/bridge/models/conversion/__init__.py b/src/megatron/bridge/models/conversion/__init__.py index 7b4b493dec..9edf8249aa 100644 --- a/src/megatron/bridge/models/conversion/__init__.py +++ b/src/megatron/bridge/models/conversion/__init__.py @@ -27,6 +27,12 @@ ReplicatedMapping, RowParallelMapping, ) +from megatron.bridge.models.conversion.transformers_version import ( + TransformersVersionError, + get_transformers_version, + is_transformers_min_version, + require_transformers_version, +) from megatron.bridge.models.conversion.utils import weights_verification_table @@ -44,4 +50,8 @@ "RowParallelMapping", "AutoMapping", "weights_verification_table", + "TransformersVersionError", + "get_transformers_version", + "is_transformers_min_version", + "require_transformers_version", ] diff --git a/src/megatron/bridge/models/conversion/auto_bridge.py b/src/megatron/bridge/models/conversion/auto_bridge.py index 317f92fbc5..308490af59 100644 --- a/src/megatron/bridge/models/conversion/auto_bridge.py +++ b/src/megatron/bridge/models/conversion/auto_bridge.py @@ -323,8 +323,6 @@ def from_auto_config(cls, megatron_path: str, hf_model_id: str, trust_remote_cod Raises: FileNotFoundError: If run_config.yaml is not found in the Megatron path """ - from transformers import AutoConfig - from megatron.bridge.models.conversion.utils import conform_config_to_reference from megatron.bridge.training.model_load_save import load_model_config @@ -352,7 +350,7 @@ def from_auto_config(cls, megatron_path: str, hf_model_id: str, trust_remote_cod "Loading a model with trust_remote_code=True allows arbitrary code execution " "from the model repository. Only use this with models you trust." ) - hf_cfg = AutoConfig.from_pretrained(hf_model_id, trust_remote_code=trust_remote_code) + hf_cfg = safe_load_config_with_retry(hf_model_id, trust_remote_code=trust_remote_code) # 2. Translate Megatron config -> HF, conforming to reference config bridge = cls.from_hf_config(hf_cfg) megatron_hf_cfg_dict = bridge._model_bridge.megatron_to_hf_config(megatron_cfg) @@ -1933,6 +1931,15 @@ def _validate_config(cls, config: PretrainedConfig, path: str | None = None) -> f" • src/megatron/bridge/models/qwen/qwen_2_causal_bridge.py" ) from None + registered_source_name = arch_key if isinstance(arch_key, str) else arch_key.__name__ + bridge_class = model_bridge.get_registered_bridge_class( + source_name=registered_source_name, + model_type=getattr(config, "model_type", None), + ) + if bridge_class is not None: + action = "select this model configuration" if path is None else f"load model {path}" + bridge_class.require_transformers_compatibility(action=action) + def _get_model_instance(self, model: list[MegatronModelT]) -> MegatronModelT: model_instance = model[0] while hasattr(model_instance, "module"): diff --git a/src/megatron/bridge/models/conversion/model_bridge.py b/src/megatron/bridge/models/conversion/model_bridge.py index b05f16384b..2592b03f74 100644 --- a/src/megatron/bridge/models/conversion/model_bridge.py +++ b/src/megatron/bridge/models/conversion/model_bridge.py @@ -62,6 +62,7 @@ from megatron.bridge.models.conversion.transformers_compat import ( rope_theta_from_hf, ) +from megatron.bridge.models.conversion.transformers_version import require_transformers_version from megatron.bridge.models.conversion.utils import ( extract_sort_key, get_module_and_param_from_name, @@ -82,6 +83,9 @@ MegatronModel = TypeVar("MegatronModel", bound=MegatronModule) _BridgeImplClass = TypeVar("_BridgeImplClass", bound="MegatronModelBridge") +_BRIDGE_CLASSES_BY_SOURCE_NAME: dict[str, type["MegatronModelBridge"]] = {} +_BRIDGE_CLASSES_BY_MODEL_TYPE: dict[str, type["MegatronModelBridge"]] = {} + class MegatronWeightTuple(NamedTuple): """Tuple representing a Megatron model weight with its metadata.""" @@ -610,6 +614,26 @@ def hf_config_to_provider_kwargs(self, hf_config) -> dict: # Set by @register_bridge decorator SOURCE_NAME: str | None = None MODEL_TYPE: str | None = None + MIN_TRANSFORMERS_VERSION: str | None = None + REQUIRED_TRANSFORMERS_SYMBOLS: tuple[str, ...] = () + + def __new__(cls, *args: Any, **kwargs: Any) -> "MegatronModelBridge": + """Create a bridge only after validating its Transformers policy.""" + instance = super().__new__(cls) + cls.require_transformers_compatibility(action="use this model bridge") + return instance + + @classmethod + def require_transformers_compatibility(cls, *, action: str | None = None) -> None: + """Validate the explicit Transformers policy attached to this bridge.""" + if cls.MIN_TRANSFORMERS_VERSION is None: + return + require_transformers_version( + cls.SOURCE_NAME or cls.__name__, + cls.MIN_TRANSFORMERS_VERSION, + symbols=cls.REQUIRED_TRANSFORMERS_SYMBOLS, + action=action, + ) def provider_bridge(self, hf_pretrained: HFPreTrained) -> ModelProviderTarget: """Create a Megatron model provider from HuggingFace configuration. @@ -1949,6 +1973,8 @@ def register_bridge( target: Type[MegatronModel], provider: Type[ModelProviderTarget] | None = None, model_type: str | None = None, + min_transformers_version: str | None = None, + required_transformers_symbols: tuple[str, ...] = (), ) -> Callable[[_BridgeImplClass], _BridgeImplClass]: """Class decorator for registering bridge implementations. @@ -1967,6 +1993,11 @@ def register_bridge( Defaults to GPTModelProvider if not specified. model_type (str, optional): HuggingFace model_type string (e.g., "llama"). Used for megatron_to_hf_config conversion. + min_transformers_version: Minimum Transformers version required to + select or use this bridge. ``None`` means the policy is not yet + researched and is not enforced. + required_transformers_symbols: Dotted Transformers symbols that + must be available after the minimum version check passes. Returns: Callable[[_BridgeImplClass], _BridgeImplClass]: Decorator function @@ -2006,7 +2037,14 @@ class MegatronDeepseekV3Bridge(MegatronModelBridge): class is defined. """ - return create_bridge_decorator(source=source, target=target, provider=provider, model_type=model_type) + return create_bridge_decorator( + source=source, + target=target, + provider=provider, + model_type=model_type, + min_transformers_version=min_transformers_version, + required_transformers_symbols=required_transformers_symbols, + ) # Core dispatch functions @@ -2167,6 +2205,8 @@ def create_bridge_decorator( target: Type["MegatronModule"], provider: Type["ModelProviderMixin"] | None = None, model_type: str | None = None, + min_transformers_version: str | None = None, + required_transformers_symbols: tuple[str, ...] = (), ) -> Callable[[Type["MegatronModelBridge"]], Type["MegatronModelBridge"]]: """Create a decorator for registering bridge implementations. @@ -2176,21 +2216,45 @@ def create_bridge_decorator( target: Megatron model class provider: Provider class to use for this model (e.g., DeepSeekModelProvider) model_type: HuggingFace model_type string (e.g., "llama", "deepseek_v3") + min_transformers_version: Minimum Transformers version required by the bridge. + required_transformers_symbols: Dotted Transformers symbols required by the bridge. Returns: Decorator function that registers the bridge implementation """ def decorator(bridge_class: Type["MegatronModelBridge"]) -> Type["MegatronModelBridge"]: + if required_transformers_symbols and min_transformers_version is None: + raise ValueError("required_transformers_symbols requires min_transformers_version") # Store source name for HF config generation bridge_class.SOURCE_NAME = source if isinstance(source, str) else source.__name__ + bridge_class.MIN_TRANSFORMERS_VERSION = min_transformers_version + bridge_class.REQUIRED_TRANSFORMERS_SYMBOLS = tuple(required_transformers_symbols) # Store model_type for HF config generation if model_type is not None: bridge_class.MODEL_TYPE = model_type # Set the provider class on the bridge if provider is not None: bridge_class.PROVIDER_CLASS = provider + _BRIDGE_CLASSES_BY_SOURCE_NAME[bridge_class.SOURCE_NAME] = bridge_class + if model_type is not None: + _BRIDGE_CLASSES_BY_MODEL_TYPE[model_type] = bridge_class register_bridge_implementation(source=source, target=target, bridge_class=bridge_class) return bridge_class return decorator + + +def get_registered_bridge_class( + *, + source_name: str | None = None, + model_type: str | None = None, +) -> type[MegatronModelBridge] | None: + """Return a registered bridge class by architecture name or model type.""" + if source_name is not None: + bridge_class = _BRIDGE_CLASSES_BY_SOURCE_NAME.get(source_name) + if bridge_class is not None: + return bridge_class + if model_type is not None: + return _BRIDGE_CLASSES_BY_MODEL_TYPE.get(model_type) + return None diff --git a/src/megatron/bridge/models/conversion/transformers_version.py b/src/megatron/bridge/models/conversion/transformers_version.py new file mode 100644 index 0000000000..13bd210190 --- /dev/null +++ b/src/megatron/bridge/models/conversion/transformers_version.py @@ -0,0 +1,178 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Lazy, model-specific compatibility checks for Hugging Face Transformers.""" + +from __future__ import annotations + +import importlib +import importlib.metadata +import importlib.util + +from packaging.version import InvalidVersion, Version + + +class TransformersVersionError(RuntimeError): + """Raised when a selected model is incompatible with installed Transformers. + + Args: + model_name: User-facing model or architecture name. + installed_version: Installed Transformers version. + required_version: Minimum Transformers version required by the model. + missing_symbols: Required Transformers symbols that are unavailable. + action: Operation that triggered the compatibility check. + """ + + def __init__( + self, + model_name: str, + installed_version: Version, + required_version: Version, + *, + missing_symbols: tuple[str, ...] = (), + action: str | None = None, + ) -> None: + self.model_name = model_name + self.installed_version = installed_version + self.required_version = required_version + self.missing_symbols = missing_symbols + self.action = action + + operation = action or "use this model" + details = [ + f"Cannot {operation} for {model_name} with Transformers {installed_version}.", + f"{model_name} requires Transformers>={required_version}.", + ] + if missing_symbols: + details.append(f"Missing required symbol(s): {', '.join(missing_symbols)}.") + details.append( + f"Install or upgrade to a compatible Transformers version (>={required_version}) before retrying." + ) + super().__init__(" ".join(details)) + + +def get_transformers_version() -> Version: + """Return the installed Transformers version as a PEP 440 version.""" + try: + installed_version = importlib.metadata.version("transformers") + except importlib.metadata.PackageNotFoundError as error: + raise RuntimeError("Transformers is not installed.") from error + + try: + return Version(installed_version) + except InvalidVersion as error: + raise RuntimeError(f"Installed Transformers has an invalid version: {installed_version!r}.") from error + + +def _parse_required_version(version: str) -> Version: + try: + return Version(version) + except InvalidVersion as error: + raise ValueError(f"Invalid minimum Transformers version: {version!r}.") from error + + +def is_transformers_min_version(version: str) -> bool: + """Return whether installed Transformers satisfies ``version``.""" + return get_transformers_version() >= _parse_required_version(version) + + +def _module_not_found_targets(module_name: str, error: ModuleNotFoundError) -> bool: + missing_name = error.name + return missing_name is not None and (missing_name == module_name or module_name.startswith(f"{missing_name}.")) + + +def _has_transformers_symbol(symbol_path: str) -> bool: + """Resolve a dotted module/attribute path without hiding unrelated import failures.""" + if symbol_path != "transformers" and not symbol_path.startswith("transformers."): + raise ValueError(f"Transformers symbol paths must start with 'transformers': {symbol_path!r}.") + + parts = symbol_path.split(".") + module_name = None + attribute_parts: list[str] = [] + for split_index in range(len(parts), 0, -1): + candidate = ".".join(parts[:split_index]) + try: + spec = importlib.util.find_spec(candidate) + except ModuleNotFoundError as error: + if not _module_not_found_targets(candidate, error): + raise + continue + except (AttributeError, ValueError): + continue + if spec is not None: + module_name = candidate + attribute_parts = parts[split_index:] + break + + if module_name is None: + return False + + try: + resolved: object = importlib.import_module(module_name) + except ModuleNotFoundError as error: + if _module_not_found_targets(module_name, error): + return False + raise + + for attribute_name in attribute_parts: + try: + resolved = getattr(resolved, attribute_name) + except AttributeError: + return False + return True + + +def require_transformers_version( + model_name: str, + min_version: str, + *, + symbols: tuple[str, ...] = (), + action: str | None = None, +) -> None: + """Require a Transformers version and optional symbols for one model. + + Symbol imports are attempted only after the minimum version is satisfied, so + selecting a model on an older installation fails without importing its + version-sensitive Transformers modules. + + Args: + model_name: User-facing model or architecture name. + min_version: Minimum compatible Transformers version. + symbols: Dotted Transformers module or attribute paths required by the model. + action: Operation that triggered the compatibility check. + + Raises: + TransformersVersionError: If the version is too old or a required symbol is missing. + ValueError: If ``min_version`` or a symbol path is invalid. + ImportError: If resolving a symbol fails because of an unrelated dependency error. + """ + installed_version = get_transformers_version() + required_version = _parse_required_version(min_version) + if installed_version < required_version: + raise TransformersVersionError( + model_name, + installed_version, + required_version, + action=action, + ) + + missing_symbols = tuple(symbol for symbol in symbols if not _has_transformers_symbol(symbol)) + if missing_symbols: + raise TransformersVersionError( + model_name, + installed_version, + required_version, + missing_symbols=missing_symbols, + action=action, + ) diff --git a/src/megatron/bridge/models/glm_moe_dsa/glm5_bridge.py b/src/megatron/bridge/models/glm_moe_dsa/glm5_bridge.py index f5b4ab24a2..ce9b79e9a7 100644 --- a/src/megatron/bridge/models/glm_moe_dsa/glm5_bridge.py +++ b/src/megatron/bridge/models/glm_moe_dsa/glm5_bridge.py @@ -15,7 +15,6 @@ import logging from megatron.core.models.gpt.gpt_model import GPTModel -from transformers import GlmMoeDsaForCausalLM from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge @@ -32,7 +31,12 @@ @MegatronModelBridge.register_bridge( - source=GlmMoeDsaForCausalLM, target=GPTModel, provider=MLAModelProvider, model_type="glm_moe_dsa" + source="GlmMoeDsaForCausalLM", + target=GPTModel, + provider=MLAModelProvider, + model_type="glm_moe_dsa", + min_transformers_version="5.2.0", + required_transformers_symbols=("transformers.GlmMoeDsaForCausalLM",), ) class GLM5Bridge(MegatronModelBridge): """ diff --git a/src/megatron/bridge/models/hf_pretrained/safe_config_loader.py b/src/megatron/bridge/models/hf_pretrained/safe_config_loader.py index 5eb6186975..0c82d008f4 100644 --- a/src/megatron/bridge/models/hf_pretrained/safe_config_loader.py +++ b/src/megatron/bridge/models/hf_pretrained/safe_config_loader.py @@ -22,6 +22,7 @@ import hashlib import os +import re import time from pathlib import Path from typing import Union @@ -33,6 +34,39 @@ import megatron.bridge.models.conversion.transformers_compat # noqa: F401 # patches removed HF utils +_UNRECOGNIZED_MODEL_TYPE_PATTERNS = ( + re.compile(r"model type [`'\"](?P[^`'\"]+)[`'\"]", re.IGNORECASE), + re.compile(r"model_type [`'\"](?P[^`'\"]+)[`'\"]", re.IGNORECASE), +) + + +def _raise_for_known_transformers_architecture_error(error: Exception) -> None: + """Translate a known guarded model-type error into Bridge compatibility guidance.""" + error_message = str(error) + normalized_message = error_message.lower() + if not ( + "does not recognize this architecture" in normalized_message or "unrecognized model type" in normalized_message + ): + return + + model_type = None + for pattern in _UNRECOGNIZED_MODEL_TYPE_PATTERNS: + match = pattern.search(error_message) + if match is not None: + model_type = match.group("model_type") + break + if model_type is None: + return + + # Import lazily to avoid a safe_config_loader -> model_bridge import cycle. + from megatron.bridge.models.conversion.model_bridge import get_registered_bridge_class + + bridge_class = get_registered_bridge_class(model_type=model_type) + if bridge_class is None or bridge_class.MIN_TRANSFORMERS_VERSION is None: + return + bridge_class.require_transformers_compatibility(action="load this model configuration") + + def safe_load_config_with_retry( path: Union[str, Path], trust_remote_code: bool = False, max_retries: int = 3, base_delay: float = 1.0, **kwargs ) -> PretrainedConfig: @@ -101,6 +135,7 @@ def safe_load_config_with_retry( return AutoConfig.from_pretrained(path, trust_remote_code=trust_remote_code, **kwargs) except Exception as e: + _raise_for_known_transformers_architecture_error(e) last_exception = e # Don't retry on certain types of errors diff --git a/src/megatron/bridge/models/qwen/qwen35_bridge.py b/src/megatron/bridge/models/qwen/qwen35_bridge.py index 06d7608ac1..3a2c85a4cd 100644 --- a/src/megatron/bridge/models/qwen/qwen35_bridge.py +++ b/src/megatron/bridge/models/qwen/qwen35_bridge.py @@ -17,7 +17,6 @@ get_transformer_block_with_experimental_attention_variant_spec, ) from megatron.core.models.gpt.gpt_model import GPTModel -from transformers import Qwen3_5ForCausalLM, Qwen3_5MoeForCausalLM from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge @@ -101,7 +100,13 @@ def _apply_qwen35_moe_config(provider: GPTModelProvider, text_config) -> None: provider.moe_permute_fusion = True -@MegatronModelBridge.register_bridge(source=Qwen3_5MoeForCausalLM, target=GPTModel, model_type="qwen3_5_moe_text") +@MegatronModelBridge.register_bridge( + source="Qwen3_5MoeForCausalLM", + target=GPTModel, + model_type="qwen3_5_moe_text", + min_transformers_version="5.2.0", + required_transformers_symbols=("transformers.Qwen3_5MoeForCausalLM",), +) class Qwen35MoEBridge(MegatronModelBridge): """ Megatron Bridge for Qwen3.5 Language Model (MoE variant). @@ -430,7 +435,13 @@ def mapping_registry(self) -> MegatronMappingRegistry: return MegatronMappingRegistry(*mapping_list) -@MegatronModelBridge.register_bridge(source=Qwen3_5ForCausalLM, target=GPTModel, model_type="qwen3_5_text") +@MegatronModelBridge.register_bridge( + source="Qwen3_5ForCausalLM", + target=GPTModel, + model_type="qwen3_5_text", + min_transformers_version="5.2.0", + required_transformers_symbols=("transformers.Qwen3_5ForCausalLM",), +) class Qwen35Bridge(MegatronModelBridge): """ Megatron Bridge for Qwen3.5 Dense Language Model. diff --git a/src/megatron/bridge/models/qwen_vl/qwen35_vl_bridge.py b/src/megatron/bridge/models/qwen_vl/qwen35_vl_bridge.py index 505d78af91..857e9b0deb 100644 --- a/src/megatron/bridge/models/qwen_vl/qwen35_vl_bridge.py +++ b/src/megatron/bridge/models/qwen_vl/qwen35_vl_bridge.py @@ -139,6 +139,10 @@ def _get_vision_mappings(): target=Qwen3VLModel, provider=Qwen35VLMoEModelProvider, model_type="qwen3_5_moe", + min_transformers_version="5.2.0", + required_transformers_symbols=( + "transformers.models.qwen3_5_moe.configuration_qwen3_5_moe.Qwen3_5MoeVisionConfig", + ), ) class Qwen35VLMoEBridge(MegatronModelBridge): """ @@ -269,6 +273,8 @@ def mapping_registry(self) -> MegatronMappingRegistry: target=Qwen3VLModel, provider=Qwen35VLModelProvider, model_type="qwen3_5", + min_transformers_version="5.2.0", + required_transformers_symbols=("transformers.models.qwen3_5.configuration_qwen3_5.Qwen3_5VisionConfig",), ) class Qwen35VLBridge(MegatronModelBridge): """ diff --git a/src/megatron/bridge/models/qwen_vl/qwen35_vl_provider.py b/src/megatron/bridge/models/qwen_vl/qwen35_vl_provider.py index 13651f1424..c8257bdcfd 100644 --- a/src/megatron/bridge/models/qwen_vl/qwen35_vl_provider.py +++ b/src/megatron/bridge/models/qwen_vl/qwen35_vl_provider.py @@ -34,38 +34,20 @@ from dataclasses import dataclass, field from typing import Any, Callable, ClassVar, List, Optional -import transformers -from megatron.core.models.gpt import GPTModel as MCoreGPTModel -from megatron.core.models.gpt.experimental_attention_variant_module_specs import ( - get_transformer_block_with_experimental_attention_variant_spec, -) -from megatron.core.transformer.spec_utils import ModuleSpec -from megatron.core.transformer.transformer_block import TransformerBlockSubmodules -from packaging.version import Version as PkgVersion - - -_TRANSFORMERS_HAS_QWEN3_5_MOE = PkgVersion(transformers.__version__) >= PkgVersion("5.2.0") - -if _TRANSFORMERS_HAS_QWEN3_5_MOE: - from transformers.models.qwen3_5_moe.configuration_qwen3_5_moe import Qwen3_5MoeVisionConfig -else: - Qwen3_5MoeVisionConfig = None # type: ignore[assignment,misc] - -try: - from transformers.models.qwen3_5.configuration_qwen3_5 import Qwen3_5VisionConfig - - _TRANSFORMERS_HAS_QWEN3_5 = True -except ImportError: - _TRANSFORMERS_HAS_QWEN3_5 = False - Qwen3_5VisionConfig = None # type: ignore[assignment,misc] - from megatron.core.extensions.transformer_engine import ( TEColumnParallelLinear, TENorm, TERowParallelLinear, ) +from megatron.core.models.gpt import GPTModel as MCoreGPTModel +from megatron.core.models.gpt.experimental_attention_variant_module_specs import ( + get_transformer_block_with_experimental_attention_variant_spec, +) from megatron.core.models.vision.vit_layer_specs import get_vit_layer_with_transformer_engine_spec +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_block import TransformerBlockSubmodules +from megatron.bridge.models.conversion.transformers_version import require_transformers_version from megatron.bridge.models.gpt_provider import GPTModelProvider from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.attention import Qwen3VLSelfAttention from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.model import Qwen3VLModel @@ -79,23 +61,32 @@ _IMAGES_MODALITY_KEY = "images" -def _check_qwen3_5_available() -> None: - """Raise a clear error if transformers doesn't have qwen3_5 (dense) support.""" - if not _TRANSFORMERS_HAS_QWEN3_5: - raise ImportError( - f"Qwen3.5 VL (dense) requires transformers with qwen3_5 model support, " - f"but found {transformers.__version__}. " - "Please upgrade: pip install --upgrade transformers" - ) +def _get_qwen3_5_vision_config_class() -> type: + """Return the guarded Qwen3.5 dense vision config class.""" + symbol = "transformers.models.qwen3_5.configuration_qwen3_5.Qwen3_5VisionConfig" + require_transformers_version( + "Qwen3.5 VL (dense)", + "5.2.0", + symbols=(symbol,), + action="construct the model provider", + ) + from transformers.models.qwen3_5.configuration_qwen3_5 import Qwen3_5VisionConfig + return Qwen3_5VisionConfig -def _check_qwen3_5_moe_available() -> None: - """Raise a clear error if transformers doesn't have qwen3_5_moe support.""" - if not _TRANSFORMERS_HAS_QWEN3_5_MOE: - raise ImportError( - f"Qwen3.5 VL (MoE) requires transformers >= 5.2.0, but found {transformers.__version__}. " - "Please upgrade: pip install --upgrade transformers" - ) + +def _get_qwen3_5_moe_vision_config_class() -> type: + """Return the guarded Qwen3.5 MoE vision config class.""" + symbol = "transformers.models.qwen3_5_moe.configuration_qwen3_5_moe.Qwen3_5MoeVisionConfig" + require_transformers_version( + "Qwen3.5 VL (MoE)", + "5.2.0", + symbols=(symbol,), + action="construct the model provider", + ) + from transformers.models.qwen3_5_moe.configuration_qwen3_5_moe import Qwen3_5MoeVisionConfig + + return Qwen3_5MoeVisionConfig @dataclass @@ -198,9 +189,9 @@ def special_token_ids(self) -> dict[str, int]: return {_IMAGES_MODALITY_KEY: self.image_token_id} def __post_init__(self): - _check_qwen3_5_available() + vision_config_class = _get_qwen3_5_vision_config_class() if self.vision_config is None: - self.vision_config = Qwen3_5VisionConfig() + self.vision_config = vision_config_class() super().__post_init__() def finalize(self) -> None: @@ -388,9 +379,9 @@ def special_token_ids(self) -> dict[str, int]: return {_IMAGES_MODALITY_KEY: self.image_token_id} def __post_init__(self): - _check_qwen3_5_moe_available() + vision_config_class = _get_qwen3_5_moe_vision_config_class() if self.vision_config is None: - self.vision_config = Qwen3_5MoeVisionConfig() + self.vision_config = vision_config_class() super().__post_init__() def finalize(self) -> None: diff --git a/tests/unit_tests/models/hf_pretrained/test_transformers_version_guard.py b/tests/unit_tests/models/hf_pretrained/test_transformers_version_guard.py new file mode 100644 index 0000000000..c9b440b331 --- /dev/null +++ b/tests/unit_tests/models/hf_pretrained/test_transformers_version_guard.py @@ -0,0 +1,89 @@ +from types import SimpleNamespace +from unittest.mock import patch + +import pytest +from packaging.version import Version + +from megatron.bridge.models.conversion import transformers_version +from megatron.bridge.models.conversion.auto_bridge import AutoBridge +from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge +from megatron.bridge.models.conversion.transformers_version import TransformersVersionError +from megatron.bridge.models.hf_pretrained.safe_config_loader import safe_load_config_with_retry + + +pytestmark = pytest.mark.unit + + +class _ConfigGuardTarget: + pass + + +@MegatronModelBridge.register_bridge( + source="UnitTestConfigGuardForCausalLM", + target=_ConfigGuardTarget, + model_type="unit_test_config_guard", + min_transformers_version="99.0.0", +) +class _ConfigGuardBridge(MegatronModelBridge): + def mapping_registry(self): + raise NotImplementedError + + +def test_auto_bridge_validates_guard_before_provider_or_model_loading(monkeypatch): + monkeypatch.setattr(transformers_version, "get_transformers_version", lambda: Version("5.3.0")) + config = SimpleNamespace( + architectures=["UnitTestConfigGuardForCausalLM"], + model_type="unit_test_config_guard", + ) + + with pytest.raises(TransformersVersionError, match="UnitTestConfigGuardForCausalLM"): + AutoBridge.from_hf_config(config) + + +def test_safe_config_loader_wraps_known_architecture_version_failure(monkeypatch, tmp_path): + monkeypatch.setattr(transformers_version, "get_transformers_version", lambda: Version("5.3.0")) + monkeypatch.setenv("MEGATRON_CONFIG_LOCK_DIR", str(tmp_path)) + hf_error = ValueError( + "The checkpoint you are trying to load has model type `unit_test_config_guard` " + "but Transformers does not recognize this architecture." + ) + + with patch( + "megatron.bridge.models.hf_pretrained.safe_config_loader.AutoConfig.from_pretrained", + side_effect=hf_error, + ): + with pytest.raises(TransformersVersionError, match="Transformers>=99.0.0"): + safe_load_config_with_retry("unused/model", max_retries=0) + + +def test_safe_config_loader_does_not_wrap_unknown_architecture(monkeypatch, tmp_path): + monkeypatch.setenv("MEGATRON_CONFIG_LOCK_DIR", str(tmp_path)) + hf_error = ValueError( + "The checkpoint you are trying to load has model type `not_registered` " + "but Transformers does not recognize this architecture." + ) + + with patch( + "megatron.bridge.models.hf_pretrained.safe_config_loader.AutoConfig.from_pretrained", + side_effect=hf_error, + ): + with pytest.raises(ValueError) as error_info: + safe_load_config_with_retry("unused/model", max_retries=0) + + assert not isinstance(error_info.value, TransformersVersionError) + assert "not_registered" in str(error_info.value) + + +def test_safe_config_loader_does_not_wrap_network_or_auth_failure(monkeypatch, tmp_path): + monkeypatch.setenv("MEGATRON_CONFIG_LOCK_DIR", str(tmp_path)) + hf_error = OSError("401 Client Error: unauthorized") + + with patch( + "megatron.bridge.models.hf_pretrained.safe_config_loader.AutoConfig.from_pretrained", + side_effect=hf_error, + ): + with pytest.raises(ValueError) as error_info: + safe_load_config_with_retry("private/model", max_retries=0) + + assert not isinstance(error_info.value, TransformersVersionError) + assert "Ensure the path is valid" in str(error_info.value) diff --git a/tests/unit_tests/models/qwen_vl/test_qwen35_vl_bridge.py b/tests/unit_tests/models/qwen_vl/test_qwen35_vl_bridge.py index 1a57e608e3..2a72298b5b 100644 --- a/tests/unit_tests/models/qwen_vl/test_qwen35_vl_bridge.py +++ b/tests/unit_tests/models/qwen_vl/test_qwen35_vl_bridge.py @@ -20,16 +20,14 @@ from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry from megatron.bridge.models.conversion.model_bridge import WeightConversionTask +from megatron.bridge.models.conversion.transformers_version import is_transformers_min_version from megatron.bridge.models.hf_pretrained.vlm import PreTrainedVLM from megatron.bridge.models.qwen_vl.qwen35_vl_bridge import Qwen35VLBridge, Qwen35VLMoEBridge -from megatron.bridge.models.qwen_vl.qwen35_vl_provider import ( - _TRANSFORMERS_HAS_QWEN3_5, - _TRANSFORMERS_HAS_QWEN3_5_MOE, - Qwen35VLModelProvider, - Qwen35VLMoEModelProvider, -) +from megatron.bridge.models.qwen_vl.qwen35_vl_provider import Qwen35VLModelProvider, Qwen35VLMoEModelProvider +_TRANSFORMERS_HAS_QWEN3_5 = is_transformers_min_version("5.2.0") + pytestmark = pytest.mark.skipif(not _TRANSFORMERS_HAS_QWEN3_5, reason="transformers does not have qwen3_5 support") @@ -283,14 +281,14 @@ def test_mapping_registry_has_vision_patch_embed(self, bridge): # ===================================================================== -@pytest.mark.skipif(not _TRANSFORMERS_HAS_QWEN3_5_MOE, reason="transformers does not have qwen3_5_moe support") +@pytest.mark.skipif(not _TRANSFORMERS_HAS_QWEN3_5, reason="transformers does not have qwen3_5_moe support") class TestQwen35VLMoEBridgeInitialization: def test_bridge_initialization(self): bridge = Qwen35VLMoEBridge() assert isinstance(bridge, Qwen35VLMoEBridge) -@pytest.mark.skipif(not _TRANSFORMERS_HAS_QWEN3_5_MOE, reason="transformers does not have qwen3_5_moe support") +@pytest.mark.skipif(not _TRANSFORMERS_HAS_QWEN3_5, reason="transformers does not have qwen3_5_moe support") class TestQwen35VLMoEBridgeProviderBridge: @pytest.fixture def bridge(self): @@ -341,7 +339,7 @@ def test_provider_bridge_token_ids(self, bridge, mock_pretrained): assert provider.image_token_id == 248056 -@pytest.mark.skipif(not _TRANSFORMERS_HAS_QWEN3_5_MOE, reason="transformers does not have qwen3_5_moe support") +@pytest.mark.skipif(not _TRANSFORMERS_HAS_QWEN3_5, reason="transformers does not have qwen3_5_moe support") class TestQwen35VLMoEBridgeMappingRegistry: @pytest.fixture def bridge(self): @@ -381,7 +379,7 @@ def test_mapping_registry_has_vision_params(self, bridge): @pytest.mark.unit -@pytest.mark.skipif(not _TRANSFORMERS_HAS_QWEN3_5_MOE, reason="transformers does not have qwen3_5_moe support") +@pytest.mark.skipif(not _TRANSFORMERS_HAS_QWEN3_5, reason="transformers does not have qwen3_5_moe support") class TestQwen35VLMoEBridgeExport: def test_maybe_modify_converted_hf_weight_keeps_explicit_mtp_expert_keys(self, monkeypatch): """Preserve already-expanded MTP expert keys without extra regrouping.""" diff --git a/tests/unit_tests/models/test_transformers_version.py b/tests/unit_tests/models/test_transformers_version.py new file mode 100644 index 0000000000..cba936db64 --- /dev/null +++ b/tests/unit_tests/models/test_transformers_version.py @@ -0,0 +1,206 @@ +import importlib.machinery +from unittest.mock import Mock + +import pytest +from packaging.version import Version + +from megatron.bridge.models.conversion import model_bridge, transformers_version +from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge, get_model_bridge +from megatron.bridge.models.conversion.transformers_version import ( + TransformersVersionError, + get_transformers_version, + is_transformers_min_version, + require_transformers_version, +) + + +pytestmark = pytest.mark.unit + + +class _TargetModel: + pass + + +def test_get_transformers_version_parses_pep440_local_version(monkeypatch): + monkeypatch.setattr(transformers_version.importlib.metadata, "version", lambda _: "5.3.0rc1+bridge.1") + + assert get_transformers_version() == Version("5.3.0rc1+bridge.1") + + +@pytest.mark.parametrize( + ("installed", "required", "expected"), + [ + ("5.3.0", "5.3.0", True), + ("5.3.0+vendor.1", "5.3.0", True), + ("5.3.0rc1", "5.3.0", False), + ("5.4.0.dev1", "5.3.0", True), + ], +) +def test_is_transformers_min_version_uses_pep440(monkeypatch, installed, required, expected): + monkeypatch.setattr(transformers_version, "get_transformers_version", lambda: Version(installed)) + + assert is_transformers_min_version(required) is expected + + +def test_is_transformers_min_version_rejects_invalid_requirement(monkeypatch): + monkeypatch.setattr(transformers_version, "get_transformers_version", lambda: Version("5.3.0")) + + with pytest.raises(ValueError, match="Invalid minimum Transformers version"): + is_transformers_min_version("not-a-version") + + +def test_require_transformers_version_accepts_dotted_symbol(monkeypatch): + monkeypatch.setattr(transformers_version, "get_transformers_version", lambda: Version("5.3.0")) + + require_transformers_version( + "test model", + "5.3.0", + symbols=("transformers.configuration_utils.PretrainedConfig",), + ) + + +def test_require_transformers_version_reports_missing_symbol(monkeypatch): + monkeypatch.setattr(transformers_version, "get_transformers_version", lambda: Version("5.3.0")) + + with pytest.raises(TransformersVersionError) as error_info: + require_transformers_version( + "FutureModel", + "5.2.0", + symbols=("transformers.FutureModelForCausalLM",), + action="construct the provider", + ) + + error = error_info.value + assert error.model_name == "FutureModel" + assert error.installed_version == Version("5.3.0") + assert error.required_version == Version("5.2.0") + assert error.missing_symbols == ("transformers.FutureModelForCausalLM",) + assert "construct the provider" in str(error) + assert "Install or upgrade" in str(error) + + +def test_symbol_check_does_not_hide_unrelated_import_failure(monkeypatch): + monkeypatch.setattr(transformers_version, "get_transformers_version", lambda: Version("5.3.0")) + + def fake_find_spec(name): + if name == "transformers.models.future.configuration_future": + return importlib.machinery.ModuleSpec(name, loader=None) + return None + + def fake_import_module(name): + assert name == "transformers.models.future.configuration_future" + raise ModuleNotFoundError("No module named 'optional_dependency'", name="optional_dependency") + + monkeypatch.setattr(transformers_version.importlib.util, "find_spec", fake_find_spec) + monkeypatch.setattr(transformers_version.importlib, "import_module", fake_import_module) + + with pytest.raises(ModuleNotFoundError, match="optional_dependency"): + require_transformers_version( + "FutureModel", + "5.3.0", + symbols=("transformers.models.future.configuration_future.FutureConfig",), + ) + + +def test_symbol_probe_does_not_hide_unrelated_parent_import_failure(monkeypatch): + monkeypatch.setattr(transformers_version, "get_transformers_version", lambda: Version("5.3.0")) + + def fake_find_spec(name): + raise ModuleNotFoundError("No module named 'optional_dependency'", name="optional_dependency") + + monkeypatch.setattr(transformers_version.importlib.util, "find_spec", fake_find_spec) + + with pytest.raises(ModuleNotFoundError, match="optional_dependency"): + require_transformers_version( + "FutureModel", + "5.3.0", + symbols=("transformers.models.future.configuration_future.FutureConfig",), + ) + + +def test_register_bridge_persists_metadata_for_string_source(monkeypatch): + require_mock = Mock() + monkeypatch.setattr(model_bridge, "require_transformers_version", require_mock) + + @MegatronModelBridge.register_bridge( + source="UnitTestFutureForCausalLM", + target=_TargetModel, + model_type="unit_test_future", + min_transformers_version="99.0.0", + required_transformers_symbols=("transformers.UnitTestFutureForCausalLM",), + ) + class _StringSourceBridge(MegatronModelBridge): + def mapping_registry(self): + raise NotImplementedError + + assert _StringSourceBridge.SOURCE_NAME == "UnitTestFutureForCausalLM" + assert _StringSourceBridge.MODEL_TYPE == "unit_test_future" + assert _StringSourceBridge.MIN_TRANSFORMERS_VERSION == "99.0.0" + assert _StringSourceBridge.REQUIRED_TRANSFORMERS_SYMBOLS == ("transformers.UnitTestFutureForCausalLM",) + require_mock.assert_not_called() + + +def test_register_bridge_persists_metadata_for_class_source(): + class _ClassSource: + pass + + @MegatronModelBridge.register_bridge(source=_ClassSource, target=_TargetModel) + class _ClassSourceBridge(MegatronModelBridge): + def mapping_registry(self): + raise NotImplementedError + + assert _ClassSourceBridge.SOURCE_NAME == "_ClassSource" + assert _ClassSourceBridge.MIN_TRANSFORMERS_VERSION is None + assert _ClassSourceBridge.REQUIRED_TRANSFORMERS_SYMBOLS == () + + +def test_guard_is_lazy_until_bridge_selection(monkeypatch): + compatibility_error = TransformersVersionError( + "UnitTestLazyForCausalLM", + Version("5.3.0"), + Version("99.0.0"), + ) + require_mock = Mock(side_effect=compatibility_error) + monkeypatch.setattr(model_bridge, "require_transformers_version", require_mock) + + @MegatronModelBridge.register_bridge( + source="UnitTestLazyForCausalLM", + target=_TargetModel, + model_type="unit_test_lazy", + min_transformers_version="99.0.0", + ) + class _LazyBridge(MegatronModelBridge): + initialized = False + + def __init__(self): + type(self).initialized = True + + def mapping_registry(self): + raise NotImplementedError + + require_mock.assert_not_called() + assert _LazyBridge.initialized is False + + with pytest.raises(TransformersVersionError): + get_model_bridge("UnitTestLazyForCausalLM") + + assert _LazyBridge.initialized is False + require_mock.assert_called_once_with( + "UnitTestLazyForCausalLM", + "99.0.0", + symbols=(), + action="use this model bridge", + ) + + +def test_required_symbols_need_minimum_version(): + with pytest.raises(ValueError, match="requires min_transformers_version"): + + @MegatronModelBridge.register_bridge( + source="UnitTestInvalidPolicyForCausalLM", + target=_TargetModel, + required_transformers_symbols=("transformers.UnitTestInvalidPolicyForCausalLM",), + ) + class _InvalidPolicyBridge(MegatronModelBridge): + def mapping_registry(self): + raise NotImplementedError From e329f03849404467ba9cbc22b3c5e017a2150a8b Mon Sep 17 00:00:00 2001 From: Chen Cui Date: Wed, 1 Jul 2026 13:35:45 -0700 Subject: [PATCH 2/3] docs(model): document transformers compatibility metadata Signed-off-by: Chen Cui --- skills/adding-model-support/SKILL.md | 45 ++++++++++++++++++++++++++-- 1 file changed, 43 insertions(+), 2 deletions(-) diff --git a/skills/adding-model-support/SKILL.md b/skills/adding-model-support/SKILL.md index fa7f49885e..fad9bec38c 100644 --- a/skills/adding-model-support/SKILL.md +++ b/skills/adding-model-support/SKILL.md @@ -1,7 +1,6 @@ --- name: adding-model-support -description: Guide for adding support for new LLM or VLM models in Megatron-Bridge. Covers bridge, provider, recipe, tests, docs, and examples. -when_to_use: User asks to add, onboard, or integrate a new model family; 'add Qwen4 support', 'onboard Llama 5', 'create a bridge for X', 'write a recipe for Y'. +description: Guide for adding, onboarding, or integrating new LLM or VLM model families in Megatron-Bridge, including bridges, providers, recipes, tests, docs, and examples. --- # Adding New Model Support in Megatron-Bridge @@ -82,6 +81,38 @@ Also add or update focused tests when touching export/import quantization paths; ## Phase 2: Bridge Support +### Declare Transformers compatibility + +For every new bridge, determine the earliest verified upstream Transformers release that provides +the required model APIs. Test that boundary and the preceding release; do not copy the current +project pin or guess a minimum from the model publication date. + +Register architectures that may be absent at the package floor by string and +attach the verified compatibility contract: + +```python +@MegatronModelBridge.register_bridge( + source="NewModelForCausalLM", + target=GPTModel, + model_type="new_model", + min_transformers_version="5.8.0", + required_transformers_symbols=("transformers.NewModelConfig", "transformers.NewModelForCausalLM"), +) +class NewModelBridge(MegatronModelBridge): + ... +``` +- `min_transformers_version` is the earliest verified release that can support the bridge. Do not + leave it `None` for newly added model support. +- `required_transformers_symbols` lists exact dotted config, model, processor, or internal module + attributes the bridge uses. Add it for newly introduced or version-sensitive APIs so vendor, + development, and incomplete builds fail with an actionable Bridge error instead of a later import error. +- Required symbols need an explicit minimum. The version check runs first; the + symbols are resolved lazily only after that bridge is selected. +- Keep version-sensitive Transformers imports out of module scope. Import them + inside a guarded provider/load/export path after bridge selection. +- Do not add native Transformers symbols for `trust_remote_code` classes that + exist only in a model repository; validate those through the remote-code path. + ### File structure **LLM** — Reference: Qwen2 (`src/megatron/bridge/models/qwen/qwen2_bridge.py`) @@ -352,6 +383,16 @@ tests/functional_tests/test_groups/models// └── test__provider.py # compare_provider_configs (optional) ``` +Add compatibility boundary coverage for every new registration: + +- At the declared minimum, package/registry import and bridge selection pass. +- On the immediately preceding release, package/registry import still passes, + while selecting the incompatible bridge raises `TransformersVersionError` + before provider construction or model loading. +- Assert required-symbol failures separately from old-version failures. +- Run version cases in fresh subprocesses or immutable environments so + Transformers lazy modules and Bridge registration state cannot leak. + For detailed test patterns, see @skills/adding-model-support/tests-and-examples.md. ## Phase 5: Docs and Examples From ff8772ec0ac8c4b1273af68929b5b2535a49e6f9 Mon Sep 17 00:00:00 2001 From: Chen Cui Date: Wed, 1 Jul 2026 16:14:08 -0700 Subject: [PATCH 3/3] fix(model): guard DeepSeek V4 Transformers compatibility Signed-off-by: Chen Cui --- .../bridge/models/deepseek/deepseek_v4_bridge.py | 5 +++++ .../models/autobridge_registration_check.py | 16 +++++++++++++++- .../models/deepseek/test_deepseek_v4_bridge.py | 7 +++++++ .../test_qwen35_vl_default_conversion.py | 7 +++++-- .../megatron_mimo/test_qwen35_vl_provider.py | 7 +++++-- .../models/qwen_vl/test_qwen35_vl_provider.py | 7 ++++--- .../test_autobridge_registration_matrix.py | 5 ++++- 7 files changed, 45 insertions(+), 9 deletions(-) diff --git a/src/megatron/bridge/models/deepseek/deepseek_v4_bridge.py b/src/megatron/bridge/models/deepseek/deepseek_v4_bridge.py index ec73c0d38d..dccc3d4ec6 100644 --- a/src/megatron/bridge/models/deepseek/deepseek_v4_bridge.py +++ b/src/megatron/bridge/models/deepseek/deepseek_v4_bridge.py @@ -340,6 +340,11 @@ def __init__(self, megatron_param: str, hf_param: str) -> None: target=GPTModel, provider=MLAModelProvider, model_type="deepseek_v4", + min_transformers_version="5.8.0", + required_transformers_symbols=( + "transformers.DeepseekV4Config", + "transformers.DeepseekV4ForCausalLM", + ), ) class DeepSeekV4Bridge(MegatronModelBridge): """Megatron Bridge implementation for DeepSeek-V4 causal language models.""" diff --git a/tests/unit_tests/models/autobridge_registration_check.py b/tests/unit_tests/models/autobridge_registration_check.py index 06e9b10a7d..d9d4b930d2 100644 --- a/tests/unit_tests/models/autobridge_registration_check.py +++ b/tests/unit_tests/models/autobridge_registration_check.py @@ -20,10 +20,12 @@ import sys from typing import cast +from packaging.version import Version from transformers import PretrainedConfig from megatron.bridge import AutoBridge from megatron.bridge.models.conversion import model_bridge +from megatron.bridge.models.conversion.transformers_version import TransformersVersionError def main() -> None: @@ -46,7 +48,19 @@ def main() -> None: config.update({"auto_map": {"AutoModelForCausalLM": f"modeling_test.{architecture}"}}) assert AutoBridge.supports(config), f"AutoBridge rejected {architecture}" - bridge = AutoBridge.from_hf_config(config) + try: + bridge = AutoBridge.from_hf_config(config) + except TransformersVersionError as error: + bridge_class = model_bridge.get_registered_bridge_class(source_name=architecture) + assert bridge_class is not None, f"missing guarded bridge class index for {architecture}" + min_version = bridge_class.MIN_TRANSFORMERS_VERSION + assert min_version is not None, f"{architecture} raised without compatibility metadata" + assert error.model_name == architecture + assert error.required_version == Version(min_version) + if error.installed_version >= error.required_version: + assert error.missing_symbols == bridge_class.REQUIRED_TRANSFORMERS_SYMBOLS + continue + selected_bridge = bridge._model_bridge actual_bridge_class = f"{type(selected_bridge).__module__}.{type(selected_bridge).__name__}" diff --git a/tests/unit_tests/models/deepseek/test_deepseek_v4_bridge.py b/tests/unit_tests/models/deepseek/test_deepseek_v4_bridge.py index 529fe4ffb8..b0adf80796 100644 --- a/tests/unit_tests/models/deepseek/test_deepseek_v4_bridge.py +++ b/tests/unit_tests/models/deepseek/test_deepseek_v4_bridge.py @@ -97,6 +97,13 @@ def _deepseek_v4_hf_config(): class TestNativeDeepSeekV4ConfigTranslation: """Native Transformers DSv4 config fields must map back to MCore fields.""" + def test_bridge_declares_native_transformers_requirement(self): + assert DeepSeekV4Bridge.MIN_TRANSFORMERS_VERSION == "5.8.0" + assert DeepSeekV4Bridge.REQUIRED_TRANSFORMERS_SYMBOLS == ( + "transformers.DeepseekV4Config", + "transformers.DeepseekV4ForCausalLM", + ) + def test_compress_ratios_from_native_layer_types(self): hf_config = SimpleNamespace( num_hidden_layers=4, diff --git a/tests/unit_tests/models/megatron_mimo/conversion/test_qwen35_vl_default_conversion.py b/tests/unit_tests/models/megatron_mimo/conversion/test_qwen35_vl_default_conversion.py index 7d339bd9a4..6a88990712 100644 --- a/tests/unit_tests/models/megatron_mimo/conversion/test_qwen35_vl_default_conversion.py +++ b/tests/unit_tests/models/megatron_mimo/conversion/test_qwen35_vl_default_conversion.py @@ -19,6 +19,7 @@ import pytest from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.bridge.models.conversion.transformers_version import is_transformers_min_version from megatron.bridge.models.megatron_mimo.conversion import ( MIMOComponent, get_mimo_conversion_spec, @@ -32,10 +33,12 @@ ) from megatron.bridge.models.megatron_mimo.megatron_mimo_provider import MegatronMIMOProvider from megatron.bridge.models.qwen_vl.qwen35_vl_bridge import Qwen35VLBridge -from megatron.bridge.models.qwen_vl.qwen35_vl_provider import _TRANSFORMERS_HAS_QWEN3_5, Qwen35VLModelProvider +from megatron.bridge.models.qwen_vl.qwen35_vl_provider import Qwen35VLModelProvider -pytestmark = pytest.mark.skipif(not _TRANSFORMERS_HAS_QWEN3_5, reason="transformers does not have qwen3_5 support") +pytestmark = pytest.mark.skipif( + not is_transformers_min_version("5.2.0"), reason="transformers does not have qwen3_5 support" +) def _make_parallelism_config() -> MegatronMIMOParallelismConfig: diff --git a/tests/unit_tests/models/megatron_mimo/test_qwen35_vl_provider.py b/tests/unit_tests/models/megatron_mimo/test_qwen35_vl_provider.py index 6e40605fe2..af0a4c0c11 100644 --- a/tests/unit_tests/models/megatron_mimo/test_qwen35_vl_provider.py +++ b/tests/unit_tests/models/megatron_mimo/test_qwen35_vl_provider.py @@ -6,6 +6,7 @@ from megatron.core.models.mimo.submodules.vision import VisionModalitySubmodules from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.bridge.models.conversion.transformers_version import is_transformers_min_version from megatron.bridge.models.megatron_mimo.conversion.mimo_model_io import _clear_derived_spec_fields from megatron.bridge.models.megatron_mimo.megatron_mimo_config import ( MegatronMIMOParallelismConfig, @@ -14,13 +15,15 @@ from megatron.bridge.models.megatron_mimo.megatron_mimo_provider import MegatronMIMOProvider from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.text_model import Qwen3VLGPTModel from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.vision_model import Qwen3VLVisionModel -from megatron.bridge.models.qwen_vl.qwen35_vl_provider import _TRANSFORMERS_HAS_QWEN3_5, Qwen35VLModelProvider +from megatron.bridge.models.qwen_vl.qwen35_vl_provider import Qwen35VLModelProvider from megatron.bridge.training.config import ConfigContainer from megatron.bridge.utils.instantiate_utils import instantiate from megatron.bridge.utils.yaml_utils import safe_yaml_representers -pytestmark = pytest.mark.skipif(not _TRANSFORMERS_HAS_QWEN3_5, reason="transformers does not have qwen3_5 support") +pytestmark = pytest.mark.skipif( + not is_transformers_min_version("5.2.0"), reason="transformers does not have qwen3_5 support" +) def _make_language_provider(**overrides) -> Qwen35VLModelProvider: diff --git a/tests/unit_tests/models/qwen_vl/test_qwen35_vl_provider.py b/tests/unit_tests/models/qwen_vl/test_qwen35_vl_provider.py index de704b3434..5e1db473e0 100644 --- a/tests/unit_tests/models/qwen_vl/test_qwen35_vl_provider.py +++ b/tests/unit_tests/models/qwen_vl/test_qwen35_vl_provider.py @@ -19,10 +19,9 @@ from megatron.core.transformer.attention import SelfAttention from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.bridge.models.conversion.transformers_version import is_transformers_min_version from megatron.bridge.models.gpt_provider import GPTModelProvider from megatron.bridge.models.qwen_vl.qwen35_vl_provider import ( - _TRANSFORMERS_HAS_QWEN3_5, - _TRANSFORMERS_HAS_QWEN3_5_MOE, Qwen3VLSelfAttention, Qwen35VLModelProvider, Qwen35VLMoEModelProvider, @@ -30,6 +29,8 @@ ) +_TRANSFORMERS_HAS_QWEN3_5 = is_transformers_min_version("5.2.0") + pytestmark = pytest.mark.skipif(not _TRANSFORMERS_HAS_QWEN3_5, reason="transformers does not have qwen3_5 support") @@ -233,7 +234,7 @@ def test_patch_standard_attention_specs_recurses_into_mtp_specs(self): assert mtp_model_layer.submodules.self_attention.module is Qwen3VLSelfAttention -@pytest.mark.skipif(not _TRANSFORMERS_HAS_QWEN3_5_MOE, reason="transformers does not have qwen3_5_moe support") +@pytest.mark.skipif(not _TRANSFORMERS_HAS_QWEN3_5, reason="transformers does not have qwen3_5_moe support") class TestQwen35VLMoEModelProvider: """Tests for the MoE Qwen3.5 VL model provider.""" diff --git a/tests/unit_tests/models/test_autobridge_registration_matrix.py b/tests/unit_tests/models/test_autobridge_registration_matrix.py index 1ee32171de..05f51da938 100644 --- a/tests/unit_tests/models/test_autobridge_registration_matrix.py +++ b/tests/unit_tests/models/test_autobridge_registration_matrix.py @@ -103,6 +103,7 @@ "Gemma4ForCausalLM", "Gemma4ForConditionalGeneration", "Glm4MoeLiteForCausalLM", + "GlmMoeDsaForCausalLM", "KimiK25ForConditionalGeneration", "KimiK2ForCausalLM", "MiMoForCausalLM", @@ -113,7 +114,9 @@ "NemotronH_Nano_VL_V2", "NemotronLabsDiffusionModel", "Qwen3ASRForConditionalGeneration", + "Qwen3_5ForCausalLM", "Qwen3_5ForConditionalGeneration", + "Qwen3_5MoeForCausalLM", "Qwen3_5MoeForConditionalGeneration", "SarvamMLAForCausalLM", "SarvamMoEForCausalLM", @@ -123,7 +126,7 @@ def test_public_autobridge_import_registers_every_supported_model() -> None: - """The public package import must install every expected bridge registration.""" + """The public import must register every bridge and lazily reject incompatible ones.""" result = subprocess.run( [ sys.executable,