diff --git a/jax-inference-offloading/examples/example-standalone.sh b/jax-inference-offloading/examples/example-standalone.sh new file mode 100755 index 000000000..7a2385198 --- /dev/null +++ b/jax-inference-offloading/examples/example-standalone.sh @@ -0,0 +1,273 @@ +#!/bin/bash +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standalone example: weight transfer + rollout generation without Tunix. +# This script launches the gateway, vLLM rollout worker, and trainer_standalone.py. + +set -euo pipefail + +DIR="$(dirname "$0")" +JAX_COMPILATION_CACHE_DIR=${JAX_COMPILATION_CACHE_DIR:-/tmp/jax-compilation-cache} +mkdir -p ${JAX_COMPILATION_CACHE_DIR} + +# Set default values +DEBUG="false" +OUTPUT_DIR=${OUTPUT_DIR:-$(mktemp -d)} + +# Model configuration +MODEL_NAME="" +MODEL_PATH="" +PARAM_MAPPING_PATH="" + +# Transfer mode +TRANSFER_MODE="" + +# vLLM runtime +VLLM_ENFORCE_EAGER="0" +VLLM_GPU_MEMORY_UTILIZATION="0.9" + +# Debug-only: use dummy weights for JAX model +USE_DUMMY_WEIGHT="0" + +# Debug-only: skip weight transfer to test if vLLM is working correctly +SKIP_TRANSFER="0" + +# Device assignment +N_GPUS_VLLM="4" +N_GPUS_JAX="4" + +# Gateway +GATEWAY_PORT="50051" + +# Parse command line arguments +while [[ $# -gt 0 ]]; do + case "$1" in + # General + --debug) + DEBUG="true" + shift + ;; + --output-dir=*) + OUTPUT_DIR="${1#*=}" + shift + ;; + # Model configuration + --model-name=*) + MODEL_NAME="${1#*=}" + shift + ;; + --model-path=*) + MODEL_PATH="${1#*=}" + shift + ;; + --param-mapping-path=*) + PARAM_MAPPING_PATH="${1#*=}" + shift + ;; + # Transfer mode + --transfer-mode=*) + TRANSFER_MODE="${1#*=}" + shift + ;; + # vLLM runtime + --vllm-enforce-eager) + VLLM_ENFORCE_EAGER="1" + shift + ;; + --no-vllm-enforce-eager) + VLLM_ENFORCE_EAGER="0" + shift + ;; + --vllm-gpu-memory-utilization=*) + VLLM_GPU_MEMORY_UTILIZATION="${1#*=}" + shift + ;; + --use-dummy-weight) + USE_DUMMY_WEIGHT="1" + shift + ;; + --skip-transfer) + SKIP_TRANSFER="1" + shift + ;; + # Device assignment + --n-gpus-vllm=*) + N_GPUS_VLLM="${1#*=}" + shift + ;; + --n-gpus-jax=*) + N_GPUS_JAX="${1#*=}" + shift + ;; + # Gateway + --gateway-port=*) + GATEWAY_PORT="${1#*=}" + shift + ;; + --help) + echo "Usage: $0 [OPTIONS]" + echo "" + echo "Standalone example: weight transfer + rollout generation." + echo "" + echo "This example uses Tunix for model loading, but trainer_standalone.py" + echo "demonstrates how to integrate with custom RL frameworks. See the comments" + echo "in trainer_standalone.py for guidance on replacing the model loading code." + echo "" + echo "Options:" + echo " --debug Enable debug mode with verbose logging." + echo " --output-dir=DIR Directory to save logs and outputs. Default is a temporary directory." + echo "" + echo " --model-name=NAME HF model name (required by Tunix's loader for architecture selection)." + echo " --model-path=PATH HF snapshot directory containing model weights." + echo " --param-mapping-path=PATH Path to JSON param mapping file (optional, uses hardcoded if not set)." + echo "" + echo " --transfer-mode=MODE Transfer mode for trainer->vLLM weights (grouped/stacked/fused/unfused)." + echo "" + echo " --vllm-enforce-eager Force vLLM eager mode (sets VLLM_ENFORCE_EAGER=1)." + echo " --no-vllm-enforce-eager Disable vLLM eager mode (sets VLLM_ENFORCE_EAGER=0)." + echo " --vllm-gpu-memory-utilization=FLOAT vLLM GPU memory utilization (e.g., 0.7)." + echo " --use-dummy-weight Use randomly initialized JAX weights (DEBUG ONLY)." + echo "" + echo " --n-gpus-vllm=N Number of GPUs for vLLM (default: 4)." + echo " --n-gpus-jax=N Number of GPUs for JAX (default: 4)." + echo "" + echo " --gateway-port=PORT gRPC gateway port (default: 50051)." + echo " --help Show this help message and exit." + exit 0 + ;; + *) + echo "Unknown argument: $1" + shift + ;; + esac +done + +# Model selection default +MODEL_NAME=${MODEL_NAME:-"meta-llama/Llama-3.2-1B-Instruct"} + +# ------------------------------------------------------------------------------ +# Kill all processes when done. +# ------------------------------------------------------------------------------ +trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM EXIT + +# ------------------------------------------------------------------------------ +# load environment variables from .env file +# ------------------------------------------------------------------------------ +if [[ -f "${PWD}/.env" ]]; then + echo "Loading ${PWD}/.env" + set -a && source "${PWD}/.env" && set +a +else + echo ".env not found in ${PWD}, skipping" +fi + +# ------------------------------------------------------------------------------ +# Ensure model is already present on disk (download only when using real weights) +# ------------------------------------------------------------------------------ + +if [[ -z "${HF_TOKEN:-}" ]]; then + echo "HF_TOKEN is not set. Please set it in the .env file or export it." +fi + +if [[ "${USE_DUMMY_WEIGHT}" == "1" ]]; then + echo "Using dummy weights for JAX model (DEBUG ONLY)." + MODEL_PATH= +else + if [[ -n "${MODEL_PATH:-}" ]]; then + echo "Using provided MODEL_PATH: ${MODEL_PATH}" + else + echo "MODEL_PATH not provided, downloading HF snapshot..." + MODEL_PATH=$(python "${DIR}/download_model.py" --hub=hf --model="${MODEL_NAME}" --ignore="*.pth") + fi +fi + +# ------------------------------------------------------------------------------ +# assign GPUs to vLLM and JAX +# ------------------------------------------------------------------------------ +N_GPUS=$((N_GPUS_VLLM + N_GPUS_JAX)) + +# Derive CUDA_VISIBLE_DEVICES_ARRAY +if [[ -z "${CUDA_VISIBLE_DEVICES:-}" ]]; then + CUDA_VISIBLE_DEVICES_ARRAY=($(seq 0 $((N_GPUS - 1)))) +else + IFS=',' read -r -a CUDA_VISIBLE_DEVICES_ARRAY <<< "$CUDA_VISIBLE_DEVICES" +fi + +VLLM_GPU_ARRAY=("${CUDA_VISIBLE_DEVICES_ARRAY[@]:0:N_GPUS_VLLM}") +JAX_GPU_ARRAY=("${CUDA_VISIBLE_DEVICES_ARRAY[@]:N_GPUS_VLLM:N_GPUS}") + +# ------------------------------------------------------------------------------ +# common environment +# ------------------------------------------------------------------------------ +export CUDA_DEVICE_ORDER=PCI_BUS_ID +export CUDA_DEVICE_MAX_CONNECTIONS=16 +export NCCL_BUFFSIZE=16777216 +export GATEWAY_PORT +export GATEWAY_URL="localhost:${GATEWAY_PORT}" +export MODEL_NAME +export MODEL_PATH +export PARAM_MAPPING_PATH +export USE_DUMMY_WEIGHT +export SKIP_TRANSFER +export VLLM_ENFORCE_EAGER +export VLLM_GPU_MEMORY_UTILIZATION +export XLA_PYTHON_CLIENT_MEM_FRACTION=0.95 +export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true + --xla_gpu_enable_command_buffer=FUSION,CUBLAS,CUDNN,CUSTOM_CALL + --xla_gpu_collective_permute_combine_threshold_bytes=8589934592 + --xla_gpu_reduce_scatter_combine_threshold_bytes=8589934592 + --xla_gpu_all_gather_combine_threshold_bytes=8589934592 + --xla_gpu_all_reduce_combine_threshold_bytes=8589934592" +if [[ -n "${TRANSFER_MODE:-}" ]]; then + export TRANSFER_MODE +fi + +if [ "$DEBUG" == "true" ]; then + set -x + export TF_CPP_MIN_LOG_LEVEL=0 + export NCCL_DEBUG=INFO # Enable NCCL debug logs +else + export TF_CPP_MIN_LOG_LEVEL=2 # Suppress TensorFlow debug logs + export VLLM_CONFIGURE_LOGGING=0 # Suppress vLLM logging +fi + +PIDS=() + +mkdir -p "${OUTPUT_DIR}" +echo "Logs will be saved to: ${OUTPUT_DIR}" + +# ------------------------------------------------------------------------------ +# Launch components +# ------------------------------------------------------------------------------ + +# Gateway server (no GPU) +CUDA_VISIBLE_DEVICES= \ +python "${DIR}/../jax_inference_offloading/controller/gateway.py" 2>&1 | tee "${OUTPUT_DIR}/gateway.log" & +PIDS+=($!) + +# vLLM rollout worker +CUDA_VISIBLE_DEVICES=$(IFS=','; echo "${VLLM_GPU_ARRAY[*]}") \ +MODEL_NAME=${MODEL_PATH:-$MODEL_NAME} \ +python "${DIR}/rollout.py" 2>&1 | tee "${OUTPUT_DIR}/rollout.log" & +PIDS+=($!) + +# Standalone trainer (weight transfer + generation demo) +CUDA_VISIBLE_DEVICES=$(IFS=','; echo "${JAX_GPU_ARRAY[*]}") \ +JAX_COMPILATION_CACHE_DIR=${JAX_COMPILATION_CACHE_DIR} \ +JAX_PERSISTENT_CACHE_MIN_COMPILE_TIME_SECS=0.1 \ +python "${DIR}/trainer_standalone.py" 2>&1 | tee "${OUTPUT_DIR}/trainer.log" & +PIDS+=($!) + +wait "${PIDS[@]}" diff --git a/jax-inference-offloading/examples/mappings/llama3_1b_param_mapping.json b/jax-inference-offloading/examples/mappings/llama3_1b_param_mapping.json new file mode 100644 index 000000000..fd902f1e3 --- /dev/null +++ b/jax-inference-offloading/examples/mappings/llama3_1b_param_mapping.json @@ -0,0 +1,157 @@ +{ + "mesh_axes": ["fsdp", "tp"], + "num_layers": 16, + "mappings": [ + { + "jax_param": { + "name": "embedder.input_embedding" + }, + "vllm_param": { + "name": "model.embed_tokens.weight", + "shape": [128256, 2048] + } + }, + { + "jax_param": { + "name": "final_norm.w" + }, + "vllm_param": { + "name": "model.norm.weight", + "shape": [2048] + } + }, + { + "jax_param": { + "name": "layers.{layer}.input_layernorm.w" + }, + "vllm_param": { + "name": "model.layers.{layer}.input_layernorm.weight", + "shape": [2048] + } + }, + { + "jax_param": { + "name": "layers.{layer}.post_attention_layernorm.w" + }, + "vllm_param": { + "name": "model.layers.{layer}.post_attention_layernorm.weight", + "shape": [2048] + } + }, + { + "jax_param": { + "name": "layers.{layer}.mlp.gate_proj.kernel", + "transform": { + "transpose": [1, 0] + } + }, + "vllm_param": { + "name": "model.layers.{layer}.mlp.gate_proj.weight", + "shape": [8192, 2048], + "transform": { + "transpose": [1, 0] + } + } + }, + { + "jax_param": { + "name": "layers.{layer}.mlp.up_proj.kernel", + "transform": { + "transpose": [1, 0] + } + }, + "vllm_param": { + "name": "model.layers.{layer}.mlp.up_proj.weight", + "shape": [8192, 2048], + "transform": { + "transpose": [1, 0] + } + } + }, + { + "jax_param": { + "name": "layers.{layer}.mlp.down_proj.kernel", + "transform": { + "transpose": [1, 0] + } + }, + "vllm_param": { + "name": "model.layers.{layer}.mlp.down_proj.weight", + "shape": [2048, 8192], + "transform": { + "transpose": [1, 0] + } + } + }, + { + "jax_param": { + "name": "layers.{layer}.attn.q_proj.w", + "transform": { + "transpose": [1, 2, 0], + "reshape": [-1, 2048] + } + }, + "vllm_param": { + "name": "model.layers.{layer}.self_attn.q_proj.weight", + "shape": [2048, 2048], + "transform": { + "transpose": [1, 0], + "reshape": [2048, 32, 64] + } + } + }, + { + "jax_param": { + "name": "layers.{layer}.attn.k_proj.w", + "transform": { + "transpose": [1, 2, 0], + "reshape": [-1, 2048], + "replication_axis": 2 + } + }, + "vllm_param": { + "name": "model.layers.{layer}.self_attn.k_proj.weight", + "shape": [512, 2048], + "transform": { + "transpose": [1, 0], + "reshape": [2048, 8, 64] + } + } + }, + { + "jax_param": { + "name": "layers.{layer}.attn.v_proj.w", + "transform": { + "transpose": [1, 2, 0], + "reshape": [-1, 2048], + "replication_axis": 2 + } + }, + "vllm_param": { + "name": "model.layers.{layer}.self_attn.v_proj.weight", + "shape": [512, 2048], + "transform": { + "transpose": [1, 0], + "reshape": [2048, 8, 64] + } + } + }, + { + "jax_param": { + "name": "layers.{layer}.attn.o_proj.w", + "transform": { + "transpose": [2, 0, 1], + "reshape": [2048, -1] + } + }, + "vllm_param": { + "name": "model.layers.{layer}.self_attn.o_proj.weight", + "shape": [2048, 2048], + "transform": { + "transpose": [1, 0], + "reshape": [32, 64, 2048] + } + } + } + ] +} diff --git a/jax-inference-offloading/examples/rollout.py b/jax-inference-offloading/examples/rollout.py index 61e91df50..00b069cf4 100644 --- a/jax-inference-offloading/examples/rollout.py +++ b/jax-inference-offloading/examples/rollout.py @@ -40,6 +40,8 @@ def main(): model_name = os.environ.get("MODEL_NAME", "meta-llama/Llama-3.1-8B-Instruct") model_path = os.environ.get("MODEL_PATH", None) model = model_path or model_name + # Optional: path to custom param_mapping.json for JAX-to-vLLM parameter mapping + param_mapping_path = os.environ.get("PARAM_MAPPING_PATH", None) logging.basicConfig(level=logging.INFO) @@ -58,7 +60,7 @@ def main(): # subscribe to control messages from the gateway rollout_client = make_rollout_client(gateway_url) - rollout_client.subscribe_to_control_messages(llm) + rollout_client.subscribe_to_control_messages(llm, mapping_json_path=param_mapping_path) if __name__ == "__main__": diff --git a/jax-inference-offloading/examples/trainer.py b/jax-inference-offloading/examples/trainer.py index 63565e7bd..0687bfff6 100644 --- a/jax-inference-offloading/examples/trainer.py +++ b/jax-inference-offloading/examples/trainer.py @@ -29,7 +29,7 @@ from jax_inference_offloading.jax import OffloadingBridge from jax_inference_offloading.sharding import PolymorphicMesh from jax_inference_offloading.timer import Timer -from jax_inference_offloading.tunix.load_model import load_model +from jax_inference_offloading.integrations.tunix.load_model import load_model from jax_inference_offloading.models import get_named_parameters # logging.basicConfig(level=logging.INFO) diff --git a/jax-inference-offloading/examples/trainer_grpo.py b/jax-inference-offloading/examples/trainer_grpo.py index c7dca3811..ff36eb66b 100644 --- a/jax-inference-offloading/examples/trainer_grpo.py +++ b/jax-inference-offloading/examples/trainer_grpo.py @@ -35,8 +35,8 @@ from tunix.rl.rollout import base_rollout from jax_inference_offloading.timer import Timer -from jax_inference_offloading.tunix.load_model import load_model -from jax_inference_offloading.tunix.rollout import VllmGPURollout +from jax_inference_offloading.integrations.tunix.load_model import load_model +from jax_inference_offloading.integrations.tunix.rollout import VllmGPURollout logger = logging.getLogger(__name__) timer = Timer() diff --git a/jax-inference-offloading/examples/trainer_standalone.py b/jax-inference-offloading/examples/trainer_standalone.py new file mode 100644 index 000000000..9def7523b --- /dev/null +++ b/jax-inference-offloading/examples/trainer_standalone.py @@ -0,0 +1,224 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Standalone example: Using jax-inference-offloading without Tunix. + +This example demonstrates how to use OffloadingSession, VLLMTransferEngine, +and VLLMRolloutEngine directly without depending on any external framework. +This is useful for: +- Custom RL training loops +- Integration with other frameworks (OpenRLHF, TRL, custom, etc.) +- Testing and benchmarking + +The architecture separates concerns: +- MappingConfig: Loads mesh_axes and parameter mappings from JSON +- load_checkpoint_to_jax: Loads HuggingFace checkpoints into JAX arrays +- OffloadingSession: Handles gRPC connection and handshake +- VLLMTransferEngine: Handles weight transfer from JAX to vLLM +- VLLMRolloutEngine: Handles inference/rollout generation + +Prerequisites: +- Gateway server running (python -m jax_inference_offloading.controller.gateway) +- vLLM rollout worker running (python examples/rollout.py) + +Environment variables: +- GATEWAY_URL: URL of the gateway server (e.g., "localhost:50051") +- MODEL_PATH: Path to HuggingFace model checkpoint (required) +- PARAM_MAPPING_PATH: Path to JSON parameter mapping file (required) +""" + +import os + +import jax +import jax.numpy as jnp + +# Framework-agnostic imports from jax-inference-offloading +from jax_inference_offloading import ( + InferenceConfig, + OffloadingSession, + VLLMRolloutEngine, + VLLMTransferEngine, +) +from jax_inference_offloading.timer import Timer +from jax_inference_offloading.models.checkpoint import ( + load_mapping_config, + load_checkpoint_to_jax, +) + +from transformers import AutoTokenizer + +timer = Timer() + +# --- Configuration --- +model_path = os.environ.get("MODEL_PATH", None) +param_mapping_path = os.environ.get("PARAM_MAPPING_PATH", None) +gateway_url = os.environ.get("GATEWAY_URL", "localhost:50051") +transfer_mode = os.environ.get("TRANSFER_MODE", "grouped") + +# Validate: both model_path and param_mapping_path are required +if model_path is None: + raise ValueError("MODEL_PATH environment variable is required") +if param_mapping_path is None: + raise ValueError("PARAM_MAPPING_PATH environment variable is required") + +# Load tokenizer for pad_id +tokenizer = AutoTokenizer.from_pretrained(model_path) +if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = tokenizer.eos_token_id + +# --- Load Mapping Config --- +# This reads the JSON mapping file to get mesh_axes and parameter mappings +if jax.process_index() == 0: + print(f"Loading mapping config from {param_mapping_path}") + +with timer.section("load_mapping_config"): + mapping_config = load_mapping_config(param_mapping_path) + +# Create mesh using axis names from the mapping config +mesh_shape = (jax.process_count(), jax.local_device_count()) +mesh_axes = tuple(mapping_config.mesh_axes) +if len(mesh_shape) != len(mesh_axes): + raise ValueError( + f"Mesh shape {mesh_shape} does not match mesh_axes {mesh_axes}. " + f"Expected {len(mesh_shape)} axes, got {len(mesh_axes)}." + ) +mesh = jax.make_mesh(mesh_shape, mesh_axes) + +if jax.process_index() == 0: + print(f"Created mesh with shape {mesh_shape} and axes {mesh_axes}") + +# --- Load Checkpoint --- +# Load HuggingFace checkpoint directly into JAX arrays using the mapping +if jax.process_index() == 0: + print(f"Loading checkpoint from {model_path}") + +with timer.section("load_checkpoint"): + params = load_checkpoint_to_jax( + checkpoint_path=model_path, + mapping_specs=mapping_config.mapping_specs, + mesh=mesh, + dtype=jnp.bfloat16, + ) + +if jax.process_index() == 0: + print(f"Loaded {len(params)} parameters") + +# --- Create OffloadingSession and Engines --- +if jax.process_index() == 0: + print(f"Creating OffloadingSession with gateway_url={gateway_url}") + +with timer.section("create_session"): + session = OffloadingSession( + gateway_url=gateway_url, + mesh=mesh, + model_path=model_path, + param_mapping_path=param_mapping_path, + ) + +if jax.process_index() == 0: + print("Creating VLLMTransferEngine and VLLMRolloutEngine...") + +with timer.section("create_engines"): + transfer_engine = VLLMTransferEngine( + session=session, + transfer_mode=transfer_mode, + timer=timer, + ) + rollout_engine = VLLMRolloutEngine( + session=session, + timer=timer, + ) + +# --- Transfer Weights --- +SKIP_TRANSFER = os.environ.get("SKIP_TRANSFER", "0") == "1" + +if SKIP_TRANSFER: + if jax.process_index() == 0: + print("SKIPPING weight transfer (SKIP_TRANSFER=1)") +else: + if jax.process_index() == 0: + print("Transferring weights to vLLM...") + + with timer.section("warmup_transfer"): + transfer_engine.update_weights(params) + + if jax.process_index() == 0: + print("Weights transferred successfully!") + + # --- Benchmark weight transfer --- + for r in range(3): + with timer.section(f"transfer.run{r}"): + transfer_engine.update_weights(params) + +# --- Generate Completions --- +if jax.process_index() == 0: + print("\n" + "=" * 80) + print("Generating completions...") + print("=" * 80) + + # Example 1: Simple text prompt + config = InferenceConfig( + max_tokens=256, + temperature=0.7, + top_p=0.95, + ) + output = rollout_engine.generate(["Messi's barcelona career:"], config) + + print("\n--- Text Prompt ---") + print(f"Prompt: Messi's barcelona career:") + print(f"Response: {output.texts[0]}") + + # Example 2: Multiple prompts with multiple outputs per prompt + config_multi = InferenceConfig( + max_tokens=100, + temperature=0.9, + top_p=0.95, + n=2, # Generate 2 completions per prompt + ) + prompts = [ + "Messi's barcelona career:", + "Name a color:", + ] + output = rollout_engine.generate(prompts, config_multi) + + print("\n--- Multiple Prompts (n=2) ---") + for i, completion in enumerate(output.completions): + print(f"\nCompletion {i + 1}:") + print(f" Text: {completion.text[:100]}...") + print(f" Token count: {len(completion.token_ids)}") + + # Example 3: Using to_arrays() for training + arrays = output.to_arrays( + max_prompt_length=64, + max_completion_length=100, + pad_id=tokenizer.pad_token_id, + ) + print("\n--- Arrays for Training ---") + print(f" prompt_tokens shape: {arrays['prompt_tokens'].shape}") + print(f" completion_tokens shape: {arrays['completion_tokens'].shape}") + +# --- Print timing summary --- +if jax.process_index() == 0: + print("\n" + "=" * 80) + print("Timing Summary") + print("=" * 80) + timer.summary(sort_by="name", precision=3) + +# --- Shutdown --- +rollout_engine.shutdown() +session.shutdown() +if jax.process_index() == 0: + print("\nShutdown complete. Exiting.") diff --git a/jax-inference-offloading/jax_inference_offloading/__init__.py b/jax-inference-offloading/jax_inference_offloading/__init__.py new file mode 100644 index 000000000..8e83bbccc --- /dev/null +++ b/jax-inference-offloading/jax_inference_offloading/__init__.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""JAX-vLLM Inference Offloading Bridge. + +This package provides infrastructure for offloading inference/rollout +generation from JAX training to vLLM, enabling efficient RL post-training. + +Quick Start: + >>> from jax_inference_offloading import ( + ... OffloadingSession, + ... VLLMTransferEngine, + ... VLLMRolloutEngine, + ... InferenceConfig, + ... ) + >>> + >>> # Create session (handles gRPC connection and handshake) + >>> session = OffloadingSession( + ... gateway_url="localhost:50051", + ... mesh=jax.make_mesh((8,), ("tp",)), + ... model_name="meta-llama/Llama-3.1-8B-Instruct", + ... ) + >>> + >>> # Create separate engines for transfer and inference + >>> transfer_engine = VLLMTransferEngine(session) + >>> rollout_engine = VLLMRolloutEngine(session) + >>> + >>> # Transfer weights and generate + >>> transfer_engine.update_weights(my_params) + >>> output = rollout_engine.generate(prompts, InferenceConfig(max_tokens=128)) +""" + +# Core API types +from jax_inference_offloading.api import ( + CompletionOutput, + InferenceConfig, + InferenceOutput, +) + +# Session +from jax_inference_offloading.session import OffloadingSession + +# Engine implementations +from jax_inference_offloading.engines import VLLMRolloutEngine, VLLMTransferEngine + +# Low-level access (for advanced users / backward compatibility) +from jax_inference_offloading.jax import OffloadingBridge + +__all__ = [ + # Core API types + "CompletionOutput", + "InferenceConfig", + "InferenceOutput", + # Session + "OffloadingSession", + # Engines + "VLLMRolloutEngine", + "VLLMTransferEngine", + # Advanced / Legacy + "OffloadingBridge", +] diff --git a/jax-inference-offloading/jax_inference_offloading/api/__init__.py b/jax-inference-offloading/jax_inference_offloading/api/__init__.py new file mode 100644 index 000000000..4a8387cab --- /dev/null +++ b/jax-inference-offloading/jax_inference_offloading/api/__init__.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Public API types for jax-inference-offloading.""" + +from jax_inference_offloading.api.types import ( + CompletionOutput, + InferenceConfig, + InferenceOutput, + pad_left, + pad_right, +) + +__all__ = [ + "CompletionOutput", + "InferenceConfig", + "InferenceOutput", + "pad_left", + "pad_right", +] diff --git a/jax-inference-offloading/jax_inference_offloading/api/controller.proto b/jax-inference-offloading/jax_inference_offloading/api/controller.proto index 6fa1b3f2d..eed119188 100644 --- a/jax-inference-offloading/jax_inference_offloading/api/controller.proto +++ b/jax-inference-offloading/jax_inference_offloading/api/controller.proto @@ -60,6 +60,7 @@ message HandshakeRequest { reserved 2; optional JaxParallelism jax_parallelism = 3; optional string model_name = 4; + optional string param_mapping_path = 5; } message HandshakeResponse { optional TpModelMappingSpecs mapping_specs = 1; diff --git a/jax-inference-offloading/jax_inference_offloading/api/param_mapping.proto b/jax-inference-offloading/jax_inference_offloading/api/param_mapping.proto index d4a70a589..577bb9827 100644 --- a/jax-inference-offloading/jax_inference_offloading/api/param_mapping.proto +++ b/jax-inference-offloading/jax_inference_offloading/api/param_mapping.proto @@ -29,6 +29,16 @@ message VllmParam { optional TpSharding tp_sharding = 3; optional string dtype = 4; + + // Transformation from the vLLM/HuggingFace parameter to the JAX parameter. + // Used when loading checkpoints into JAX model state. + message Transform { + optional TensorSlice slice = 1; + repeated int32 transpose = 2; + repeated int32 reshape = 3; + } + + optional Transform transform = 5; } message TensorSlice { @@ -72,7 +82,9 @@ message JaxParam { repeated string axes = 1; } - optional PartitionSpecs partition_specs = 3; // TODO: maybe not needed + // Partition spec for sharding the JAX parameter when loading from checkpoint. + // Each element is an axis name from mesh_axes, or empty string for unsharded dimensions. + optional PartitionSpecs partition_specs = 3; } message ParamMapping { @@ -82,4 +94,6 @@ message ParamMapping { message TpModelMappingSpecs { repeated ParamMapping mappings = 1; + // Axis names for JAX mesh creation, e.g., ["fsdp", "tp"] + repeated string mesh_axes = 2; } \ No newline at end of file diff --git a/jax-inference-offloading/jax_inference_offloading/api/types.py b/jax-inference-offloading/jax_inference_offloading/api/types.py new file mode 100644 index 000000000..54412215a --- /dev/null +++ b/jax-inference-offloading/jax_inference_offloading/api/types.py @@ -0,0 +1,174 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Framework-agnostic types for inference offloading.""" + +from dataclasses import dataclass, field +from typing import Dict, List, Optional + +import numpy as np + + +def pad_left(seq: list, length: int, pad_value) -> list: + """Left-pad a sequence to the specified length. + + Args: + seq: Sequence to pad. + length: Target length. + pad_value: Value to use for padding. + + Returns: + Left-padded sequence. + + Raises: + AssertionError: If sequence is longer than target length. + """ + assert len(seq) <= length, f"Sequence too long: {len(seq)} > {length}" + return [pad_value] * (length - len(seq)) + list(seq) + + +def pad_right(seq: list, length: int, pad_value) -> list: + """Right-pad a sequence to the specified length. + + Args: + seq: Sequence to pad. + length: Target length. + pad_value: Value to use for padding. + + Returns: + Right-padded sequence. + + Raises: + AssertionError: If sequence is longer than target length. + """ + assert len(seq) <= length, f"Sequence too long: {len(seq)} > {length}" + return list(seq) + [pad_value] * (length - len(seq)) + + +@dataclass(frozen=True) +class InferenceConfig: + """Framework-agnostic inference configuration. + + Maps to vLLM SamplingParams. + + Attributes: + max_tokens: Maximum number of tokens to generate per output sequence. + temperature: Temperature for sampling. 0.0 = greedy, higher = more random. + top_p: Top-p (nucleus) sampling. 1.0 = no filtering. + top_k: Top-k sampling. -1 = no filtering. + n: Number of output sequences per prompt (for best-of-n, GRPO groups, etc.). + seed: Random seed for reproducibility. + stop_token_ids: Stop token IDs (e.g., EOS tokens). + """ + + max_tokens: int = 64 + temperature: float = 0.9 + top_p: float = 1.0 + top_k: int = -1 + n: int = 1 + seed: Optional[int] = None + stop_token_ids: List[int] = field(default_factory=list) + + +@dataclass +class CompletionOutput: + """Single completion/output from the model. + + Attributes: + text: Generated text. + token_ids: Generated token IDs. + logprobs: Log probabilities per generated token (optional). + prompt_token_ids: Prompt token IDs (useful for log-prob calculations). + """ + + text: str + token_ids: List[int] + logprobs: Optional[List[float]] = None + prompt_token_ids: Optional[List[int]] = None + + +@dataclass +class InferenceOutput: + """Output from inference/rollout generation. + + Contains one or more CompletionOutput per prompt (based on config.n). + + Attributes: + completions: List of completions, flattened across all prompts. + Length = num_prompts * config.n + """ + + completions: List[CompletionOutput] + + @property + def texts(self) -> List[str]: + """Get all generated texts.""" + return [c.text for c in self.completions] + + @property + def token_ids(self) -> List[List[int]]: + """Get all generated token ID sequences.""" + return [c.token_ids for c in self.completions] + + def to_arrays( + self, + max_prompt_length: int, + max_completion_length: int, + pad_id: int, + ) -> Dict[str, np.ndarray]: + """Convert to padded numpy arrays for training. + + Args: + max_prompt_length: Maximum prompt length (for left-padding). + max_completion_length: Maximum completion length (for right-padding). + pad_id: Padding token ID. + + Returns: + dict with keys: + - 'prompt_tokens': [batch, max_prompt_length] left-padded + - 'completion_tokens': [batch, max_completion_length] right-padded + - 'completion_logprobs': [batch, max_completion_length] if available + + Raises: + AssertionError: If any sequence exceeds the specified max length. + """ + result: Dict[str, np.ndarray] = { + "prompt_tokens": np.array( + [ + pad_left(c.prompt_token_ids or [], max_prompt_length, pad_id) + for c in self.completions + ], + dtype=np.int32, + ), + "completion_tokens": np.array( + [ + pad_right(c.token_ids, max_completion_length, pad_id) + for c in self.completions + ], + dtype=np.int32, + ), + } + + if all(c.logprobs is not None for c in self.completions): + result["completion_logprobs"] = np.array( + [ + pad_right(c.logprobs or [], max_completion_length, 0.0) + for c in self.completions + ], + dtype=np.float32, + ) + + return result diff --git a/jax-inference-offloading/jax_inference_offloading/controller/rollout_client.py b/jax-inference-offloading/jax_inference_offloading/controller/rollout_client.py index 2f7900491..a45c7f3b2 100644 --- a/jax-inference-offloading/jax_inference_offloading/controller/rollout_client.py +++ b/jax-inference-offloading/jax_inference_offloading/controller/rollout_client.py @@ -36,11 +36,12 @@ class RolloutServicer: - def __init__(self, llm): + def __init__(self, llm, mapping_json_path=None): llm.collective_rpc("set_sharding") self._llm = llm self._tok = llm.get_tokenizer() + self._mapping_json_path = mapping_json_path @staticmethod def as_proto(vllm_response) -> ctrl.InferenceResponse: @@ -62,7 +63,10 @@ def from_vllm_output(vllm_output) -> ctrl.InferenceResponse.Output: return response_proto def handshake(self, request): - mapping_specs = get_tp_model_mapping(request.model_name or self._llm.llm_engine.model_config.model) + model_name = request.model_name or self._llm.llm_engine.model_config.model + # Use param_mapping_path from request if provided, otherwise fall back to local config + mapping_path = getattr(request, 'param_mapping_path', None) or self._mapping_json_path + mapping_specs = get_tp_model_mapping(model_name, mapping_json_path=mapping_path) mapping_specs, vllm_tp_size = add_sharding_specs(mapping_specs, self._llm, request.jax_parallelism.tp) self._mapping_specs = mapping_specs @@ -160,10 +164,10 @@ def __init__(self, executor, controller_stub, broker_stub, channel=None): super().__init__(executor, controller_stub, broker_stub, channel) self._update_future = None - def subscribe_to_control_messages(self, llm): + def subscribe_to_control_messages(self, llm, mapping_json_path=None): assert self._update_future is None - servicer = RolloutServicer(llm) + servicer = RolloutServicer(llm, mapping_json_path=mapping_json_path) def call(): try: diff --git a/jax-inference-offloading/jax_inference_offloading/engines/__init__.py b/jax-inference-offloading/jax_inference_offloading/engines/__init__.py new file mode 100644 index 000000000..a55867829 --- /dev/null +++ b/jax-inference-offloading/jax_inference_offloading/engines/__init__.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference engine implementations.""" + +from jax_inference_offloading.engines.vllm_rollout_engine import VLLMRolloutEngine +from jax_inference_offloading.engines.vllm_transfer_engine import VLLMTransferEngine + +__all__ = ["VLLMRolloutEngine", "VLLMTransferEngine"] diff --git a/jax-inference-offloading/jax_inference_offloading/engines/vllm_rollout_engine.py b/jax-inference-offloading/jax_inference_offloading/engines/vllm_rollout_engine.py new file mode 100644 index 000000000..8391752fd --- /dev/null +++ b/jax-inference-offloading/jax_inference_offloading/engines/vllm_rollout_engine.py @@ -0,0 +1,250 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""VLLMRolloutEngine: Handles inference requests to vLLM.""" + +import json +import secrets +import traceback +from logging import getLogger +from queue import Empty, Queue +from typing import List, Optional, Union + +import jax_inference_offloading.api.controller_pb2 as ctrl +from jax_inference_offloading.api.message_broker_pb2 import SubscribeRequest +from jax_inference_offloading.api.types import ( + CompletionOutput, + InferenceConfig, + InferenceOutput, +) +from jax_inference_offloading.controller.spmd import on_spmd_leader +from jax_inference_offloading.controller.utils import create_topic +from jax_inference_offloading.session import OffloadingSession +from jax_inference_offloading.timer import Timer + +logger = getLogger(__name__) + + +class VLLMRolloutEngine: + """vLLM-based rollout engine for inference offloading. + + This engine handles inference/rollout generation by sending requests to vLLM + via the gateway. It is designed to work with an OffloadingSession and does + not handle weight transfer (use VLLMTransferEngine for that). + + Args: + session: An initialized OffloadingSession. + timer: Optional timer for performance profiling. + + Example: + >>> session = OffloadingSession( + ... gateway_url="localhost:50051", + ... mesh=mesh, + ... model_name="meta-llama/Llama-3.1-8B", + ... ) + >>> rollout_engine = VLLMRolloutEngine(session) + >>> config = InferenceConfig(max_tokens=128, temperature=0.9) + >>> output = rollout_engine.generate(["What is 2+2?"], config) + >>> print(output.texts[0]) + """ + + def __init__( + self, + session: OffloadingSession, + timer: Optional[Timer] = None, + ): + self._session = session + self._timer = timer or Timer() + + # Set up inference response stream + self._inference_topic_id = f"inference/results/{secrets.token_hex(16)}" + self._response_queue: Queue = Queue() + self._stream = None + self._response_future = None + + # Start background thread to handle responses + self._setup_response_stream() + + logger.warning("VLLMRolloutEngine initialized") + + def _setup_response_stream(self): + """Set up the background thread for receiving inference responses.""" + self._stream = self._session.broker_stub.SubscriptionStream( + SubscribeRequest(topics=[create_topic(self._inference_topic_id)]) + ) + + def handle_responses(): + try: + for delivery in self._stream: + result = ctrl.InferenceResponse() + delivery.message.payload.Unpack(result) + self._response_queue.put(result) + except Exception as e: + # Treat intentional cancellations/unavailable server as graceful closure + import grpc + if isinstance(e, grpc.RpcError) and e.code() in ( + grpc.StatusCode.CANCELLED, + grpc.StatusCode.UNAVAILABLE, + ): + return + else: + logger.error(f"Error in inference response stream: {e}") + traceback.print_exc() + raise e + + self._response_future = self._session.executor.submit(handle_responses) + + def generate( + self, + prompts: Union[List[str], List[List[int]]], + config: InferenceConfig, + ) -> InferenceOutput: + """Generate completions using vLLM. + + This is a blocking call that waits until vLLM returns the response. + + Args: + prompts: Text prompts or pre-tokenized prompts. Supports: + - List[str]: Text prompts + - List[List[int]]: Pre-tokenized prompts + + config: Inference configuration. + + Returns: + InferenceOutput with generated completions. + """ + with self._timer.section("inference"): + # Build protobuf config + proto_config = ctrl.RolloutConfig( + max_tokens=config.max_tokens, + temperature=config.temperature, + top_p=config.top_p, + top_k=config.top_k, + num_outputs=config.n, + seed=config.seed or 42, + ) + proto_config.stop_token_ids.extend(config.stop_token_ids) + + # Build inference request + request = ctrl.InferenceRequest() + request.response_topic = self._inference_topic_id + request.config.CopyFrom(proto_config) + + # Add prompts to request + self._add_prompts_to_request(prompts, request) + + # Send inference request (only on leader, but all ranks wait for response) + self._send_inference_request(request) + + # Wait for response + response = self._response_queue.get() + + # Convert response to framework-agnostic output + return self._convert_response(response) + + @on_spmd_leader(broadcast_result=False) + def _send_inference_request(self, request: ctrl.InferenceRequest): + """Send inference request to gateway. Only executed on leader.""" + self._session.controller_stub.AsyncInference(request) + + def _add_prompts_to_request( + self, + prompts: Union[str, List[str], List[int], List[List[int]], List[dict], List[List[dict]]], + request: ctrl.InferenceRequest, + ): + """Add prompts to the inference request in the appropriate format.""" + if isinstance(prompts, str): + request.prompts.append(ctrl.Prompt(text_prompt=prompts)) + elif isinstance(prompts, list) and all(isinstance(p, int) for p in prompts): + tids = ctrl.TokenIds() + tids.ids.extend(prompts) + request.prompts.append(ctrl.Prompt(tokenized_prompt=tids)) + elif isinstance(prompts, list) and all(isinstance(p, str) for p in prompts): + for p in prompts: + request.prompts.append(ctrl.Prompt(text_prompt=p)) + elif isinstance(prompts, list) and all(isinstance(p, dict) for p in prompts): + request.prompts.append(ctrl.Prompt(chat_messages_json=json.dumps(prompts))) + elif isinstance(prompts, list) and all(isinstance(p, list) for p in prompts): + for p in prompts: + if all(isinstance(m, dict) for m in p): + request.prompts.append(ctrl.Prompt(chat_messages_json=json.dumps(p))) + elif all(isinstance(m, int) for m in p): + tids = ctrl.TokenIds() + tids.ids.extend(p) + request.prompts.append(ctrl.Prompt(tokenized_prompt=tids)) + else: + raise ValueError( + f"Invalid prompt format. Expected a list of dicts or ints. Got {p}." + ) + else: + raise ValueError(f"Invalid prompt format: {prompts}") + + def _convert_response(self, response: ctrl.InferenceResponse) -> InferenceOutput: + """Convert protobuf response to framework-agnostic output.""" + completions = [] + for output in response.outputs: + completions.append( + CompletionOutput( + text=output.generated_text, + token_ids=list(output.generated_tokens.ids), + logprobs=( + list(output.generated_token_logps) + if output.generated_token_logps + else None + ), + prompt_token_ids=( + list(output.tokenized_prompt.ids) + if output.tokenized_prompt.ids + else None + ), + ) + ) + return InferenceOutput(completions=completions) + + @property + def session(self) -> OffloadingSession: + """Access the underlying session.""" + return self._session + + @property + def timer(self) -> Timer: + """Access the timer for performance analysis.""" + return self._timer + + def shutdown(self) -> None: + """Shutdown the rollout engine.""" + # Cancel the gRPC stream to unblock the background thread + if self._stream is not None: + try: + self._stream.cancel() + except Exception: + pass + + # Wait for the background thread to finish + if self._response_future is not None: + try: + self._response_future.result(timeout=5) + except Exception: + pass + + def __enter__(self): + """Context manager entry.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit.""" + self.shutdown() + return False diff --git a/jax-inference-offloading/jax_inference_offloading/engines/vllm_transfer_engine.py b/jax-inference-offloading/jax_inference_offloading/engines/vllm_transfer_engine.py new file mode 100644 index 000000000..49cad38c9 --- /dev/null +++ b/jax-inference-offloading/jax_inference_offloading/engines/vllm_transfer_engine.py @@ -0,0 +1,204 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""VLLMTransferEngine: Handles weight transfer from JAX to vLLM.""" + +import json +from logging import getLogger +from typing import Dict, Optional, Union + +import jax + +import jax_inference_offloading.api.controller_pb2 as ctrl +from jax_inference_offloading.controller.spmd import on_spmd_leader +from jax_inference_offloading.models import flatten_state, get_named_parameters +from jax_inference_offloading.session import OffloadingSession +from jax_inference_offloading.timer import Timer +from jax_inference_offloading.transport.model.nccl_fused import NcclFusedModelTransport +from jax_inference_offloading.transport.model.nccl_grouped import NcclGroupedModelTransport +from jax_inference_offloading.transport.model.nccl_unfused import NcclUnfusedModelTransport +from jax_inference_offloading.transport.tensor.nccl_star import NcclStarTransport + +logger = getLogger(__name__) + + +class VLLMTransferEngine: + """Engine for transferring model weights from JAX to vLLM. + + This engine handles NCCL transport creation and weight transfer operations. + It is designed to work with an OffloadingSession and does not use TrainerClient. + + Args: + session: An initialized OffloadingSession. + transfer_mode: Weight transfer mode ('fused', 'unfused', 'grouped'). + Default is 'grouped' which batches all transfers for efficiency. + timer: Optional timer for performance profiling. + + Example: + >>> session = OffloadingSession( + ... gateway_url="localhost:50051", + ... mesh=mesh, + ... model_name="meta-llama/Llama-3.1-8B", + ... ) + >>> transfer_engine = VLLMTransferEngine(session) + >>> transfer_engine.update_weights(my_params) + """ + + def __init__( + self, + session: OffloadingSession, + transfer_mode: str = "grouped", + timer: Optional[Timer] = None, + ): + self._session = session + self._transfer_mode = transfer_mode + self._timer = timer or Timer() + + # Create NCCL transports + with self._timer.section("create_transport"): + self._transports, self._transport_config = self._create_transport() + + # Create model transport based on transfer mode + with self._timer.section("create_model_transport"): + if transfer_mode == 'fused': + self._model_transport = NcclFusedModelTransport( + session.mesh, + session.mapping_specs, + self, # Pass self as gateway (provides start_weight_transfer) + self._transports, + self._transport_config, + timer=self._timer, + ) + elif transfer_mode == 'unfused': + self._model_transport = NcclUnfusedModelTransport( + session.mesh, + session.mapping_specs, + self, + self._transports, + self._transport_config, + timer=self._timer, + ) + elif transfer_mode == 'grouped': + self._model_transport = NcclGroupedModelTransport( + session.mesh, + session.mapping_specs, + self, + self._transports, + self._transport_config, + timer=self._timer, + ) + else: + raise ValueError(f"Unknown transfer_mode: {transfer_mode}") + + logger.warning( + f"VLLMTransferEngine initialized with mode={transfer_mode}, " + f"transports={len(self._transports)}" + ) + + def _create_transport(self): + """Create NCCL transports for JAX-vLLM communication.""" + transport_cls = NcclStarTransport + + # Configure transport (calls get_nccl_id via self) + @on_spmd_leader() + def _configure(): + cfg = transport_cls.configure( + self, + trainer_ranks=self._session.jax_parallelism.tp, + rollout_ranks=self._session.vllm_parallelism.tp, + ) + # Signal vLLM to create its side of the transport + self._session.controller_stub.CreateTransport( + ctrl.CreateTransportRequest(config_json=json.dumps(cfg)) + ) + return cfg + + transport_config = _configure() + + # Create JAX-side transports + transports = transport_cls.create_trainer_transport(transport_config) + + logger.warning( + f"Created {len(transports)} NCCL transports in {transport_config['MODE']} mode" + ) + + return transports, transport_config + + def get_nccl_id(self): + """Get NCCL unique ID from gateway. Used by NcclStarTransport.configure().""" + return self._session.get_nccl_id() + + @on_spmd_leader(broadcast_result=False) + def start_weight_transfer(self, mode: str): + """Signal vLLM to start receiving weights. Used by model transport.""" + self._session.controller_stub.StartWeightUpdate( + ctrl.StartWeightUpdateRequest(mode=mode) + ) + + def update_weights( + self, + params: Union[Dict[str, jax.Array], "nnx.State", "nnx.Module"], # noqa: F821 + ) -> None: + """Transfer model weights to vLLM. + + This is a blocking call that waits until all weights are transferred. + + Args: + params: Model parameters in various formats: + - Dict[str, jax.Array]: Direct flattened params + - flax.nnx.State: Flax state object + - flax.nnx.Module: Flax module (state extracted automatically) + """ + with self._timer.section("update_weights"): + # Handle different input formats + with self._timer.section("to_named_parameters"): + if isinstance(params, dict): + named_params = params + else: + # Try flax.nnx formats + try: + from flax import nnx + + if isinstance(params, nnx.Module): + named_params = get_named_parameters(params) + elif isinstance(params, nnx.State): + named_params = flatten_state(params) + else: + raise TypeError(f"Unsupported params type: {type(params)}") + except ImportError: + raise TypeError( + f"Unsupported params type: {type(params)}. " + "Expected Dict[str, jax.Array] or install flax for nnx support." + ) + + # Transfer via model transport + with self._timer.section("transfer"): + self._model_transport(named_params) + + @property + def session(self) -> OffloadingSession: + """Access the underlying session.""" + return self._session + + @property + def timer(self) -> Timer: + """Access the timer for performance analysis.""" + return self._timer + + @property + def transfer_mode(self) -> str: + """Get the current transfer mode.""" + return self._transfer_mode diff --git a/jax-inference-offloading/jax_inference_offloading/integrations/__init__.py b/jax-inference-offloading/jax_inference_offloading/integrations/__init__.py new file mode 100644 index 000000000..dcd32cf2c --- /dev/null +++ b/jax-inference-offloading/jax_inference_offloading/integrations/__init__.py @@ -0,0 +1,17 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Framework-specific integrations for jax-inference-offloading.""" diff --git a/jax-inference-offloading/jax_inference_offloading/integrations/standalone/__init__.py b/jax-inference-offloading/jax_inference_offloading/integrations/standalone/__init__.py new file mode 100644 index 000000000..3372482ef --- /dev/null +++ b/jax-inference-offloading/jax_inference_offloading/integrations/standalone/__init__.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Standalone model implementations using Flax NNX.""" + +from jax_inference_offloading.integrations.standalone.model import Llama3 + +__all__ = ["Llama3"] diff --git a/jax-inference-offloading/jax_inference_offloading/integrations/standalone/model.py b/jax-inference-offloading/jax_inference_offloading/integrations/standalone/model.py new file mode 100644 index 000000000..819ac1ff6 --- /dev/null +++ b/jax-inference-offloading/jax_inference_offloading/integrations/standalone/model.py @@ -0,0 +1,459 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Standalone LLama3 model implementation using Flax NNX. + +This is a minimal, self-contained implementation without config dataclasses. +Shardings are not hardcoded - they come from the loaded parameter pytree. +""" + +from typing import Tuple + +from flax import nnx +import jax +from jax import numpy as jnp +import jaxtyping + +K_MASK = -2.3819763e38 + +LayerCache = dict[str, jaxtyping.Array] +Cache = dict[str, LayerCache] + + +class Einsum(nnx.Module): + """Einsum is a convenience module for parameterized tensor multiplication.""" + + def __init__( + self, + einsum_str: str, + shape: Tuple[int, ...], + *, + rngs: nnx.Rngs, + ): + self.einsum_str = einsum_str + self.shape = shape + self.w = nnx.Param(nnx.initializers.normal()(rngs.params(), shape)) + + @jax.named_scope("einsum") + def __call__(self, x: jaxtyping.ArrayLike) -> jaxtyping.Array: + return jnp.einsum(self.einsum_str, x, self.w.value) + + +class Embedder(nnx.Module): + """Embedder module.""" + + def __init__( + self, + vocab_size: int, + embed_dim: int, + *, + rngs: nnx.Rngs, + ): + self.input_embedding = nnx.Param( + nnx.initializers.normal()(rngs.params(), (vocab_size, embed_dim)) + ) + + @jax.named_scope("embedder_encode") + def encode(self, x: jaxtyping.ArrayLike) -> jaxtyping.Array: + return self.input_embedding[(x,)] + + @jax.named_scope("embedder_decode") + def decode(self, x: jaxtyping.ArrayLike) -> jaxtyping.Array: + return jnp.dot(x, self.input_embedding.value.T) + + +def apply_rope( + inputs: jaxtyping.Array, # [B, L, N, H] + positions: jaxtyping.Array, # [B, L] + head_dim: int, + rope_theta: int = 500_000, +) -> jaxtyping.Array: + """Applies Rotary Position Embedding (RoPE).""" + fraction = 2 * jnp.arange(0, head_dim // 2, dtype=jnp.float32) / head_dim + timescale = rope_theta**fraction + + sinusoid_inp = ( + positions[..., jnp.newaxis] / timescale[jnp.newaxis, jnp.newaxis, :] + ) + sinusoid_inp = sinusoid_inp[..., jnp.newaxis, :] + sin = jnp.sin(sinusoid_inp) + cos = jnp.cos(sinusoid_inp) + + first_half, second_half = jnp.split(inputs, 2, axis=-1) + first_part = first_half * cos - second_half * sin + second_part = second_half * cos + first_half * sin + out = jnp.concatenate([first_part, second_part], axis=-1) + return out.astype(inputs.dtype) + + +class RMSNorm(nnx.Module): + """RMSNorm layer.""" + + def __init__( + self, + dim: int, + *, + norm_eps: float = 1e-06, + rngs: nnx.Rngs, + ): + self.w = nnx.Param(nnx.initializers.ones_init()(rngs.params(), (dim,))) + self.norm_eps = norm_eps + + @jax.named_scope("rms_norm") + def __call__(self, x: jaxtyping.Array) -> jaxtyping.Array: + dtype = x.dtype + rms = jnp.sqrt( + jnp.mean(jnp.astype(x, jnp.float32) ** 2, axis=-1, keepdims=True) + + self.norm_eps + ) + return jnp.astype(self.w * x / rms, dtype) + + +class Attention(nnx.Module): + """Multi-head attention with Grouped Query Attention (GQA) support.""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + head_dim: int, + num_kv_heads: int, + rope_theta: int = 500_000, + *, + rngs: nnx.Rngs, + ): + self.num_heads = num_heads + self.head_dim = head_dim + self.num_kv_heads = num_kv_heads + self.rope_theta = rope_theta + self.n_rep = num_heads // num_kv_heads + self.scale = head_dim**-0.5 + + self.q_proj = Einsum( + einsum_str="BTD,DNH->BTNH", + shape=(embed_dim, num_heads, head_dim), + rngs=rngs, + ) + self.k_proj = Einsum( + einsum_str="BSD,DKH->BSKH", + shape=(embed_dim, num_kv_heads, head_dim), + rngs=rngs, + ) + self.v_proj = Einsum( + einsum_str="BSD,DKH->BSKH", + shape=(embed_dim, num_kv_heads, head_dim), + rngs=rngs, + ) + self.o_proj = Einsum( + einsum_str="BTNH,NHD->BTD", + shape=(num_heads, head_dim, embed_dim), + rngs=rngs, + ) + + @jax.named_scope("attention") + def __call__( + self, + x: jaxtyping.Array, + segment_pos: jaxtyping.Array, + cache: LayerCache | None, + attn_mask: jaxtyping.Array | None, + ) -> tuple[LayerCache | None, jaxtyping.Array]: + """Attention forward pass.""" + seq_len = x.shape[1] + + query_proj = self.q_proj(x) + key_proj = self.k_proj(x) + value_proj = self.v_proj(x) + + query_proj = apply_rope( + query_proj, + segment_pos, + head_dim=self.head_dim, + rope_theta=self.rope_theta, + ) + key_proj = apply_rope( + key_proj, + segment_pos, + head_dim=self.head_dim, + rope_theta=self.rope_theta, + ) + + if cache is not None: + end_index = cache["end_index"][0] + slice_indices = (0, end_index % cache["v"].shape[1], 0, 0) + value_proj = jax.lax.dynamic_update_slice( + cache["v"], + value_proj, + slice_indices, + ) + key_proj = jax.lax.dynamic_update_slice( + cache["k"], key_proj, slice_indices + ) + + b, t, qh, d = query_proj.shape + _, s, kh, _ = key_proj.shape + + # Grouped Query Attention + query_proj = query_proj.reshape((b, t, kh, qh // kh, d)) + attn = jnp.einsum("BTHGD,BSHD->BHGTS", query_proj, key_proj) * self.scale + attn = attn.reshape((b, qh, t, s)) + + if attn_mask is not None: + attn = jnp.where((jnp.expand_dims(attn_mask, -3)), attn, K_MASK) + + attn = jax.nn.softmax(attn.astype(jnp.float32), axis=-1).astype( + key_proj.dtype + ) + + attn = attn.reshape((b, kh, qh // kh, t, s)) + qkv = jnp.einsum("BHGTS,BSHD->BTHGD", attn, value_proj) + qkv = qkv.reshape((b, t, qh, d)) + + outputs = self.o_proj(qkv) + + if cache is not None: + new_cache = { + "v": value_proj, + "k": key_proj, + "end_index": cache["end_index"] + seq_len, + } + else: + new_cache = None + + return new_cache, outputs + + +class MLP(nnx.Module): + """MLP module with SwiGLU activation.""" + + def __init__( + self, + embed_dim: int, + hidden_dim: int, + *, + rngs: nnx.Rngs, + ): + self.gate_proj = nnx.Linear( + in_features=embed_dim, + out_features=hidden_dim, + use_bias=False, + rngs=rngs, + ) + self.up_proj = nnx.Linear( + in_features=embed_dim, + out_features=hidden_dim, + use_bias=False, + rngs=rngs, + ) + self.down_proj = nnx.Linear( + in_features=hidden_dim, + out_features=embed_dim, + use_bias=False, + rngs=rngs, + ) + + @jax.named_scope("feed_forward") + def __call__(self, x: jaxtyping.ArrayLike) -> jaxtyping.Array: + activations = nnx.silu(self.gate_proj(x)) * self.up_proj(x) + outputs = self.down_proj(activations) + return outputs + + +class DecoderLayer(nnx.Module): + """Single transformer decoder layer.""" + + def __init__( + self, + embed_dim: int, + hidden_dim: int, + num_heads: int, + head_dim: int, + num_kv_heads: int, + rope_theta: int = 500_000, + norm_eps: float = 1e-5, + *, + rngs: nnx.Rngs, + ): + self.input_layernorm = RMSNorm( + embed_dim, + norm_eps=norm_eps, + rngs=rngs, + ) + self.attn = Attention( + embed_dim=embed_dim, + num_heads=num_heads, + head_dim=head_dim, + num_kv_heads=num_kv_heads, + rope_theta=rope_theta, + rngs=rngs, + ) + self.mlp = MLP( + embed_dim=embed_dim, + hidden_dim=hidden_dim, + rngs=rngs, + ) + self.post_attention_layernorm = RMSNorm( + embed_dim, + norm_eps=norm_eps, + rngs=rngs, + ) + + def __call__( + self, + x: jaxtyping.Array, + segment_pos: jaxtyping.Array, + cache: LayerCache | None, + attn_mask: jaxtyping.Array, + ) -> tuple[LayerCache | None, jaxtyping.Array]: + inputs_normalized = self.input_layernorm(x) + cache, attn_output = self.attn( + inputs_normalized, + segment_pos, + cache, + attn_mask, + ) + attn_output = attn_output + x + residual = attn_output + attn_output = self.post_attention_layernorm(attn_output) + outputs = residual + self.mlp(attn_output) + return cache, outputs + + +class Llama3(nnx.Module): + """Standalone LLama3 model. + + This implementation takes all model hyperparameters directly as constructor + arguments rather than using a config dataclass. Shardings are not hardcoded; + they are inherited from the loaded parameter pytree. + + Example usage for Llama3.2 1B: + model = Llama3( + num_layers=16, + vocab_size=128256, + embed_dim=2048, + hidden_dim=8192, + num_heads=32, + head_dim=64, + num_kv_heads=8, + rope_theta=500_000, + norm_eps=1e-5, + weight_tying=True, + rngs=nnx.Rngs(0), + ) + """ + + def __init__( + self, + num_layers: int, + vocab_size: int, + embed_dim: int, + hidden_dim: int, + num_heads: int, + head_dim: int, + num_kv_heads: int, + rope_theta: int = 500_000, + norm_eps: float = 1e-5, + weight_tying: bool = False, + *, + rngs: nnx.Rngs, + ): + self.num_layers = num_layers + self.vocab_size = vocab_size + self.embed_dim = embed_dim + self.hidden_dim = hidden_dim + self.num_heads = num_heads + self.head_dim = head_dim + self.num_kv_heads = num_kv_heads + self.rope_theta = rope_theta + self.norm_eps = norm_eps + self.weight_tying = weight_tying + + self.embedder = Embedder( + vocab_size=vocab_size, + embed_dim=embed_dim, + rngs=rngs, + ) + + self.layers = [ + DecoderLayer( + embed_dim=embed_dim, + hidden_dim=hidden_dim, + num_heads=num_heads, + head_dim=head_dim, + num_kv_heads=num_kv_heads, + rope_theta=rope_theta, + norm_eps=norm_eps, + rngs=rngs, + ) + for _ in range(num_layers) + ] + + self.final_norm = RMSNorm( + embed_dim, + rngs=rngs, + norm_eps=norm_eps, + ) + + if not weight_tying: + self.lm_head = Einsum( + einsum_str="BTD,DV->BTV", + shape=(embed_dim, vocab_size), + rngs=rngs, + ) + + def __call__( + self, + input_tokens: jaxtyping.Array, # [B, L] + positions: jaxtyping.Array, # [B, L] + cache: Cache | None, # (sequence length L') + attention_mask: jaxtyping.Array, # [B, L, L'] + ) -> tuple[jaxtyping.Array, Cache | None]: + """LLama3 forward pass. + + Args: + input_tokens: Input sequence of tokens. + positions: Input absolute positions. + cache: Attention KV cache or None. + attention_mask: Transformer input mask. + + Returns: + Tuple of (predicted_logits, new_cache): + - predicted_logits: Output logits predicted by the model. + - new_cache: Updated cache if input cache is not None, else None. + """ + new_cache = None if cache is None else {} + x = self.embedder.encode(input_tokens) + + for i, layer in enumerate(self.layers): + layer_name = f"layer_{i}" + layer_cache = cache[layer_name] if cache else None + layer_cache, x = layer( + x, + positions, + layer_cache, + attention_mask, + ) + if cache is not None: + new_cache[layer_name] = layer_cache + + x = self.final_norm(x) + + if self.weight_tying: + logits = self.embedder.decode(x) + else: + logits = self.lm_head(x) + + return logits, new_cache diff --git a/jax-inference-offloading/jax_inference_offloading/tunix/__init__.py b/jax-inference-offloading/jax_inference_offloading/integrations/tunix/__init__.py similarity index 100% rename from jax-inference-offloading/jax_inference_offloading/tunix/__init__.py rename to jax-inference-offloading/jax_inference_offloading/integrations/tunix/__init__.py diff --git a/jax-inference-offloading/jax_inference_offloading/tunix/load_model.py b/jax-inference-offloading/jax_inference_offloading/integrations/tunix/load_model.py similarity index 94% rename from jax-inference-offloading/jax_inference_offloading/tunix/load_model.py rename to jax-inference-offloading/jax_inference_offloading/integrations/tunix/load_model.py index 3fef74854..6cfd7766c 100644 --- a/jax-inference-offloading/jax_inference_offloading/tunix/load_model.py +++ b/jax-inference-offloading/jax_inference_offloading/integrations/tunix/load_model.py @@ -30,11 +30,11 @@ def load_model(name, mesh: jax.sharding.Mesh = None, checkpoint_path: str = None from tunix.models.llama3.model import Llama3, ModelConfig from tunix.models.llama3.params import create_model_from_safe_tensors config_factory = { - '1B': ModelConfig.llama3_2_1b, - '3B': ModelConfig.llama3_2_3b, - '8B': ModelConfig.llama3_1_8b, - '70B': ModelConfig.llama3_70b, - '405B': ModelConfig.llama3_405b, + '1B': ModelConfig.llama3p2_1b, + '3B': ModelConfig.llama3p2_3b, + '8B': ModelConfig.llama3p1_8b, + '70B': ModelConfig.llama3p1_70b, + '405B': ModelConfig.llama3p1_405b, } try: config = config_factory[m.group('size')]() diff --git a/jax-inference-offloading/jax_inference_offloading/integrations/tunix/rollout.py b/jax-inference-offloading/jax_inference_offloading/integrations/tunix/rollout.py new file mode 100644 index 000000000..c77b792a5 --- /dev/null +++ b/jax-inference-offloading/jax_inference_offloading/integrations/tunix/rollout.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tunix adapter for vLLM rollout offloading. + +This module provides VllmGPURollout, a Tunix BaseRollout implementation +that delegates to the framework-agnostic VLLMRolloutEngine. +""" +from typing import Any, List, Optional, Tuple + +import jax +import jax.numpy as jnp +import jaxtyping + +from jax_inference_offloading.api import InferenceConfig, pad_left, pad_right +from jax_inference_offloading.engines import VLLMRolloutEngine +from jax_inference_offloading.timer import Timer +from tunix.rl.rollout.base_rollout import BaseRollout, RolloutConfig, RolloutOutput + + +class VllmGPURollout(BaseRollout): + """Tunix adapter wrapping VLLMRolloutEngine. + + This class implements Tunix's BaseRollout interface by delegating + to the framework-agnostic VLLMRolloutEngine. It handles the conversion + between Tunix's RolloutConfig/RolloutOutput and the bridge's + InferenceConfig/InferenceOutput. + """ + + def __init__( + self, + gateway_url: str, + model_name: str, + *, + rollout_actor, # AKA rollout model (unused for remote engine) + tokenizer, + mesh: jax.sharding.Mesh, + rollout_config: RolloutConfig, # Initial config (unused, passed per-call) + extra_stop_tokens: List[str] | None = None, + transfer_mode: str = "fused", + timer: Any | None = None, + ): + """Initialize the Tunix vLLM rollout adapter. + + Args: + gateway_url: URL of the gateway server. + model_name: HuggingFace model name for tensor mapping. + rollout_actor: The rollout model (unused for remote engine). + tokenizer: Tunix tokenizer for encoding/decoding. + mesh: JAX device mesh. + rollout_config: Initial rollout config (unused, provided per generate call). + extra_stop_tokens: Additional stop tokens as strings. + transfer_mode: Weight transfer mode ('fused', 'unfused', 'grouped'). + timer: Optional timer for profiling. + """ + del rollout_actor # Not used for remote engine + del rollout_config # Config passed per generate() call + + self._timer = timer or Timer() + self._tokenizer = tokenizer + + # Resolve extra stop tokens to IDs + self._extra_stop_token_ids: List[int] = [] + for t in extra_stop_tokens or []: + token_ids = self._tokenizer.encode(t) + assert len(token_ids) == 1, f"Stop token {t} must be a single token, got {token_ids}" + self._extra_stop_token_ids.extend(token_ids) + + # Delegate to the framework-agnostic engine + self._engine = VLLMRolloutEngine( + gateway_url=gateway_url, + model_name=model_name, + mesh=mesh, + transfer_mode=transfer_mode, + timer=self._timer, + ) + + def generate( + self, + prompts: List[str], + rollout_config: RolloutConfig, + ) -> RolloutOutput: + """Generate completions for the given prompts. + + Args: + prompts: List of text prompts. + rollout_config: Tunix rollout configuration. + + Returns: + Tunix RolloutOutput with generated samples. + """ + with self._timer.section("rollout.generate"): + # Convert Tunix RolloutConfig -> InferenceConfig + stop_token_ids = list(self._extra_stop_token_ids) + if rollout_config.eos_tokens is not None: + stop_token_ids = list(rollout_config.eos_tokens) + stop_token_ids + else: + stop_token_ids = [self._tokenizer.eos_id()] + stop_token_ids + + config = InferenceConfig( + max_tokens=rollout_config.max_tokens_to_generate, + temperature=rollout_config.temperature, + top_p=rollout_config.top_p if rollout_config.top_p is not None else 1.0, + top_k=rollout_config.top_k if rollout_config.top_k is not None else -1, + seed=rollout_config.seed, + stop_token_ids=stop_token_ids, + ) + + # Call engine + with self._timer.section("inference"): + output = self._engine.generate([str(p) for p in prompts], config) + + # Convert InferenceOutput -> Tunix RolloutOutput + with self._timer.section("process_outputs"): + generated_text = [] + input_tokens = [] + output_tokens = [] + + for i, completion in enumerate(output.completions): + generated_text.append(completion.text) + input_tokens.append( + pad_left( + completion.prompt_token_ids or [], + rollout_config.max_prompt_length, + self._tokenizer.pad_id(), + ) + ) + output_tokens.append( + pad_right( + completion.token_ids, + rollout_config.max_tokens_to_generate, + self._tokenizer.pad_id(), + ) + ) + + return RolloutOutput( + text=generated_text, + logits=[], # Not needed for GRPO + tokens=jnp.array(output_tokens, dtype=jnp.int32), + left_padded_prompt_tokens=jnp.array(input_tokens, dtype=jnp.int32), + logprobs=None, # GRPOLearner will recalculate + ) + + def get_per_token_logps( + self, + prompt_tokens: jax.Array, + completion_tokens: jax.Array, + completion_mask: jax.Array | None = None, + ) -> jax.Array: + """Get per-token log probabilities. + + Not implemented for remote engine - use GRPOLearner's recalculation. + """ + raise NotImplementedError( + "get_per_token_logps is not supported for remote vLLM engine. " + "Use GRPOLearner which recalculates logprobs locally." + ) + + def update_params( + self, + params: jaxtyping.PyTree, + filter_types: Optional[Tuple[Any, ...]] = None, + ) -> None: + """Update the rollout model parameters. + + Args: + params: Model parameters to transfer. + filter_types: Unused for remote engine. + """ + del filter_types # Not used for remote engine + with self._timer.section("rollout.update_params"): + self._engine.update_weights(params) + + def pad_id(self) -> int: + """Return the padding token ID.""" + return self._tokenizer.pad_id() + + def eos_id(self) -> int: + """Return the end-of-sequence token ID.""" + return self._tokenizer.eos_id() + + def model(self): + """Return the local model (None for remote engine).""" + return None + + def shutdown(self) -> None: + """Gracefully shutdown the remote gateway.""" + self._engine.shutdown() + + def __del__(self): + """Destructor - attempt graceful shutdown.""" + try: + self.shutdown() + except Exception: + # Suppress destructor-time errors during interpreter shutdown. + pass diff --git a/jax-inference-offloading/jax_inference_offloading/models/__init__.py b/jax-inference-offloading/jax_inference_offloading/models/__init__.py index abf2b1652..f2230dde1 100644 --- a/jax-inference-offloading/jax_inference_offloading/models/__init__.py +++ b/jax-inference-offloading/jax_inference_offloading/models/__init__.py @@ -32,18 +32,18 @@ def make_transform(slice=[], transpose=[], reshape=[], replication_axis=None, re def make_mapping( - jax_name, vllm_name, vllm_shape, *, transform=None, jax_prefix="model", vllm_prefix="model" + jax_name, vllm_name, vllm_shape, *, transform=None, vllm_prefix="model" ): result = mapping.ParamMapping() result.vllm_param.name = f"{vllm_prefix}.{vllm_name}".lstrip(".") result.vllm_param.shape.extend(vllm_shape) - result.jax_param.name = f"{jax_prefix}.{jax_name}".lstrip(".") + result.jax_param.name = jax_name if transform is not None: result.jax_param.transform.CopyFrom(transform) return result -def flatten_state(nnx_state, prefix="model"): +def flatten_state(nnx_state): """Flatten an NNX state tree into a dictionary with dot-separated keys.""" def _flatten_dict(nnx_state, prefix=""): @@ -51,14 +51,15 @@ def _flatten_dict(nnx_state, prefix=""): yield prefix, nnx_state.value except AttributeError: for k in nnx_state.keys(): - yield from _flatten_dict(nnx_state[k], prefix=".".join([prefix, str(k)])) + new_prefix = f"{prefix}.{k}" if prefix else str(k) + yield from _flatten_dict(nnx_state[k], prefix=new_prefix) - return dict(_flatten_dict(nnx_state, prefix=prefix)) + return dict(_flatten_dict(nnx_state)) -def get_named_parameters(nnx_model, prefix="model", *filters): +def get_named_parameters(nnx_model, *filters): """Flatten an NNX model into a dictionary with dot-separated keys.""" from flax import nnx nnx_state = nnx.state(nnx_model, *filters) - return flatten_state(nnx_state, prefix=prefix) + return flatten_state(nnx_state) diff --git a/jax-inference-offloading/jax_inference_offloading/models/auto.py b/jax-inference-offloading/jax_inference_offloading/models/auto.py index c6902d970..565dea76a 100644 --- a/jax-inference-offloading/jax_inference_offloading/models/auto.py +++ b/jax-inference-offloading/jax_inference_offloading/models/auto.py @@ -14,27 +14,44 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os + import jax_inference_offloading.api.param_mapping_pb2 as mapping from .gemma import get_gemma_2b_mapping, get_gemma_7b_mapping from .gemma3 import get_gemma3_1b_mapping -from .llama3 import get_llama3_8b_mapping, get_llama3_70b_mapping, get_llama3_405b_mapping +from .llama3 import get_llama3_1b_mapping, get_llama3_8b_mapping, get_llama3_70b_mapping, get_llama3_405b_mapping +from .mapping_util import load_mapping_from_json + + +def get_tp_model_mapping(model_name, vllm_prefix="model", mapping_json_path=None) -> mapping.TpModelMappingSpecs: + """Get the parameter mapping for a model. + Args: + model_name: HuggingFace model name (e.g., "meta-llama/Llama-3.1-8B"). + vllm_prefix: Prefix for vLLM parameter names. + mapping_json_path: Optional path to a custom param_mapping.json file. + If provided and the file exists, it will be used instead of hardcoded mappings. -def get_tp_model_mapping(model_name, jax_prefix="model", vllm_prefix="model") -> mapping.TpModelMappingSpecs: + Returns: + TpModelMappingSpecs protobuf with parameter mappings. + """ + # Check for custom JSON mapping file first + if mapping_json_path is not None and os.path.exists(mapping_json_path): + return load_mapping_from_json(mapping_json_path) if model_name in ("google/gemma-2b", "google/gemma-2b-it"): - return get_gemma_2b_mapping(jax_prefix, vllm_prefix) + return get_gemma_2b_mapping(vllm_prefix) elif model_name in ("google/gemma-7b", "google/gemma-7b-it"): - return get_gemma_7b_mapping(jax_prefix, vllm_prefix) + return get_gemma_7b_mapping(vllm_prefix) elif model_name in ("google/gemma-3-1b", "google/gemma-3-1b-it"): - return get_gemma3_1b_mapping(jax_prefix, vllm_prefix) + return get_gemma3_1b_mapping(vllm_prefix) elif model_name in ( "meta-llama/Meta-Llama-3-8B", "meta-llama/Meta-Llama-3-8B-Instruct", "meta-llama/Llama-3.1-8B", "meta-llama/Llama-3.1-8B-Instruct", ): - return get_llama3_8b_mapping(jax_prefix, vllm_prefix) + return get_llama3_8b_mapping(vllm_prefix) elif model_name in ( # Llama 3.2 covers only 1B/3B # Llama 3.3 is instruct-only at 70B @@ -47,7 +64,7 @@ def get_tp_model_mapping(model_name, jax_prefix="model", vllm_prefix="model") -> "meta-llama/Llama-3.1-70B-Instruct", "meta-llama/Llama-3.3-70B-Instruct", ): - return get_llama3_70b_mapping(jax_prefix, vllm_prefix) + return get_llama3_70b_mapping(vllm_prefix) elif model_name in ( "meta-llama/Meta-Llama-3.1-405B", "meta-llama/Meta-Llama-3.1-405B-Instruct", @@ -56,5 +73,10 @@ def get_tp_model_mapping(model_name, jax_prefix="model", vllm_prefix="model") -> "meta-llama/Llama-3.1-405B-Instruct", "meta-llama/Llama-3.1-405B-Instruct-FP8", ): - return get_llama3_405b_mapping(jax_prefix, vllm_prefix) + return get_llama3_405b_mapping(vllm_prefix) + elif model_name in ( + "meta-llama/Llama-3.2-1B", + "meta-llama/Llama-3.2-1B-Instruct", + ): + return get_llama3_1b_mapping(vllm_prefix) raise Exception(f"Unknown model {model_name}.") diff --git a/jax-inference-offloading/jax_inference_offloading/models/checkpoint.py b/jax-inference-offloading/jax_inference_offloading/models/checkpoint.py new file mode 100644 index 000000000..40879cecf --- /dev/null +++ b/jax-inference-offloading/jax_inference_offloading/models/checkpoint.py @@ -0,0 +1,172 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Checkpoint loading utilities for JAX models. + +This module provides functions to load HuggingFace/vLLM checkpoints directly +into JAX arrays, using the parameter mapping JSON to handle shape transformations +and sharding. +""" +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Optional + +import jax +import jax.numpy as jnp +from jax.sharding import NamedSharding, PartitionSpec + +import jax_inference_offloading.api.param_mapping_pb2 as mapping +from jax_inference_offloading.models.mapping_util import ( + apply_vllm_transform, + load_mapping_from_json, +) + + +@dataclass +class MappingConfig: + """Configuration loaded from a parameter mapping JSON file.""" + mesh_axes: List[str] + mapping_specs: mapping.TpModelMappingSpecs + + +def load_mapping_config(json_path: str) -> MappingConfig: + """Load a parameter mapping JSON file and return mesh_axes + mapping specs. + + Args: + json_path: Path to the JSON configuration file. + + Returns: + MappingConfig with mesh_axes and mapping_specs. + """ + mapping_specs = load_mapping_from_json(json_path) + return MappingConfig( + mesh_axes=list(mapping_specs.mesh_axes), + mapping_specs=mapping_specs, + ) + + +def _partition_spec_from_proto( + partition_specs: mapping.JaxParam.PartitionSpecs, + mesh_axes: List[str], +) -> PartitionSpec: + """Convert a PartitionSpecs proto to a JAX PartitionSpec. + + Args: + partition_specs: Proto with axis names (empty string = unsharded). + mesh_axes: List of valid axis names from mesh_axes config. + + Returns: + A JAX PartitionSpec. + """ + axes = [] + for axis in partition_specs.axes: + if axis == "": + axes.append(None) + else: + if axis not in mesh_axes: + raise ValueError( + f"Partition spec axis '{axis}' not found in mesh_axes {mesh_axes}" + ) + axes.append(axis) + return PartitionSpec(*axes) + + +def load_checkpoint_to_jax( + checkpoint_path: str, + mapping_specs: mapping.TpModelMappingSpecs, + mesh: Optional[jax.sharding.Mesh] = None, + dtype: jnp.dtype = jnp.bfloat16, +) -> Dict[str, jax.Array]: + """Load vLLM/HuggingFace checkpoint weights into JAX arrays. + + This function reads safetensors files from a checkpoint directory and + transforms them to match JAX model parameter shapes using the vllm_param.transform + specifications. If a mesh is provided, parameters are sharded according to + jax_param.partition_specs. + + Args: + checkpoint_path: Path to checkpoint directory containing safetensors files. + mapping_specs: TpModelMappingSpecs proto with parameter mappings. + mesh: Optional JAX mesh for sharding. If None, arrays are not sharded. + dtype: Target dtype for loaded arrays (default: bfloat16). + + Returns: + Dictionary mapping JAX parameter names to JAX arrays. + """ + try: + from safetensors import safe_open + except ImportError: + raise ImportError( + "safetensors package is required for checkpoint loading. " + "Install with: pip install safetensors" + ) + + checkpoint_path = Path(checkpoint_path) + mesh_axes = list(mapping_specs.mesh_axes) + + # Find all safetensors files in the checkpoint directory + safetensor_files = list(checkpoint_path.glob("*.safetensors")) + if not safetensor_files: + raise FileNotFoundError( + f"No safetensors files found in {checkpoint_path}" + ) + + # Build a mapping from vLLM param names to file paths + vllm_name_to_file = {} + for st_file in safetensor_files: + with safe_open(st_file, framework="numpy") as f: + for key in f.keys(): + vllm_name_to_file[key] = st_file + + # Load and transform each parameter + jax_params = {} + + for param_mapping in mapping_specs.mappings: + jax_name = param_mapping.jax_param.name + vllm_name = param_mapping.vllm_param.name + + # Find the safetensors file containing this parameter + if vllm_name not in vllm_name_to_file: + raise KeyError( + f"Parameter '{vllm_name}' not found in checkpoint. " + f"Available parameters: {list(vllm_name_to_file.keys())[:10]}..." + ) + + st_file = vllm_name_to_file[vllm_name] + + # Load the tensor + with safe_open(st_file, framework="numpy") as f: + tensor = f.get_tensor(vllm_name) + + # Apply vLLM -> JAX transform if specified + if param_mapping.vllm_param.HasField("transform"): + tensor = apply_vllm_transform(tensor, param_mapping.vllm_param.transform) + + # Convert to JAX array with target dtype + tensor = jnp.asarray(tensor, dtype=dtype) + + # Apply sharding if mesh is provided and partition_specs are defined + if mesh is not None and param_mapping.jax_param.HasField("partition_specs"): + partition_spec = _partition_spec_from_proto( + param_mapping.jax_param.partition_specs, + mesh_axes, + ) + sharding = NamedSharding(mesh, partition_spec) + tensor = jax.device_put(tensor, sharding) + + jax_params[jax_name] = tensor + + return jax_params diff --git a/jax-inference-offloading/jax_inference_offloading/models/gemma.py b/jax-inference-offloading/jax_inference_offloading/models/gemma.py index 7b1a4064d..9de1435d6 100644 --- a/jax-inference-offloading/jax_inference_offloading/models/gemma.py +++ b/jax-inference-offloading/jax_inference_offloading/models/gemma.py @@ -29,10 +29,9 @@ def _get_gemma_mapping( head_dim, ffn_size, qkv_merged, - jax_prefix="model", vllm_prefix="model", ) -> mapping.TpModelMappingSpecs: - param_mapping = partial(make_mapping, jax_prefix=jax_prefix, vllm_prefix=vllm_prefix) + param_mapping = partial(make_mapping, vllm_prefix=vllm_prefix) def qkv_param(layer_id, jax_name, vllm_name, slice=[], heads=kv_heads, replication_axis=None): return param_mapping( @@ -130,7 +129,7 @@ def qkv_specs(layer_id): return result -def get_gemma_2b_mapping(jax_prefix="model", vllm_prefix="model") -> mapping.TpModelMappingSpecs: +def get_gemma_2b_mapping(vllm_prefix="model") -> mapping.TpModelMappingSpecs: return _get_gemma_mapping( vocab_size=256_000, n_layers=18, @@ -139,13 +138,12 @@ def get_gemma_2b_mapping(jax_prefix="model", vllm_prefix="model") -> mapping.TpM kv_heads=1, head_dim=256, ffn_size=16384, - jax_prefix=jax_prefix, vllm_prefix=vllm_prefix, qkv_merged=False, ) -def get_gemma_7b_mapping(jax_prefix="model", vllm_prefix="model") -> mapping.TpModelMappingSpecs: +def get_gemma_7b_mapping(vllm_prefix="model") -> mapping.TpModelMappingSpecs: return _get_gemma_mapping( vocab_size=256_000, n_layers=28, @@ -154,7 +152,6 @@ def get_gemma_7b_mapping(jax_prefix="model", vllm_prefix="model") -> mapping.TpM kv_heads=16, head_dim=256, ffn_size=24576, - jax_prefix=jax_prefix, vllm_prefix=vllm_prefix, qkv_merged=True, ) diff --git a/jax-inference-offloading/jax_inference_offloading/models/gemma3.py b/jax-inference-offloading/jax_inference_offloading/models/gemma3.py index 811888e1e..380b961c2 100644 --- a/jax-inference-offloading/jax_inference_offloading/models/gemma3.py +++ b/jax-inference-offloading/jax_inference_offloading/models/gemma3.py @@ -29,10 +29,9 @@ def _get_gemma3_mapping( head_dim, ffn_size, qkv_merged, - jax_prefix: str = "model", vllm_prefix: str = "model", ) -> mapping.TpModelMappingSpecs: - param_mapping = partial(make_mapping, jax_prefix=jax_prefix, vllm_prefix=vllm_prefix) + param_mapping = partial(make_mapping, vllm_prefix=vllm_prefix) def qkv_param(layer_id, jax_name, vllm_name, slice_spec=None, heads=kv_heads, replication_axis=None): return param_mapping( @@ -145,7 +144,7 @@ def qkv_specs(layer_id): return model_mapping -def get_gemma3_1b_mapping(jax_prefix: str = "model", vllm_prefix: str = "model") -> mapping.TpModelMappingSpecs: +def get_gemma3_1b_mapping(vllm_prefix: str = "model") -> mapping.TpModelMappingSpecs: return _get_gemma3_mapping( vocab_size=262_144, n_layers=26, @@ -155,6 +154,5 @@ def get_gemma3_1b_mapping(jax_prefix: str = "model", vllm_prefix: str = "model") head_dim=256, ffn_size=6912, qkv_merged=False, - jax_prefix=jax_prefix, vllm_prefix=vllm_prefix, ) diff --git a/jax-inference-offloading/jax_inference_offloading/models/llama3.py b/jax-inference-offloading/jax_inference_offloading/models/llama3.py index 2346c9090..c638e6e27 100644 --- a/jax-inference-offloading/jax_inference_offloading/models/llama3.py +++ b/jax-inference-offloading/jax_inference_offloading/models/llama3.py @@ -28,10 +28,10 @@ def _get_llama3_mapping( kv_heads, head_dim, ffn_size, - jax_prefix: str = "model", vllm_prefix: str = "model", + tie_word_embeddings: bool = False, ) -> mapping.TpModelMappingSpecs: - param_mapping = partial(make_mapping, jax_prefix=jax_prefix, vllm_prefix=vllm_prefix) + param_mapping = partial(make_mapping, vllm_prefix=vllm_prefix) params = [ # singletons @@ -41,13 +41,18 @@ def _get_llama3_mapping( [vocab_size, hidden_size], ), param_mapping("final_norm.w", "norm.weight", [hidden_size]), - make_mapping( - "lm_head.w", "lm_head.weight", [vocab_size, hidden_size], - transform=make_transform(transpose=[1, 0]), - jax_prefix=jax_prefix, vllm_prefix='' - ), ] + # Only add lm_head mapping if embeddings are not tied + if not tie_word_embeddings: + params.append( + make_mapping( + "lm_head.w", "lm_head.weight", [vocab_size, hidden_size], + transform=make_transform(transpose=[1, 0]), + vllm_prefix='' + ) + ) + # per-layer for layer_id in range(n_layers): params.extend( @@ -116,7 +121,20 @@ def _get_llama3_mapping( model_mapping.mappings.extend(params) return model_mapping -def get_llama3_8b_mapping(jax_prefix: str = "model", vllm_prefix: str = "model") -> mapping.TpModelMappingSpecs: +def get_llama3_1b_mapping(vllm_prefix: str = "model") -> mapping.TpModelMappingSpecs: + return _get_llama3_mapping( + vocab_size=128_256, + n_layers=16, + hidden_size=2048, + q_heads=32, + kv_heads=8, + head_dim=64, + ffn_size=8192, + vllm_prefix=vllm_prefix, + tie_word_embeddings=True, + ) + +def get_llama3_8b_mapping(vllm_prefix: str = "model") -> mapping.TpModelMappingSpecs: return _get_llama3_mapping( vocab_size=128_256, n_layers=32, @@ -125,11 +143,10 @@ def get_llama3_8b_mapping(jax_prefix: str = "model", vllm_prefix: str = "model") kv_heads=8, head_dim=128, ffn_size=14336, - jax_prefix=jax_prefix, vllm_prefix=vllm_prefix, ) -def get_llama3_70b_mapping(jax_prefix: str = "model", vllm_prefix: str = "model") -> mapping.TpModelMappingSpecs: +def get_llama3_70b_mapping(vllm_prefix: str = "model") -> mapping.TpModelMappingSpecs: return _get_llama3_mapping( vocab_size=128_256, n_layers=80, @@ -138,11 +155,10 @@ def get_llama3_70b_mapping(jax_prefix: str = "model", vllm_prefix: str = "model" kv_heads=8, head_dim=128, ffn_size=28672, - jax_prefix=jax_prefix, vllm_prefix=vllm_prefix, ) -def get_llama3_405b_mapping(jax_prefix: str = "model", vllm_prefix: str = "model") -> mapping.TpModelMappingSpecs: +def get_llama3_405b_mapping(vllm_prefix: str = "model") -> mapping.TpModelMappingSpecs: return _get_llama3_mapping( vocab_size=128_256, n_layers=126, @@ -151,6 +167,5 @@ def get_llama3_405b_mapping(jax_prefix: str = "model", vllm_prefix: str = "model kv_heads=8, head_dim=128, ffn_size=53248, - jax_prefix=jax_prefix, vllm_prefix=vllm_prefix, ) diff --git a/jax-inference-offloading/jax_inference_offloading/models/mapping_util.py b/jax-inference-offloading/jax_inference_offloading/models/mapping_util.py index 5030a455d..d1293f607 100644 --- a/jax-inference-offloading/jax_inference_offloading/models/mapping_util.py +++ b/jax-inference-offloading/jax_inference_offloading/models/mapping_util.py @@ -14,8 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from copy import deepcopy - +import json import google.protobuf.text_format as text_format import numpy as np from vllm import LLM @@ -59,6 +58,18 @@ def _proto_to_slice(proto: mapping.TensorSlice): def apply_transform(tensor, transform: mapping.JaxParam.Transform): + """Apply a JAX -> vLLM transform to a tensor.""" + if transform.slice.dims: + tensor = tensor[_proto_to_slice(transform.slice)] + if transform.transpose: + tensor = tensor.transpose(transform.transpose) + if transform.reshape: + tensor = tensor.reshape(transform.reshape) + return tensor + + +def apply_vllm_transform(tensor, transform: mapping.VllmParam.Transform): + """Apply a vLLM -> JAX transform to a tensor (for checkpoint loading).""" if transform.slice.dims: tensor = tensor[_proto_to_slice(transform.slice)] if transform.transpose: @@ -73,6 +84,221 @@ def load_mapping_spec(filename: str) -> mapping.TpModelMappingSpecs: return text_format.Parse(file.read(), mapping.TpModelMappingSpecs()) +def _parse_slice_from_json(slice_list: list) -> mapping.TensorSlice: + """Parse a JSON slice specification into a TensorSlice proto. + + Args: + slice_list: A list of slice specifications. Each element can be: + - "..." for ellipsis + - An integer for index + - A list [start, stop] for slice (use null for None) + + Returns: + A TensorSlice protobuf message. + """ + result = mapping.TensorSlice() + for dim in slice_list: + slice_dim = mapping.TensorSlice.Dim() + if dim == "...": + slice_dim.ellipsis.SetInParent() + elif isinstance(dim, int): + slice_dim.index.index = dim + elif isinstance(dim, list): + slice_dim.slice.SetInParent() + if len(dim) >= 1 and dim[0] is not None: + slice_dim.slice.start = dim[0] + if len(dim) >= 2 and dim[1] is not None: + slice_dim.slice.stop = dim[1] + else: + raise ValueError(f"Invalid slice specification: {dim}") + result.dims.append(slice_dim) + return result + + +def _parse_jax_transform_from_json(transform_dict: dict) -> mapping.JaxParam.Transform: + """Parse a JSON transform specification into a JaxParam.Transform proto. + + Args: + transform_dict: A dictionary with optional keys: + - "transpose": list of ints (axis permutation) + - "reshape": list of ints (new shape, -1 for inferred) + - "slice": list of slice specs + - "replication_axis": int + - "replication_count": int + + Returns: + A JaxParam.Transform protobuf message. + """ + result = mapping.JaxParam.Transform() + + if "slice" in transform_dict: + result.slice.CopyFrom(_parse_slice_from_json(transform_dict["slice"])) + + if "transpose" in transform_dict: + result.transpose.extend(transform_dict["transpose"]) + + if "reshape" in transform_dict: + result.reshape.extend(transform_dict["reshape"]) + + if "replication_axis" in transform_dict: + result.replication_axis = int(transform_dict["replication_axis"]) + + if "replication_count" in transform_dict: + result.replication_count = int(transform_dict["replication_count"]) + + return result + + +def _parse_vllm_transform_from_json(transform_dict: dict) -> mapping.VllmParam.Transform: + """Parse a JSON transform specification into a VllmParam.Transform proto. + + Used for vLLM -> JAX transforms when loading checkpoints. + + Args: + transform_dict: A dictionary with optional keys: + - "transpose": list of ints (axis permutation) + - "reshape": list of ints (new shape, -1 for inferred) + - "slice": list of slice specs + + Returns: + A VllmParam.Transform protobuf message. + """ + result = mapping.VllmParam.Transform() + + if "slice" in transform_dict: + result.slice.CopyFrom(_parse_slice_from_json(transform_dict["slice"])) + + if "transpose" in transform_dict: + result.transpose.extend(transform_dict["transpose"]) + + if "reshape" in transform_dict: + result.reshape.extend(transform_dict["reshape"]) + + return result + + +def _parse_partition_spec_from_json(partition_spec_list: list) -> mapping.JaxParam.PartitionSpecs: + """Parse a JSON partition spec into a PartitionSpecs proto. + + Args: + partition_spec_list: A list of axis names or null for unsharded dimensions. + e.g., ["fsdp", null, "tp"] means shard dim 0 on "fsdp", dim 2 on "tp". + + Returns: + A JaxParam.PartitionSpecs protobuf message. + """ + result = mapping.JaxParam.PartitionSpecs() + for axis in partition_spec_list: + # Use empty string for null/None (unsharded dimensions) + result.axes.append(axis if axis is not None else "") + return result + + +def load_mapping_from_json(json_path: str) -> mapping.TpModelMappingSpecs: + """Load parameter mapping from a JSON configuration file. + + The JSON schema is: + { + "mesh_axes": ["fsdp", "tp"], + "num_layers": 32, + "mappings": [ + { + "jax_param": { + "name": "layers.{layer}.attn.q_proj.w", + "partition_spec": ["fsdp", null, "tp"], + "transform": { "transpose": [1, 2, 0], "reshape": [-1, 4096] } + }, + "vllm_param": { + "name": "model.layers.{layer}.self_attn.q_proj.weight", + "shape": [4096, 4096], + "transform": { "transpose": [1, 0] } + } + }, + ... + ] + } + + - mesh_axes: Axis names for JAX mesh creation (e.g., ["fsdp", "tp"]) + - JAX parameter names should match the NNX module structure (no prefix) + - vLLM parameter names typically have a "model." prefix (matching vLLM's model structure) + - Mappings with `{layer}` placeholder are expanded into `num_layers` copies + - Mappings without `{layer}` are kept as singletons + - jax_param.transform: Applied when sending JAX -> vLLM (optional) + - jax_param.partition_spec: Sharding spec for checkpoint loading (optional) + - vllm_param.transform: Applied when loading vLLM checkpoint -> JAX (optional) + + Args: + json_path: Path to the JSON configuration file. + + Returns: + A TpModelMappingSpecs protobuf message with all mappings expanded. + """ + with open(json_path, "r") as f: + config = json.load(f) + + num_layers = config.get("num_layers", 0) + mesh_axes = config.get("mesh_axes", []) + json_mappings = config.get("mappings", []) + + model_mapping = mapping.TpModelMappingSpecs() + + # Set mesh axes + model_mapping.mesh_axes.extend(mesh_axes) + + for json_mapping in json_mappings: + jax_param_spec = json_mapping["jax_param"] + vllm_param_spec = json_mapping["vllm_param"] + + jax_name = jax_param_spec["name"] + vllm_name = vllm_param_spec["name"] + vllm_shape = vllm_param_spec["shape"] + + # Check if this is a templated per-layer mapping + if "{layer}" in jax_name or "{layer}" in vllm_name: + # Expand for all layers + layer_indices = range(num_layers) + else: + # Singleton mapping - use None as sentinel + layer_indices = [None] + + for layer_idx in layer_indices: + param_mapping = mapping.ParamMapping() + + # Expand layer placeholder if present + if layer_idx is not None: + expanded_jax_name = jax_name.replace("{layer}", str(layer_idx)) + expanded_vllm_name = vllm_name.replace("{layer}", str(layer_idx)) + else: + expanded_jax_name = jax_name + expanded_vllm_name = vllm_name + + # Set JAX param + param_mapping.jax_param.name = expanded_jax_name + + # Parse and set JAX transform if present (JAX -> vLLM) + if "transform" in jax_param_spec: + transform = _parse_jax_transform_from_json(jax_param_spec["transform"]) + param_mapping.jax_param.transform.CopyFrom(transform) + + # Parse and set partition spec if present (for checkpoint loading) + if "partition_spec" in jax_param_spec: + partition_specs = _parse_partition_spec_from_json(jax_param_spec["partition_spec"]) + param_mapping.jax_param.partition_specs.CopyFrom(partition_specs) + + # Set vLLM param + param_mapping.vllm_param.name = expanded_vllm_name + param_mapping.vllm_param.shape.extend(vllm_shape) + + # Parse and set vLLM transform if present (vLLM -> JAX for checkpoint loading) + if "transform" in vllm_param_spec: + vllm_transform = _parse_vllm_transform_from_json(vllm_param_spec["transform"]) + param_mapping.vllm_param.transform.CopyFrom(vllm_transform) + + model_mapping.mappings.append(param_mapping) + + return model_mapping + + def add_sharding_specs(model_mapping: mapping.TpModelMappingSpecs, llm: LLM, jax_tp_size: int): per_rank_sharding_specs = llm.collective_rpc("get_tp_sharding_specs") vllm_tp_size = len(per_rank_sharding_specs) @@ -93,7 +319,9 @@ def add_sharding_specs(model_mapping: mapping.TpModelMappingSpecs, llm: LLM, jax per_tensor_sharding_specs[name] = specs # convert to VllmTPShardingSpecs - augmented_mapping = deepcopy(model_mapping) + # Use protobuf's CopyFrom instead of Python's deepcopy for proper message copying + augmented_mapping = mapping.TpModelMappingSpecs() + augmented_mapping.CopyFrom(model_mapping) for param in augmented_mapping.mappings: if spec := per_tensor_sharding_specs.get(param.vllm_param.name): dim, parallelism = spec["dim"], spec["parallelism"] diff --git a/jax-inference-offloading/jax_inference_offloading/session.py b/jax-inference-offloading/jax_inference_offloading/session.py new file mode 100644 index 000000000..cbf5143e1 --- /dev/null +++ b/jax-inference-offloading/jax_inference_offloading/session.py @@ -0,0 +1,214 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""OffloadingSession: Manages gRPC connection and handshake for JAX-vLLM offloading.""" + +import secrets +from concurrent.futures import ThreadPoolExecutor +from logging import getLogger +from typing import Optional, Tuple + +import grpc +import jax + +import jax_inference_offloading.api.controller_pb2 as ctrl +import jax_inference_offloading.api.controller_pb2_grpc as ctrl_grpc +import jax_inference_offloading.api.message_broker_pb2_grpc as broker_grpc +from jax_inference_offloading.api.message_broker_pb2 import SubscribeRequest +from jax_inference_offloading.api.utils import proto_to_dataclass +from jax_inference_offloading.controller.spmd import on_spmd_leader +from jax_inference_offloading.controller.utils import create_topic + +logger = getLogger(__name__) + + +class OffloadingSession: + """Manages gRPC connection and handshake for JAX-vLLM offloading. + + This class handles the initial setup for offloading inference from JAX to vLLM, + including gRPC channel management and the handshake protocol. It does not use + TrainerClient, instead managing gRPC stubs directly. + + Args: + gateway_url: URL of the gateway server (e.g., "localhost:50051"). + mesh: JAX device mesh for parallelism info. + model_path: Path to model checkpoint. + param_mapping_path: Path to custom JSON parameter mapping file. + model_name: Optional HuggingFace model name. If not provided and + param_mapping_path is set, uses param_mapping_path for resolution. + timeout: Timeout in seconds for gRPC channel readiness. + + Raises: + ValueError: If neither model_name nor param_mapping_path is provided. + + Example: + >>> session = OffloadingSession( + ... gateway_url="localhost:50051", + ... mesh=jax.make_mesh((8,), ("tp",)), + ... model_path="/path/to/checkpoint", + ... param_mapping_path="/path/to/mapping.json", + ... ) + >>> # Use session with VLLMTransferEngine and VLLMRolloutEngine + """ + + def __init__( + self, + gateway_url: str, + mesh: jax.sharding.Mesh, + *, + model_path: Optional[str] = None, + param_mapping_path: Optional[str] = None, + model_name: Optional[str] = None, + timeout: int = 60, + ): + # Validation: need at least param_mapping_path or model_name + if param_mapping_path is None and model_name is None: + raise ValueError( + "Either param_mapping_path or model_name must be provided" + ) + + # Store configuration + self.gateway_url = gateway_url + self.mesh = mesh + self.model_name = model_name + self.model_path = model_path + self.param_mapping_path = param_mapping_path + + # Set up gRPC channel and stubs (no TrainerClient) + self._channel = grpc.insecure_channel(gateway_url) + grpc.channel_ready_future(self._channel).result(timeout=timeout) + self._controller_stub = ctrl_grpc.CouplingControllerStub(self._channel) + self._broker_stub = broker_grpc.MessageBrokerStub(self._channel) + # Use daemon threads so they don't block process exit + self._executor = ThreadPoolExecutor(thread_name_prefix="offloading_session") + self._shutdown = False + + # Perform handshake + self._handshake_result = self._do_handshake() + + # Parse handshake results + self.mapping_specs = proto_to_dataclass( + self._handshake_result.mapping_specs, 'mapping_specs' + ) + self.jax_parallelism = self._handshake_result.jax_parallelism + self.vllm_parallelism = self._handshake_result.vllm_parallelism + + logger.warning( + f"OffloadingSession initialized: JAX TP={self.jax_parallelism.tp}, " + f"vLLM TP={self.vllm_parallelism.tp}" + ) + + @on_spmd_leader( + serializer=lambda m: m.SerializeToString(), + deserializer=lambda b: (lambda r: (r.ParseFromString(b), r)[1])(ctrl.HandshakeResponse()), + ) + def _do_handshake(self) -> ctrl.HandshakeResponse: + """Perform handshake with vLLM. Executed on leader, broadcast to all ranks.""" + response_topic_id = f"handshake/results/{secrets.token_hex(16)}" + + # Set up response stream subscription + stream = self._broker_stub.SubscriptionStream( + SubscribeRequest(topics=[create_topic(response_topic_id)]) + ) + + # Build handshake request + request = ctrl.HandshakeRequest( + response_topic=response_topic_id, + model_name=self.model_name or "", + jax_parallelism=ctrl.JaxParallelism(tp=self.mesh.devices.size), + ) + + # Include param_mapping_path if provided + if self.param_mapping_path: + request.param_mapping_path = self.param_mapping_path + + # Send handshake request + self._controller_stub.AsyncHandshake(request) + + # Wait for response + for delivery in stream: + result = ctrl.HandshakeResponse() + delivery.message.payload.Unpack(result) + + # Validate TP sizes + is_power_of_2 = lambda x: x > 0 and (x & (x - 1)) == 0 + assert is_power_of_2(result.vllm_parallelism.tp), \ + "vLLM TP size must be a power of 2." + assert is_power_of_2(result.jax_parallelism.tp), \ + "JAX TP size must be a power of 2." + + return result + + def get_nccl_id(self) -> Tuple[int, ...]: + """Get NCCL unique ID from gateway.""" + return tuple(self._controller_stub.GetNcclId(ctrl.GetNcclIdRequest()).ids) + + @property + def controller_stub(self): + """Access the gRPC controller stub.""" + return self._controller_stub + + @property + def broker_stub(self): + """Access the gRPC message broker stub.""" + return self._broker_stub + + @property + def executor(self): + """Access the thread pool executor.""" + return self._executor + + def shutdown(self, shutdown_gateway: bool = True, grace_period: int = 1): + """Close the gRPC channel and executor. + + Args: + shutdown_gateway: If True, send shutdown signal to gateway server. + grace_period: Grace period in seconds for gateway shutdown. + """ + if self._shutdown: + return + self._shutdown = True + + # Send shutdown signal to gateway (this also shuts down vLLM rollout) + if shutdown_gateway: + try: + self._controller_stub.Shutdown( + ctrl.ShutdownRequest(grace_period=grace_period) + ) + except Exception: + pass + + # Close gRPC channel - this will cause streams to fail + try: + self._channel.close() + except Exception: + pass + + # Shutdown executor and wait for threads to terminate + try: + self._executor.shutdown(wait=True, cancel_futures=True) + except TypeError: + # Python < 3.9 doesn't support cancel_futures + self._executor.shutdown(wait=True) + + def __enter__(self): + """Context manager entry.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit - ensures shutdown is called.""" + self.shutdown() + return False diff --git a/jax-inference-offloading/jax_inference_offloading/transport/model/nccl_fused.py b/jax-inference-offloading/jax_inference_offloading/transport/model/nccl_fused.py index 465d2e745..adffe38a4 100644 --- a/jax-inference-offloading/jax_inference_offloading/transport/model/nccl_fused.py +++ b/jax-inference-offloading/jax_inference_offloading/transport/model/nccl_fused.py @@ -26,10 +26,19 @@ from jax.experimental.shard_map import shard_map from jax.sharding import Mesh, PartitionSpec +from typing import Protocol, runtime_checkable + from jax_inference_offloading.api.param_mapping_pb2 import TpModelMappingSpecs from jax_inference_offloading.api.utils import DataclassFor -from jax_inference_offloading.controller.trainer_client import TrainerClient from jax_inference_offloading.sharding import PolymorphicMesh + + +@runtime_checkable +class WeightTransferGateway(Protocol): + """Protocol for objects that can trigger weight transfer.""" + def start_weight_transfer(self, mode: str) -> None: + """Signal the rollout side to start receiving weights.""" + ... from jax_inference_offloading.transport.tensor.nccl_base import nccl_type from jax_inference_offloading.transport.tensor.nccl_star import NcclStarTransport from jax_inference_offloading.timer import Timer @@ -53,7 +62,7 @@ def __init__( self, main_mesh: Union[Mesh, PolymorphicMesh], mapping_specs: DataclassFor[TpModelMappingSpecs], - gateway: TrainerClient, + gateway: WeightTransferGateway, transports: List[NcclStarTransport], transport_config: Dict[str, Any], timer: Timer | None = None, diff --git a/jax-inference-offloading/jax_inference_offloading/transport/model/nccl_grouped.py b/jax-inference-offloading/jax_inference_offloading/transport/model/nccl_grouped.py index 944ff72db..a8f55be85 100644 --- a/jax-inference-offloading/jax_inference_offloading/transport/model/nccl_grouped.py +++ b/jax-inference-offloading/jax_inference_offloading/transport/model/nccl_grouped.py @@ -26,10 +26,19 @@ from jax.experimental.shard_map import shard_map from jax.sharding import Mesh, PartitionSpec +from typing import Protocol, runtime_checkable + from jax_inference_offloading.api.param_mapping_pb2 import TpModelMappingSpecs from jax_inference_offloading.api.utils import DataclassFor -from jax_inference_offloading.controller.trainer_client import TrainerClient from jax_inference_offloading.sharding import PolymorphicMesh + + +@runtime_checkable +class WeightTransferGateway(Protocol): + """Protocol for objects that can trigger weight transfer.""" + def start_weight_transfer(self, mode: str) -> None: + """Signal the rollout side to start receiving weights.""" + ... from jax_inference_offloading.transport.tensor.nccl_base import nccl_type from jax_inference_offloading.transport.tensor.nccl_star import NcclStarTransport from jax_inference_offloading.timer import Timer @@ -53,7 +62,7 @@ def __init__( self, main_mesh: Union[Mesh, PolymorphicMesh], mapping_specs: DataclassFor[TpModelMappingSpecs], - gateway: TrainerClient, + gateway: WeightTransferGateway, transports: List[NcclStarTransport], transport_config: Dict[str, Any], timer: Timer | None = None, @@ -144,7 +153,8 @@ def _fused_transfer(state_dict, mappings): # reshard partition_spec = [None] * len(param.vllm_param.shape) sharding_specs = param.vllm_param.tp_sharding - if sharding_specs.parallelism * transform.replication_count > 1: + rep_count = getattr(transform, 'replication_count', 1) if transform else 1 + if sharding_specs.parallelism * rep_count > 1: partition_spec[sharding_specs.dim] = axis_dst partition_spec[sharding_specs.aux_dim] = axis_src elif sharding_specs.parallelism == -1: @@ -206,7 +216,7 @@ def transport_callback(ctx, outputs, inputs): # Group all scatters together for better performance nccl.groupStart() - for buffers_for_param in all_buffers_list: + for param_idx, buffers_for_param in enumerate(all_buffers_list): # Each buffers_for_param contains k buffers, one for each vLLM rank for peer, buffer_dlpack in zip(range(1, comm.size()), buffers_for_param): buffer = jnp.from_dlpack(buffer_dlpack) diff --git a/jax-inference-offloading/jax_inference_offloading/transport/model/nccl_unfused.py b/jax-inference-offloading/jax_inference_offloading/transport/model/nccl_unfused.py index f7ed76371..dafc7d642 100644 --- a/jax-inference-offloading/jax_inference_offloading/transport/model/nccl_unfused.py +++ b/jax-inference-offloading/jax_inference_offloading/transport/model/nccl_unfused.py @@ -25,10 +25,19 @@ from jax import lax from jax.sharding import Mesh, PartitionSpec +from typing import Protocol, runtime_checkable + from jax_inference_offloading.api.param_mapping_pb2 import TpModelMappingSpecs from jax_inference_offloading.api.utils import DataclassFor -from jax_inference_offloading.controller.trainer_client import TrainerClient from jax_inference_offloading.sharding import PolymorphicMesh + + +@runtime_checkable +class WeightTransferGateway(Protocol): + """Protocol for objects that can trigger weight transfer.""" + def start_weight_transfer(self, mode: str) -> None: + """Signal the rollout side to start receiving weights.""" + ... from jax_inference_offloading.transport.tensor.nccl_star import NcclStarTransport from jax_inference_offloading.timer import Timer @@ -65,7 +74,7 @@ def __init__( self, main_mesh: Union[Mesh, PolymorphicMesh], mapping_specs: DataclassFor[TpModelMappingSpecs], - gateway: TrainerClient, + gateway: WeightTransferGateway, transports: List[NcclStarTransport], transport_config: Dict[str, Any], timer: Timer | None = None, diff --git a/jax-inference-offloading/jax_inference_offloading/transport/tensor/nccl_star.py b/jax-inference-offloading/jax_inference_offloading/transport/tensor/nccl_star.py index c5d0f3ba2..62e6ccaa6 100644 --- a/jax-inference-offloading/jax_inference_offloading/transport/tensor/nccl_star.py +++ b/jax-inference-offloading/jax_inference_offloading/transport/tensor/nccl_star.py @@ -206,18 +206,20 @@ def gather_grouped( assert self._comm.rank_id() == 0, \ "Star gather must converge on the root (rank 0)." - gathered_tensors = [] - + # Pre-allocate all shard buffers and collect recv operations + all_shards = [] # List of shard lists, one per parameter + with cuda.Device(self._comm.device_id()): - # For each parameter, we need to receive shards from all peers + stream = cuda.get_current_stream().ptr + + # Single NCCL group for ALL receives to match JAX's single send group + nccl.groupStart() for shape, dtype, dim, parallelism in param_specs: shards = [] shard_shape = np.array(shape, dtype=np.int32) shard_shape[dim] //= parallelism shard_shape = shard_shape.tolist() - # Start grouped NCCL operations for this parameter across all peers - nccl.groupStart() for peer in range(1, self._comm.size()): # rank 0 is the current GPU shard = torch.empty(shard_shape, dtype=getattr(torch, dtype), device=torch.cuda.current_device()) self._comm.recv( @@ -225,15 +227,16 @@ def gather_grouped( count=shard.numel(), datatype=nccl_type(dtype), peer=peer, - stream=cuda.get_current_stream().ptr, + stream=stream, ) shards.append(shard) - nccl.groupEnd() - - gathered_tensors.append(torch.cat(shards, dim=dim)) + all_shards.append((shards, dim)) + nccl.groupEnd() cudart.deviceSynchronize() + # Concatenate shards for each parameter + gathered_tensors = [torch.cat(shards, dim=dim) for shards, dim in all_shards] return gathered_tensors def scatter(self, buffers: List[Any]) -> None: diff --git a/jax-inference-offloading/jax_inference_offloading/tunix/rollout.py b/jax-inference-offloading/jax_inference_offloading/tunix/rollout.py deleted file mode 100644 index af8dc9adf..000000000 --- a/jax-inference-offloading/jax_inference_offloading/tunix/rollout.py +++ /dev/null @@ -1,157 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Rollout worker with offloading to vLLM.""" -from typing import Any, Optional, Tuple - -import jax -import jax.numpy as jnp -import jaxtyping - -import jax_inference_offloading.api.controller_pb2 as ctrl -from jax_inference_offloading.jax import OffloadingBridge -from jax_inference_offloading.timer import Timer -from tunix.rl.rollout.base_rollout import BaseRollout, RolloutConfig, RolloutOutput - - -class VllmGPURollout(BaseRollout): - - def __init__( - self, - gateway_url, - model_name, - *, - rollout_actor, # AKA rollout model - tokenizer, - mesh, - rollout_config, - extra_stop_tokens: list[str] | None = None, - transfer_mode: str = 'fused', - timer: Any | None = None, - ): - self._timer = timer or Timer() - self._tokenizer = tokenizer - self._bridge = OffloadingBridge( - gateway_url=gateway_url, - model_name=model_name, - mesh=mesh, - transfer_mode=transfer_mode, - timer=self._timer, - ) - self._extra_stop_token_ids = [] - for t in extra_stop_tokens or []: - i = self._tokenizer.encode(t) - assert len(i) == 1, f"Stop token {t} must be a single token, got {i}" - self._extra_stop_token_ids.extend(i) - - def generate( - self, - prompts: list[str], - rollout_config: RolloutConfig, - ): - """Generates samples from the model.""" - with self._timer.section("rollout.generate"): - remote_rollout_config = ctrl.RolloutConfig( - top_p=rollout_config.top_p, - top_k=rollout_config.top_k, - temperature=rollout_config.temperature, - max_tokens=rollout_config.max_tokens_to_generate, - seed=rollout_config.seed, - ) - if rollout_config.eos_tokens is not None: - remote_rollout_config.stop_token_ids.extend(rollout_config.eos_tokens) - else: - remote_rollout_config.stop_token_ids.extend([self._tokenizer.eos_id()]) - remote_rollout_config.stop_token_ids.extend(self._extra_stop_token_ids) - - with self._timer.section("inference"): - response = self._bridge.gateway.inference([str(p) for p in prompts], config=remote_rollout_config) - - with self._timer.section("process_outputs"): - generated_text = [] - input_tokens = [] - output_tokens = [] - - def pad_to_left(original, length, pad_value): - assert len(original) <= length - return [pad_value] * (length - len(original)) + original - - def pad_to_right(original, length, pad_value): - assert len(original) <= length - return original + [pad_value] * (length - len(original)) - - for i, output in enumerate(response.outputs): - if i < 1: - print(f"# Rollout {i} of {len(prompts)}") - print(f"## Prompt:\n{prompts[i]}") - print(f"## Response:\n{output.generated_text}") - print("-" * 80) - generated_text.append(output.generated_text) - input_tokens.append( - pad_to_left(list(output.tokenized_prompt.ids), rollout_config.max_prompt_length, self._tokenizer.pad_id()) - ) - output_tokens.append( - pad_to_right(list(output.generated_tokens.ids), rollout_config.max_tokens_to_generate, self._tokenizer.pad_id()) - ) - - return RolloutOutput( - text=generated_text, - logits=[], # not needed for GRPO - tokens=jnp.array(output_tokens, dtype=jnp.int32), - left_padded_prompt_tokens=jnp.array(input_tokens, dtype=jnp.int32), - logprobs=None, # needed for GRPO, GRPOLearner will recalc - ) - - def get_per_token_logps( - self, - prompt_tokens: jax.Array, - completion_tokens: jax.Array, - completion_mask: jax.Array | None = None, - ) -> jax.Array: - raise NotImplementedError() - - def update_params( - self, - params: jaxtyping.PyTree, - filter_types: Optional[Tuple[Any, ...]] = None, - ) -> None: - """Updates the rollout model parameters.""" - with self._timer.section("rollout.update_params"): - self._bridge.transfer(params) - - def pad_id(self) -> int: - return self._tokenizer.pad_id() - - def eos_id(self) -> int: - return self._tokenizer.eos_id() - - def model(self): - return None - - def shutdown(self) -> None: - """Gracefully shutdown the remote gateway if available.""" - try: - self._bridge.gateway.shutdown() - except Exception: - # Ignore shutdown errors; process teardown or remote unavailability is expected. - pass - - def __del__(self): - try: - self.shutdown() - except Exception: - # Suppress destructor-time errors during interpreter shutdown. - pass diff --git a/jax-inference-offloading/jax_inference_offloading/vllm/extension.py b/jax-inference-offloading/jax_inference_offloading/vllm/extension.py index 61e4d1394..7051bfe58 100644 --- a/jax-inference-offloading/jax_inference_offloading/vllm/extension.py +++ b/jax-inference-offloading/jax_inference_offloading/vllm/extension.py @@ -60,7 +60,7 @@ def set_sharding(self): # Prevent unquantized linear modules from using V2 weight loader if "UnquantizedLinearMethod" in WEIGHT_LOADER_V2_SUPPORTED: WEIGHT_LOADER_V2_SUPPORTED.remove("UnquantizedLinearMethod") - logger.warning("Removed UnquantizedLinearMethod from WEIGHT_LOADER_V2_SUPPORTED.") + logger.debug("Removed UnquantizedLinearMethod from WEIGHT_LOADER_V2_SUPPORTED.") for name, module in self.model_runner.model.named_modules(): if type(module) in [ @@ -69,7 +69,6 @@ def set_sharding(self): MergedColumnParallelLinear, QKVParallelLinear, ]: - logger.debug(f"Setting sharding for module: {name} of type {type(module)}") # force to use the V1 weight_loader module.weight.weight_loader = module.weight_loader # instruct V1 loader to treat the incoming weight as pre-sharded @@ -214,7 +213,7 @@ def update_weights(self, mapping_specs: TpModelMappingSpecs): self.commit_staged_weights(reset=True) self.sync() - logger.warning("done receiving") + logger.debug("done receiving") self.commit_staged_weights(reset=True) def update_weights_grouped(self, mapping_specs: TpModelMappingSpecs): @@ -242,9 +241,9 @@ def update_weights_grouped(self, mapping_specs: TpModelMappingSpecs): param_names.append(param.vllm_param.name) # Receive all weights in one grouped operation - logger.warning(f'vLLM TP rank {tp_rank} receiving {len(param_specs)} weights via grouped gather...') + logger.debug(f'vLLM TP rank {tp_rank} receiving {len(param_specs)} weights via grouped gather...') weights = self.transport.gather_grouped(param_specs) - logger.warning(f'vLLM TP rank {tp_rank} received all weights') + logger.debug(f'vLLM TP rank {tp_rank} received all weights') # Stage all received weights for name, weight in zip(param_names, weights): @@ -268,9 +267,9 @@ def update_weights_grouped(self, mapping_specs: TpModelMappingSpecs): param_names.append(param.vllm_param.name) # Receive all weights in one grouped operation - logger.warning(f'vLLM TP rank {tp_rank} receiving {len(param_specs)} weights via grouped recv...') + logger.debug(f'vLLM TP rank {tp_rank} receiving {len(param_specs)} weights via grouped recv...') weights = self.transport.recv_grouped(param_specs) - logger.warning(f'vLLM TP rank {tp_rank} received all weights') + logger.debug(f'vLLM TP rank {tp_rank} received all weights') # Stage all received weights for name, weight in zip(param_names, weights): @@ -280,11 +279,13 @@ def update_weights_grouped(self, mapping_specs: TpModelMappingSpecs): self.commit_staged_weights(reset=True) self.sync() - logger.warning("done receiving") - self.commit_staged_weights(reset=True) + logger.debug("done receiving") def commit_staged_weights(self, reset=True): + logger.debug(f"Committing {len(self._staged_weights)} staged weights") + loaded_weights = self.model_runner.model.load_weights(weights=self._staged_weights) + if reset: self.reset_stage() return loaded_weights diff --git a/jax-inference-offloading/setup.py b/jax-inference-offloading/setup.py index 6bd052756..030934ac7 100644 --- a/jax-inference-offloading/setup.py +++ b/jax-inference-offloading/setup.py @@ -66,7 +66,7 @@ def run(self): 'jax==0.8.1', 'jaxtyping', 'kagglehub', - 'vllm==0.11.2', + 'vllm==0.14.0', ], extras_require={ 'test': ['pytest>=7.0'],