Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions transformer_engine/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import sys
import sysconfig
from typing import Optional, Tuple
import warnings


@functools.lru_cache(maxsize=None)
Expand Down Expand Up @@ -191,6 +192,21 @@ def load_framework_extension(framework: str) -> None:
sys.modules[module_name] = solib
spec.loader.exec_module(solib)

# Plugin system: if NVTE_ENABLE_PLUGIN=1, let plugin stub take over
# transformer_engine_torch and register original pybind as _nv for CUDA backend.
# Only applies to the PyTorch extension — JAX has no plugin stub.
if os.environ.get("NVTE_ENABLE_PLUGIN", "0") == "1" and framework == "torch":
sys.modules[module_name + "_nv"] = solib
try:
from transformer_engine_plugin_fl import load_plugins

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we instead have a separate env variable with the name of the module rather than hardcoding FlagOS version of it? Or maybe we could have 1 env variable NVTE_PLUGIN which could be set to the name of the plugin module (and setting it to None or not setting it at all would disable plugin functionality)?

load_plugins()
except ImportError as e:
warnings.warn(
f"NVTE_ENABLE_PLUGIN=1 but plugin import failed: {e}",
ImportWarning,
stacklevel=2,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Plugin initialization errors beyond ImportError left uncaught

sys.modules[module_name + "_nv"] is registered unconditionally at line 199 before load_plugins() runs. If transformer_engine_plugin_fl is installed but load_plugins() itself raises a non-ImportError (e.g., RuntimeError during backend registration, an AttributeError inside the plugin, or an OSError loading a shared library), the exception escapes the try/except block, crashes load_framework_extension, and leaves sys.modules in a partially inconsistent state: transformer_engine_torch_nv exists (native pybind) but transformer_engine_torch was never replaced by the plugin stub. TE initialization fails with an opaque traceback instead of the intended graceful fallback message.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

Comment on lines +200 to +216

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Incomplete sys.modules rollback on load_plugins() failure

sys.modules[module_name] (i.e., transformer_engine_torch) is set to solib at line 192, before this try block. If load_plugins() partially succeeds — for example, replaces sys.modules["transformer_engine_torch"] with the plugin stub before raising a RuntimeError during backend registration — the except block pops _nv but leaves sys.modules["transformer_engine_torch"] pointing to the partially-initialized stub. TE then continues with a broken tex module even though the warning says plugin loading failed.

Capture the pre-attempt value of sys.modules.get(module_name) before calling load_plugins() and restore it in the except block alongside the _nv pop.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed



def sanity_checks_for_pypi_installation() -> None:
"""Ensure that package is installed correctly if using PyPI."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,16 @@
_dpa_fp8ds_reduce_amax = os.getenv("NVTE_DPA_FP8DS_REDUCE_AMAX", "1") == "1"


# Plugin system: override FlashAttention and get_attention_backend if enabled
if os.environ.get("NVTE_ENABLE_PLUGIN", "0") == "1":
_FlashAttentionNative = FlashAttention
FlashAttention = getattr(tex, "flash_attention", _FlashAttentionNative)
_plugin_get_attention_backend = getattr(tex, "get_attention_backend", None)
if _plugin_get_attention_backend is not None:
dpa_utils._original_get_attention_backend = dpa_utils.get_attention_backend
dpa_utils.get_attention_backend = _plugin_get_attention_backend


__all__ = ["DotProductAttention"]


Expand Down