From b3c368cda4973803e8b59e6bfd9a77f77ab8ca32 Mon Sep 17 00:00:00 2001 From: Masahiro Ogawa Date: Mon, 23 Mar 2026 16:22:53 +0900 Subject: [PATCH 1/3] feat: add speed benchmark, interactive example, and RTX 4080 test - Add benchmarks/benchmark_speed_mamba123.py: compares Mamba1 vs Mamba2 vs Mamba3 forward/backward speed on identical workload with visual chart - Add examples/predict_next_token.py: interactive next-token prediction with example outputs on launch showing what base LM prediction looks like - Add tests/test_rtx4080.py: RTX 4080 specific test for pretrained inference and Mamba3 SISO module - Add configs/rtx4080.json: model presets sized for 12GB VRAM - Update benchmarks/benchmark_README.md with both benchmarks - Update README.md with new sections Co-Authored-By: Claude Opus 4.6 (1M context) --- README.md | 30 ++++ benchmarks/benchmark_README.md | 39 +++++ benchmarks/benchmark_speed_mamba123.py | 188 +++++++++++++++++++++++++ configs/rtx4080.json | 53 +++++++ examples/predict_next_token.py | 116 +++++++++++++++ tests/test_rtx4080.py | 156 ++++++++++++++++++++ 6 files changed, 582 insertions(+) create mode 100644 benchmarks/benchmark_README.md create mode 100644 benchmarks/benchmark_speed_mamba123.py create mode 100644 configs/rtx4080.json create mode 100644 examples/predict_next_token.py create mode 100644 tests/test_rtx4080.py diff --git a/README.md b/README.md index c6435be3d..6e660ae93 100755 --- a/README.md +++ b/README.md @@ -134,6 +134,21 @@ Source: [models/mixer_seq_simple.py](mamba_ssm/models/mixer_seq_simple.py). This is an example of how to integrate Mamba into an end-to-end neural network. This example is used in the generation scripts below. +### Interactive Next-Token Prediction + +Try Mamba interactively with [examples/predict_next_token.py](examples/predict_next_token.py). It loads a pretrained model, shows example outputs on launch, then lets you type prompts and see continuations. + +``` sh +# Mamba-1 +python examples/predict_next_token.py --model "state-spaces/mamba-130m" + +# Mamba-2 +python examples/predict_next_token.py --model "state-spaces/mamba2-130m" +``` + +Available models: `mamba-130m`, `mamba-370m`, `mamba-790m`, `mamba-1.4b`, `mamba-2.8b`, `mamba2-130m`, `mamba2-370m`, `mamba2-780m`, `mamba2-1.3b`, `mamba2-2.7b`. + +Note: These are **base language models**, not chatbots. They predict the most likely continuation of your text based on training data (The Pile). Longer, specific prompts work best — short inputs like "hello" may produce random document fragments. Mamba-3 is not supported here as no pretrained language model is available yet. ## Pretrained Models @@ -222,6 +237,21 @@ python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-space ``` +## Forward/Backward Speed Comparison + +The script [benchmarks/benchmark_speed_mamba123.py](benchmarks/benchmark_speed_mamba123.py) +compares forward and backward pass speed of Mamba1 vs Mamba2 vs Mamba3 on the same workload (same d_model, sequence length, batch size, dtype). + +While the inference benchmark above measures text generation latency, this measures the raw computation speed of each architecture's forward and backward pass. + +``` sh +python benchmarks/benchmark_speed_mamba123.py +``` + +The script outputs per-module timing (ms) and peak VRAM (GB), and saves a visual comparison chart. + +**Note:** Mamba3 MIMO mode requires >100KB shared memory per SM (available on H100/sm_90+). On consumer GPUs like RTX 4080 (sm_89), use SISO mode instead. + ## Troubleshooting ### Precision diff --git a/benchmarks/benchmark_README.md b/benchmarks/benchmark_README.md new file mode 100644 index 000000000..804e8e6eb --- /dev/null +++ b/benchmarks/benchmark_README.md @@ -0,0 +1,39 @@ +# Benchmarks + +## Overview + +| Script | Purpose | +|---|---| +| `benchmark_text_generation_latency.py` | How fast can the model write text? | +| `benchmark_speed_mamba123.py` | Which Mamba generation computes fastest? | + +## benchmark_text_generation_latency.py + +Loads a pretrained model (Mamba or Transformer), feeds a prompt, generates tokens, and measures the total time. + +- **Measures:** End-to-end latency — prompt processing + token-by-token decoding (ms) +- **Models:** One pretrained model at a time (e.g. `state-spaces/mamba-2.8b`) +- **Includes backward?** No (inference only) +- **Use case:** Evaluating deployment/serving performance + +```sh +python benchmarks/benchmark_text_generation_latency.py \ + --model-name "state-spaces/mamba-2.8b" \ + --prompt "My cat wrote all this CUDA code for a new language model and" \ + --topp 0.9 --temperature 0.7 +``` + +## benchmark_speed_mamba123.py + +Creates Mamba1, Mamba2, and Mamba3 modules with identical settings (same d_model, sequence length, batch size, dtype), runs the same tensor through each, and times forward + backward passes. + +- **Measures:** Raw computation speed — forward pass (ms) + backward pass (ms) + VRAM (GB) +- **Models:** Mamba1, Mamba2, Mamba3 side by side (random weights, same config) +- **Includes backward?** Yes +- **Use case:** Comparing architecture efficiency for research/training + +```sh +python benchmarks/benchmark_speed_mamba123.py +``` + +**Note:** Mamba3 MIMO mode requires >100KB shared memory per SM (available on H100/sm_90+). On consumer GPUs like RTX 4080 (sm_89), use SISO mode instead. diff --git a/benchmarks/benchmark_speed_mamba123.py b/benchmarks/benchmark_speed_mamba123.py new file mode 100644 index 000000000..7f8fb1431 --- /dev/null +++ b/benchmarks/benchmark_speed_mamba123.py @@ -0,0 +1,188 @@ +#!/usr/bin/env python3 +"""Benchmark forward/backward speed: Mamba1 vs Mamba2 vs Mamba3. + +Compares all three SSM generations on the same workload (same d_model, +sequence length, batch size) measuring forward/backward speed and VRAM. + +Usage: + python benchmarks/benchmark_speed_mamba123.py +""" + +import gc +import time +from pathlib import Path + +import matplotlib.pyplot as plt +import matplotlib +import torch + +matplotlib.rcParams.update({"font.size": 11, "figure.dpi": 150}) + +GB = 1024**3 +OUTPUT_PATH = Path(__file__).parent / "speed_mamba123.png" + +# Common benchmark parameters +D_MODEL = 768 +BATCH = 2 +SEQ_LEN = 2048 +DTYPE = torch.bfloat16 +N_WARMUP = 2 +N_RUNS = 5 + + +def bench_module(name, create_fn): + """Benchmark a single Mamba module, returning timing and VRAM stats.""" + print(f" Benchmarking {name}...") + model = create_fn().to("cuda") + params = sum(p.numel() for p in model.parameters()) / 1e6 + + x = torch.randn(BATCH, SEQ_LEN, D_MODEL, dtype=DTYPE, device="cuda") + + # Warm-up + for _ in range(N_WARMUP): + with torch.no_grad(): + model(x) + torch.cuda.synchronize() + + # Forward + torch.cuda.reset_peak_memory_stats() + fwd_times = [] + for _ in range(N_RUNS): + x_fwd = torch.randn(BATCH, SEQ_LEN, D_MODEL, dtype=DTYPE, device="cuda") + torch.cuda.synchronize() + t0 = time.time() + with torch.no_grad(): + model(x_fwd) + torch.cuda.synchronize() + fwd_times.append((time.time() - t0) * 1000) + fwd_ms = sum(fwd_times) / len(fwd_times) + fwd_vram = torch.cuda.max_memory_allocated() / GB + + # Backward + torch.cuda.reset_peak_memory_stats() + bwd_times = [] + for _ in range(N_RUNS): + x_bwd = torch.randn(BATCH, SEQ_LEN, D_MODEL, dtype=DTYPE, device="cuda", requires_grad=True) + y = model(x_bwd) + loss = y.sum() + torch.cuda.synchronize() + t0 = time.time() + loss.backward() + torch.cuda.synchronize() + bwd_times.append((time.time() - t0) * 1000) + bwd_ms = sum(bwd_times) / len(bwd_times) + bwd_vram = torch.cuda.max_memory_allocated() / GB + + del model, x + gc.collect() + torch.cuda.empty_cache() + + return { + "params_m": params, + "fwd_ms": fwd_ms, + "bwd_ms": bwd_ms, + "fwd_vram_gb": fwd_vram, + "bwd_vram_gb": bwd_vram, + } + + +def collect_results(): + from mamba_ssm.modules.mamba_simple import Mamba + from mamba_ssm.modules.mamba2 import Mamba2 + from mamba_ssm.modules.mamba3 import Mamba3 + + models = { + "Mamba1": lambda: Mamba(d_model=D_MODEL, d_state=16, dtype=DTYPE), + "Mamba2": lambda: Mamba2(d_model=D_MODEL, d_state=128, headdim=64, dtype=DTYPE), + "Mamba3\n(SISO)": lambda: Mamba3( + d_model=D_MODEL, d_state=128, headdim=64, + is_mimo=False, chunk_size=64, is_outproj_norm=False, dtype=DTYPE, + ), + } + + results = {} + for name, create_fn in models.items(): + results[name] = bench_module(name, create_fn) + return results + + +def plot_results(results): + names = list(results.keys()) + colors = ["#4C78A8", "#F58518", "#E45756"] + + fig, axes = plt.subplots(2, 2, figsize=(12, 9)) + fig.suptitle( + f"Mamba1 vs Mamba2 vs Mamba3 — {torch.cuda.get_device_name(0)}\n" + f"d_model={D_MODEL}, seq_len={SEQ_LEN}, batch={BATCH}, dtype=bf16 | " + f"PyTorch {torch.__version__} | CUDA {torch.version.cuda}", + fontsize=12, fontweight="bold", + ) + + def add_bar_labels(ax, bars, values, fmt="{:.1f}"): + for bar, v in zip(bars, values): + ax.text(bar.get_x() + bar.get_width() / 2, + bar.get_height() + max(values) * 0.03, + fmt.format(v), ha="center", va="bottom", fontsize=10) + + # 1) Forward time + ax = axes[0, 0] + vals = [results[n]["fwd_ms"] for n in names] + bars = ax.bar(names, vals, color=colors) + ax.set_ylabel("ms") + ax.set_title("Forward Pass (lower is better)") + add_bar_labels(ax, bars, vals, "{:.1f}ms") + + # 2) Backward time + ax = axes[0, 1] + vals = [results[n]["bwd_ms"] for n in names] + bars = ax.bar(names, vals, color=colors) + ax.set_ylabel("ms") + ax.set_title("Backward Pass (lower is better)") + add_bar_labels(ax, bars, vals, "{:.1f}ms") + + # 3) VRAM (forward vs fwd+bwd side by side) + ax = axes[1, 0] + x_pos = range(len(names)) + w = 0.35 + fwd_vram = [results[n]["fwd_vram_gb"] for n in names] + bwd_vram = [results[n]["bwd_vram_gb"] for n in names] + ax.bar([p - w / 2 for p in x_pos], fwd_vram, w, label="Forward", color="#4C78A8") + ax.bar([p + w / 2 for p in x_pos], bwd_vram, w, label="Fwd+Backward", color="#E45756") + ax.axhline(y=11.6, color="red", linestyle="--", alpha=0.4, label="VRAM limit (11.6 GB)") + ax.set_ylabel("GB") + ax.set_title("Peak VRAM Usage") + ax.set_xticks(list(x_pos)) + ax.set_xticklabels(names) + ax.legend(fontsize=9) + + # 4) Parameters + ax = axes[1, 1] + vals = [results[n]["params_m"] for n in names] + bars = ax.bar(names, vals, color=colors) + ax.set_ylabel("Parameters (M)") + ax.set_title("Module Parameters") + add_bar_labels(ax, bars, vals, "{:.1f}M") + + plt.tight_layout() + plt.savefig(OUTPUT_PATH, bbox_inches="tight") + print(f"\nSaved: {OUTPUT_PATH}") + + +def main(): + print(f"GPU: {torch.cuda.get_device_name(0)}") + print(f"Benchmark: d_model={D_MODEL}, seq_len={SEQ_LEN}, batch={BATCH}, bf16") + print(f"Averaging over {N_RUNS} runs after {N_WARMUP} warmup\n") + + results = collect_results() + + print("\n--- Results ---") + for name, r in results.items(): + print(f" {name:12s} params={r['params_m']:.1f}M " + f"fwd={r['fwd_ms']:.1f}ms bwd={r['bwd_ms']:.1f}ms " + f"vram_fwd={r['fwd_vram_gb']:.2f}GB vram_bwd={r['bwd_vram_gb']:.2f}GB") + + plot_results(results) + + +if __name__ == "__main__": + main() diff --git a/configs/rtx4080.json b/configs/rtx4080.json new file mode 100644 index 000000000..ef77ff856 --- /dev/null +++ b/configs/rtx4080.json @@ -0,0 +1,53 @@ +{ + "gpu": { + "name": "RTX 4080", + "compute_capability": "8.9", + "vram_gb": 12, + "architecture": "Ada Lovelace" + }, + "models": { + "mamba1_130m": { + "pretrained": "state-spaces/mamba-130m", + "d_model": 768, + "n_layer": 24, + "dtype": "bfloat16", + "note": "Smallest model, fast iteration" + }, + "mamba1_370m": { + "pretrained": "state-spaces/mamba-370m", + "d_model": 1024, + "n_layer": 48, + "dtype": "bfloat16", + "note": "Good balance of quality and speed" + }, + "mamba1_1.4b": { + "pretrained": "state-spaces/mamba-1.4b", + "d_model": 2048, + "n_layer": 48, + "dtype": "bfloat16", + "note": "Larger model, fits in 12GB VRAM in bf16" + }, + "mamba1_2.8b": { + "pretrained": "state-spaces/mamba-2.8b", + "d_model": 2560, + "n_layer": 64, + "dtype": "bfloat16", + "note": "Largest mamba-1, tight fit in 12GB VRAM" + }, + "mamba3_test": { + "d_model": 768, + "d_state": 128, + "n_layer": 12, + "headdim": 64, + "is_mimo": false, + "chunk_size": 64, + "dtype": "bfloat16", + "note": "Mamba3 SISO mode for RTX 4080 (MIMO backward needs >100KB shared mem, only available on H100+)" + } + }, + "defaults": { + "dtype": "bfloat16", + "max_batch_size": 1, + "max_seq_len": 2048 + } +} diff --git a/examples/predict_next_token.py b/examples/predict_next_token.py new file mode 100644 index 000000000..64d19f2b0 --- /dev/null +++ b/examples/predict_next_token.py @@ -0,0 +1,116 @@ +#!/usr/bin/env python3 +"""Interactive next-token prediction with Mamba. + +Mamba is a BASE language model (not a chatbot). It predicts the most likely +continuation of your text, as if completing a document from its training data +(The Pile: web text, code, Wikipedia, academic papers, etc.). + +Usage: + python examples/predict_next_token.py [--model MODEL] [--genlen N] [--skip-examples] + +Models: state-spaces/mamba-130m (default), state-spaces/mamba-370m, + state-spaces/mamba-1.4b, state-spaces/mamba-2.8b, + state-spaces/mamba2-130m, state-spaces/mamba2-370m, + state-spaces/mamba2-1.3b, state-spaces/mamba2-2.7b +""" + +import argparse + +import torch +from transformers import AutoTokenizer + +from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel + +EXAMPLES = [ + ("The theory of relativity states that", + "Science/factual continuation"), + ("Once upon a time, in a land far away,", + "Story continuation"), + ("def fibonacci(n):", + "Code completion"), + ("The capital of Japan is", + "Factual knowledge"), +] + + +def run_examples(model, tokenizer, genlen, temperature, top_k, top_p): + """Run example prompts to show what the model does.""" + print("=" * 60) + print("EXAMPLE OUTPUTS (showing what next-token prediction does)") + print("=" * 60) + + for prompt, description in EXAMPLES: + input_ids = torch.tensor([tokenizer.encode(prompt)], device="cuda") + with torch.no_grad(): + out = model.generate( + input_ids=input_ids, + max_length=input_ids.shape[1] + min(genlen, 50), + temperature=temperature, + top_k=top_k, + top_p=top_p, + ) + generated = tokenizer.decode(out[0].cpu().tolist()) + print(f"\n[{description}]") + print(f" Prompt: {prompt}") + print(f" Mamba -> {generated}") + + print("\n" + "=" * 60) + print("NOTE: This is NOT a chatbot. It predicts the next tokens") + print("based on training data. Longer, specific prompts work best.") + print("Short inputs like 'hello' may produce random document fragments.") + print("=" * 60) + + +def main(): + parser = argparse.ArgumentParser(description="Interactive Mamba next-token prediction") + parser.add_argument("--model", default="state-spaces/mamba-130m", help="Pretrained model name") + parser.add_argument("--genlen", type=int, default=100, help="Max tokens to generate") + parser.add_argument("--temperature", type=float, default=0.7) + parser.add_argument("--top-k", type=int, default=50) + parser.add_argument("--top-p", type=float, default=0.9) + parser.add_argument("--dtype", default="bfloat16", choices=["float16", "bfloat16"]) + parser.add_argument("--skip-examples", action="store_true", help="Skip example outputs on launch") + args = parser.parse_args() + + dtype = getattr(torch, args.dtype) + + print(f"Loading {args.model}...") + tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") + model = MambaLMHeadModel.from_pretrained(args.model, device="cuda", dtype=dtype) + model.eval() + params = sum(p.numel() for p in model.parameters()) / 1e6 + print(f"Ready! ({params:.0f}M params, {args.dtype})\n") + + if not args.skip_examples: + run_examples(model, tokenizer, args.genlen, args.temperature, args.top_k, args.top_p) + + print(f"\nSettings: genlen={args.genlen}, temperature={args.temperature}, top_k={args.top_k}, top_p={args.top_p}") + print("Enter a prompt and Mamba will predict the continuation.") + print("Type 'quit' to exit.\n") + + while True: + try: + prompt = input(">>> ") + except (EOFError, KeyboardInterrupt): + print() + break + if prompt.strip().lower() in ("quit", "exit", "q"): + break + if not prompt.strip(): + continue + + input_ids = torch.tensor([tokenizer.encode(prompt)], device="cuda") + with torch.no_grad(): + out = model.generate( + input_ids=input_ids, + max_length=input_ids.shape[1] + args.genlen, + temperature=args.temperature, + top_k=args.top_k, + top_p=args.top_p, + ) + print(tokenizer.decode(out[0].cpu().tolist())) + print() + + +if __name__ == "__main__": + main() diff --git a/tests/test_rtx4080.py b/tests/test_rtx4080.py new file mode 100644 index 000000000..5b530ed97 --- /dev/null +++ b/tests/test_rtx4080.py @@ -0,0 +1,156 @@ +#!/usr/bin/env python3 +"""Test Mamba models on RTX 4080 (sm_89, 12GB VRAM). + +Usage: + uv run --no-project --python .venv/bin/python tests/test_rtx4080.py [--model MODEL] [--config CONFIG] + +Models: mamba1_130m, mamba1_370m, mamba1_1.4b, mamba1_2.8b, mamba3_test +""" + +import argparse +import json +import time +from pathlib import Path + +import torch + +CONFIGS_DIR = Path(__file__).parent.parent / "configs" +GB = 1024**3 + + +def load_gpu_config(config_path: str = None): + config_path = config_path or str(CONFIGS_DIR / "rtx4080.json") + with open(config_path) as f: + return json.load(f) + + +def print_gpu_info(): + print(f"GPU: {torch.cuda.get_device_name(0)}") + print(f"Compute capability: {torch.cuda.get_device_capability(0)}") + print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / GB:.1f} GB") + print(f"PyTorch: {torch.__version__}") + print(f"CUDA: {torch.version.cuda}") + print() + + +def test_mamba1_pretrained(model_name: str, config: dict): + """Test Mamba1/2 pretrained model via MambaLMHeadModel.""" + from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel + from transformers import AutoTokenizer + + model_cfg = config["models"][model_name] + pretrained = model_cfg["pretrained"] + dtype = getattr(torch, model_cfg.get("dtype", config["defaults"]["dtype"])) + + print(f"--- Loading {model_name} from {pretrained} ---") + t0 = time.time() + model = MambaLMHeadModel.from_pretrained(pretrained, device="cuda", dtype=dtype) + model.eval() + print(f"Model loaded in {time.time() - t0:.1f}s") + print(f"Parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M") + print(f"VRAM used: {torch.cuda.memory_allocated() / GB:.2f} GB") + + tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") + prompt = "The capital of France is" + input_ids = torch.tensor([tokenizer.encode(prompt)], device="cuda") + + print(f"\nPrompt: {prompt!r}") + with torch.no_grad(): + t0 = time.time() + out = model.generate( + input_ids=input_ids, + max_length=input_ids.shape[1] + 50, + temperature=0.7, + top_p=0.9, + top_k=50, + ) + elapsed = time.time() - t0 + + generated = tokenizer.decode(out[0].cpu().tolist()) + tokens_generated = out.shape[1] - input_ids.shape[1] + print(f"Generated ({tokens_generated} tokens in {elapsed:.2f}s, {tokens_generated/elapsed:.1f} tok/s):") + print(generated) + print(f"Peak VRAM: {torch.cuda.max_memory_allocated() / GB:.2f} GB") + print() + + +def test_mamba3(config: dict): + """Test Mamba3 module with random input (no pretrained model available yet).""" + from mamba_ssm.modules.mamba3 import Mamba3 + + model_cfg = config["models"]["mamba3_test"] + dtype = getattr(torch, model_cfg.get("dtype", config["defaults"]["dtype"])) + + is_mimo = model_cfg.get("is_mimo", False) + mode_str = "MIMO" if is_mimo else "SISO" + print(f"--- Testing Mamba3 module ({mode_str}, random weights) ---") + + mamba3_kwargs = dict( + d_model=model_cfg["d_model"], + d_state=model_cfg["d_state"], + headdim=model_cfg["headdim"], + is_mimo=is_mimo, + chunk_size=model_cfg["chunk_size"], + is_outproj_norm=False, + dtype=dtype, + ) + if is_mimo: + mamba3_kwargs["mimo_rank"] = model_cfg.get("mimo_rank", 4) + + model = Mamba3(**mamba3_kwargs).to("cuda") + + print(f"Parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M") + + batch, length, dim = 2, 2048, model_cfg["d_model"] + x = torch.randn(batch, length, dim, dtype=dtype, device="cuda") + + print(f"Input shape: {x.shape}") + with torch.no_grad(): + t0 = time.time() + y = model(x) + torch.cuda.synchronize() + elapsed = time.time() - t0 + + print(f"Output shape: {y.shape}") + print(f"Forward pass: {elapsed*1000:.1f}ms") + print(f"Peak VRAM: {torch.cuda.max_memory_allocated() / GB:.2f} GB") + + # Test backward pass + torch.cuda.reset_peak_memory_stats() + x.requires_grad_(True) + y = model(x) + loss = y.sum() + t0 = time.time() + loss.backward() + torch.cuda.synchronize() + elapsed = time.time() - t0 + print(f"Backward pass: {elapsed*1000:.1f}ms") + print(f"Peak VRAM (fwd+bwd): {torch.cuda.max_memory_allocated() / GB:.2f} GB") + print() + + +def main(): + parser = argparse.ArgumentParser(description="Test Mamba models on RTX 4080") + parser.add_argument( + "--model", default="mamba1_130m", + choices=["mamba1_130m", "mamba1_370m", "mamba1_1.4b", "mamba1_2.8b", "mamba3_test", "all"], + help="Model to test (default: mamba1_130m)", + ) + parser.add_argument("--config", default=None, help="Path to GPU config JSON") + args = parser.parse_args() + + config = load_gpu_config(args.config) + print_gpu_info() + + models = list(config["models"]) if args.model == "all" else [args.model] + for name in models: + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + if name == "mamba3_test": + test_mamba3(config) + else: + test_mamba1_pretrained(name, config) + + +if __name__ == "__main__": + main() From c6e611acd852c47b3730d5e927d4624b6d319ecc Mon Sep 17 00:00:00 2001 From: Masahiro Ogawa Date: Mon, 23 Mar 2026 16:31:49 +0900 Subject: [PATCH 2/3] fix: use dynamic VRAM limit and remove uv-specific usage from docstring - benchmark_speed_mamba123.py: read VRAM from GPU instead of hardcoded 11.6GB - tests/test_rtx4080.py: use generic python invocation in docstring Co-Authored-By: Claude Opus 4.6 (1M context) --- benchmarks/benchmark_speed_mamba123.py | 3 ++- tests/test_rtx4080.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/benchmarks/benchmark_speed_mamba123.py b/benchmarks/benchmark_speed_mamba123.py index 7f8fb1431..4136f3226 100644 --- a/benchmarks/benchmark_speed_mamba123.py +++ b/benchmarks/benchmark_speed_mamba123.py @@ -148,7 +148,8 @@ def add_bar_labels(ax, bars, values, fmt="{:.1f}"): bwd_vram = [results[n]["bwd_vram_gb"] for n in names] ax.bar([p - w / 2 for p in x_pos], fwd_vram, w, label="Forward", color="#4C78A8") ax.bar([p + w / 2 for p in x_pos], bwd_vram, w, label="Fwd+Backward", color="#E45756") - ax.axhline(y=11.6, color="red", linestyle="--", alpha=0.4, label="VRAM limit (11.6 GB)") + vram_total = torch.cuda.get_device_properties(0).total_memory / GB + ax.axhline(y=vram_total, color="red", linestyle="--", alpha=0.4, label=f"VRAM limit ({vram_total:.1f} GB)") ax.set_ylabel("GB") ax.set_title("Peak VRAM Usage") ax.set_xticks(list(x_pos)) diff --git a/tests/test_rtx4080.py b/tests/test_rtx4080.py index 5b530ed97..b8e744ebe 100644 --- a/tests/test_rtx4080.py +++ b/tests/test_rtx4080.py @@ -2,7 +2,7 @@ """Test Mamba models on RTX 4080 (sm_89, 12GB VRAM). Usage: - uv run --no-project --python .venv/bin/python tests/test_rtx4080.py [--model MODEL] [--config CONFIG] + python tests/test_rtx4080.py [--model MODEL] [--config CONFIG] Models: mamba1_130m, mamba1_370m, mamba1_1.4b, mamba1_2.8b, mamba3_test """ From d7a8313f2f21b0e3cd04ea0ceae8b0e4423ffb17 Mon Sep 17 00:00:00 2001 From: Masahiro Ogawa Date: Mon, 23 Mar 2026 16:55:05 +0900 Subject: [PATCH 3/3] feat: add visual chart for text generation latency comparison - Add benchmark_text_generation_latency_visual.py: compares generation latency across multiple Mamba model sizes with chart output - Uses dynamic VRAM limit detection - Preserves original benchmark_generation_mamba_simple.py untouched - Update benchmark_README.md to document all three benchmarks Co-Authored-By: Claude Opus 4.6 (1M context) --- benchmarks/benchmark_README.md | 28 ++- ...enchmark_text_generation_latency_visual.py | 197 ++++++++++++++++++ 2 files changed, 218 insertions(+), 7 deletions(-) create mode 100644 benchmarks/benchmark_text_generation_latency_visual.py diff --git a/benchmarks/benchmark_README.md b/benchmarks/benchmark_README.md index 804e8e6eb..dab3e930a 100644 --- a/benchmarks/benchmark_README.md +++ b/benchmarks/benchmark_README.md @@ -4,25 +4,39 @@ | Script | Purpose | |---|---| -| `benchmark_text_generation_latency.py` | How fast can the model write text? | -| `benchmark_speed_mamba123.py` | Which Mamba generation computes fastest? | +| `benchmark_generation_mamba_simple.py` | How fast can one model generate text? (single model, full CLI) | +| `benchmark_text_generation_latency_visual.py` | Compare generation latency across model sizes (visual chart) | +| `benchmark_speed_mamba123.py` | Which Mamba generation computes fastest? (visual chart) | -## benchmark_text_generation_latency.py +## benchmark_generation_mamba_simple.py -Loads a pretrained model (Mamba or Transformer), feeds a prompt, generates tokens, and measures the total time. +Loads a single pretrained model (Mamba or Transformer), feeds a prompt, generates tokens, and measures the total time. Full CLI with sampling options. - **Measures:** End-to-end latency — prompt processing + token-by-token decoding (ms) -- **Models:** One pretrained model at a time (e.g. `state-spaces/mamba-2.8b`) +- **Models:** One pretrained model at a time (Mamba or HuggingFace Transformer) - **Includes backward?** No (inference only) -- **Use case:** Evaluating deployment/serving performance +- **Use case:** Evaluating deployment/serving performance for a specific model ```sh -python benchmarks/benchmark_text_generation_latency.py \ +python benchmarks/benchmark_generation_mamba_simple.py \ --model-name "state-spaces/mamba-2.8b" \ --prompt "My cat wrote all this CUDA code for a new language model and" \ --topp 0.9 --temperature 0.7 ``` +## benchmark_text_generation_latency_visual.py + +Loads multiple pretrained Mamba models, generates tokens, and produces a visual comparison chart of throughput, latency, and VRAM usage. + +- **Measures:** Throughput (tok/s), latency (ms), VRAM (GB) across model sizes +- **Models:** Multiple Mamba models side by side (mamba-130m, mamba-370m) +- **Includes backward?** No (inference only) +- **Use case:** Comparing generation performance across model sizes + +```sh +python benchmarks/benchmark_text_generation_latency_visual.py +``` + ## benchmark_speed_mamba123.py Creates Mamba1, Mamba2, and Mamba3 modules with identical settings (same d_model, sequence length, batch size, dtype), runs the same tensor through each, and times forward + backward passes. diff --git a/benchmarks/benchmark_text_generation_latency_visual.py b/benchmarks/benchmark_text_generation_latency_visual.py new file mode 100644 index 000000000..7ae014065 --- /dev/null +++ b/benchmarks/benchmark_text_generation_latency_visual.py @@ -0,0 +1,197 @@ +#!/usr/bin/env python3 +"""Visual comparison of text generation latency across Mamba model sizes. + +Loads multiple pretrained Mamba models, generates tokens from a prompt, +and produces a comparison chart of throughput, latency, and VRAM usage. + +For single-model benchmarking with full CLI options, use +benchmark_generation_mamba_simple.py instead. + +Usage: + python benchmarks/benchmark_text_generation_latency_visual.py [--genlen N] [--dtype DTYPE] +""" + +import argparse +import gc +import time +from pathlib import Path + +import matplotlib.pyplot as plt +import matplotlib +import torch + +from transformers import AutoTokenizer + +from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel + +matplotlib.rcParams.update({"font.size": 11, "figure.dpi": 150}) + +GB = 1024**3 +OUTPUT_PATH = Path(__file__).parent / "text_generation_latency.png" + +MODELS = [ + ("mamba-130m", "state-spaces/mamba-130m"), + ("mamba-370m", "state-spaces/mamba-370m"), +] + +REPEATS = 3 +PROMPT = "The capital of France is" + + +def bench_model(name, pretrained, tokenizer, dtype, genlen): + """Benchmark a single pretrained model, returning latency and throughput.""" + print(f" Loading {name}...") + gc.collect() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + + t0 = time.time() + model = MambaLMHeadModel.from_pretrained(pretrained, device="cuda", dtype=dtype) + model.eval() + load_time = time.time() - t0 + + params = sum(p.numel() for p in model.parameters()) / 1e6 + vram_model = torch.cuda.memory_allocated() / GB + + input_ids = torch.tensor([tokenizer.encode(PROMPT)], device="cuda") + prompt_len = input_ids.shape[1] + max_length = prompt_len + genlen + + # Warmup + with torch.no_grad(): + model.generate(input_ids=input_ids, max_length=prompt_len + 10, temperature=0.7) + + # Timed runs + torch.cuda.synchronize() + latencies = [] + for _ in range(REPEATS): + torch.cuda.synchronize() + t0 = time.time() + out = model.generate( + input_ids=input_ids, + max_length=max_length, + temperature=0.7, + top_k=50, + top_p=0.9, + ) + torch.cuda.synchronize() + latencies.append((time.time() - t0) * 1000) + + avg_latency_ms = sum(latencies) / len(latencies) + tokens_gen = out.shape[1] - prompt_len + tok_per_sec = tokens_gen / (avg_latency_ms / 1000) + peak_vram = torch.cuda.max_memory_allocated() / GB + + del model, out, input_ids + gc.collect() + torch.cuda.empty_cache() + + return { + "params_m": params, + "load_time_s": load_time, + "vram_model_gb": vram_model, + "peak_vram_gb": peak_vram, + "latency_ms": avg_latency_ms, + "tok_per_sec": tok_per_sec, + "tokens_gen": tokens_gen, + } + + +def collect_results(dtype, genlen): + tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") + results = {} + for name, pretrained in MODELS: + results[name] = bench_model(name, pretrained, tokenizer, dtype, genlen) + return results + + +def plot_results(results, genlen): + names = list(results.keys()) + colors = ["#4C78A8", "#F58518", "#E45756", "#72B7B2"] + + fig, axes = plt.subplots(2, 2, figsize=(12, 9)) + fig.suptitle( + f"Text Generation Latency — {torch.cuda.get_device_name(0)}\n" + f"prompt={PROMPT!r}, genlen={genlen}, {REPEATS} runs | " + f"PyTorch {torch.__version__} | CUDA {torch.version.cuda}", + fontsize=12, fontweight="bold", + ) + + def add_bar_labels(ax, bars, values, fmt="{:.1f}"): + for bar, v in zip(bars, values): + ax.text(bar.get_x() + bar.get_width() / 2, + bar.get_height() + max(values) * 0.03, + fmt.format(v), ha="center", va="bottom", fontsize=10) + + # 1) Generation throughput (tok/s) + ax = axes[0, 0] + vals = [results[n]["tok_per_sec"] for n in names] + bars = ax.bar(names, vals, color=colors[:len(names)]) + ax.set_ylabel("Tokens / sec") + ax.set_title("Generation Throughput (higher is better)") + add_bar_labels(ax, bars, vals, "{:.1f}") + + # 2) End-to-end latency (ms) + ax = axes[0, 1] + vals = [results[n]["latency_ms"] for n in names] + bars = ax.bar(names, vals, color=colors[:len(names)]) + ax.set_ylabel("ms") + ax.set_title(f"Total Latency for {genlen} tokens (lower is better)") + add_bar_labels(ax, bars, vals, "{:.0f}ms") + + # 3) VRAM usage + ax = axes[1, 0] + x_pos = range(len(names)) + w = 0.35 + vram_model = [results[n]["vram_model_gb"] for n in names] + vram_peak = [results[n]["peak_vram_gb"] for n in names] + ax.bar([p - w / 2 for p in x_pos], vram_model, w, label="Model weights", color="#4C78A8") + ax.bar([p + w / 2 for p in x_pos], vram_peak, w, label="Peak (generation)", color="#E45756") + vram_total = torch.cuda.get_device_properties(0).total_memory / GB + ax.axhline(y=vram_total, color="red", linestyle="--", alpha=0.4, label=f"VRAM limit ({vram_total:.1f} GB)") + ax.set_ylabel("GB") + ax.set_title("VRAM Usage") + ax.set_xticks(list(x_pos)) + ax.set_xticklabels(names) + ax.legend(fontsize=9) + + # 4) Model size + ax = axes[1, 1] + vals = [results[n]["params_m"] for n in names] + bars = ax.bar(names, vals, color=colors[:len(names)]) + ax.set_ylabel("Parameters (M)") + ax.set_title("Model Size") + add_bar_labels(ax, bars, vals, "{:.0f}M") + + plt.tight_layout() + plt.savefig(OUTPUT_PATH, bbox_inches="tight") + print(f"\nSaved: {OUTPUT_PATH}") + + +def main(): + parser = argparse.ArgumentParser(description="Visual benchmark of text generation latency") + parser.add_argument("--genlen", type=int, default=100, help="Number of tokens to generate") + parser.add_argument("--dtype", type=str, default="bfloat16", choices=["float16", "bfloat16"]) + args = parser.parse_args() + + dtype = getattr(torch, args.dtype) + + print(f"GPU: {torch.cuda.get_device_name(0)}") + print(f"Prompt: {PROMPT!r}") + print(f"Generation length: {args.genlen} tokens, dtype: {args.dtype}") + print(f"Averaging over {REPEATS} runs\n") + + results = collect_results(dtype, args.genlen) + + print("\n--- Results ---") + for name, r in results.items(): + print(f" {name:12s} params={r['params_m']:.0f}M " + f"latency={r['latency_ms']:.0f}ms " + f"throughput={r['tok_per_sec']:.1f} tok/s " + f"vram_peak={r['peak_vram_gb']:.2f}GB") + + plot_results(results, args.genlen) + + +if __name__ == "__main__": + main()