From e1e20c3ed5eeedacda49fbad32c9fc37a7767057 Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Tue, 23 Jun 2026 05:23:03 +0000 Subject: [PATCH 1/4] Re-land #1637 (FastAPI server + CLI) onto migrated #1622 (new layout) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit #1622 was migrated to the new top-level layout (interpretability/sparse_autoencoders/…, no bionemo-recipes/ prefix) and squash-replayed, so #1637 is layered on by hand rather than rebased. - Clean adds: src/evo2_sae/{server,cli}.py, scripts/launch_inference.sh, tests/test_{cli,server}.py. - pyproject: add pandas/fastapi/uvicorn/anyio. - tests/conftest.py: keep #1622's 1B GPU fixtures + bionemo.common loader; append the serve-layer FakeEngine + fake_engine fixture. - core.py (semantic merge): keep #1622's _sanitize_steering (all CPU sanitize tests) and fold in the explicit non-finite-strength guard (no min/max arg-order reliance); add the shared annotate() + parse_clamp_spec() (CLI strings ⇄ API dicts) and feed parse_clamp_spec in front of _sanitize_steering; add _is_unrecoverable_cuda + flip the engine not-ready on an unrecoverable CUDA fault in generate(). (Kept _sanitize_steering rather than swapping in _normalize_clamps — non-redundant and preserves #1622's sampler hardening + tests.) - test_steering.py: keep #1622's sanitize + GPU tests; add the _is_unrecoverable_cuda test. Preserved from #1622: clamp_hook canonical encode/decode, TopKSAE-only _load_sae (no ReLU), bionemo.common(/core fallback) loader. Validated in the evo2_megatron venv: CPU 38 passed, GPU test_steering 13 passed on the 1B (ran, not skipped). Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: Polina Binder --- .../recipes/evo2/pyproject.toml | 9 +- .../recipes/evo2/scripts/launch_inference.sh | 37 ++++ .../recipes/evo2/src/evo2_sae/cli.py | 188 +++++++++++++++++ .../recipes/evo2/src/evo2_sae/core.py | 128 +++++++++--- .../recipes/evo2/src/evo2_sae/server.py | 195 ++++++++++++++++++ .../recipes/evo2/tests/conftest.py | 75 +++++++ .../recipes/evo2/tests/test_cli.py | 117 +++++++++++ .../recipes/evo2/tests/test_server.py | 117 +++++++++++ .../recipes/evo2/tests/test_steering.py | 9 +- 9 files changed, 843 insertions(+), 32 deletions(-) create mode 100755 interpretability/sparse_autoencoders/recipes/evo2/scripts/launch_inference.sh create mode 100644 interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/cli.py create mode 100644 interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/server.py create mode 100644 interpretability/sparse_autoencoders/recipes/evo2/tests/test_cli.py create mode 100644 interpretability/sparse_autoencoders/recipes/evo2/tests/test_server.py diff --git a/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml b/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml index 5132547927..1945e3455b 100644 --- a/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml +++ b/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml @@ -12,12 +12,15 @@ dependencies = [ "sae", "torch>=2.0", "numpy>=1.20", + "pandas>=2.0", # cli.py `batch` -> parquet "pyarrow>=23.0.0", + "fastapi>=0.110", # server.py + "uvicorn>=0.29", # serve CLI + "anyio>=4.0", # server.py: threadpool concurrency cap (ships with fastapi, declared explicitly) ] -# The `evo2_sae` package (src/) holds the live inference engine + steering hook; -# scripts/ (extract, train) are standalone entry points alongside it. The FastAPI -# server + CLI (and their fastapi/uvicorn deps) are added by the serve PR (#1637). +# The `evo2_sae` package (src/) holds the live inference engine + steering hook + the FastAPI +# server and CLI; scripts/ (extract, train) are standalone entry points alongside it. [tool.setuptools.packages.find] where = ["src"] diff --git a/interpretability/sparse_autoencoders/recipes/evo2/scripts/launch_inference.sh b/interpretability/sparse_autoencoders/recipes/evo2/scripts/launch_inference.sh new file mode 100755 index 0000000000..26768a4c46 --- /dev/null +++ b/interpretability/sparse_autoencoders/recipes/evo2/scripts/launch_inference.sh @@ -0,0 +1,37 @@ +#!/bin/bash +# Launch the Evo2 SAE inference engine. One engine, four modes: +# +# ./launch_inference.sh serve # live HTTP server on :8001 (viz backend) +# ./launch_inference.sh encode --sequence ATGC... # annotate ONE sequence -> top features +# ./launch_inference.sh batch --fasta in.fa --out out.parquet # MANY sequences -> parquet +# ./launch_inference.sh generate --prompt ATGC... --clamp 29244:300 # steer + generate DNA +# +# Steering loop: `encode` a sequence to find an active feature id, then +# `generate --clamp ID:STRENGTH` (strength ~2-3x the feature's max_activation; repeat --clamp). +# +# Config via env. Required: EVO2_CKPT_DIR, SAE_CKPT_PATH. Optional (have defaults): +# FEATURE_ANNOTATIONS, EMBEDDING_LAYER (26), DEVICE, PORT, CUDA_VISIBLE_DEVICES. +# +# Requires the evo2_megatron recipe venv (provides bionemo.evo2 + megatron). +set -euo pipefail + +HERE="$(cd "$(dirname "$0")" && pwd)" +RECIPE_DIR="$(cd "$HERE/.." && pwd)" # recipes/evo2 — so the evo2_sae package imports + +# Required (no hardcoded defaults — supply your own paths via env): +VENV="${VENV:?Set VENV to the evo2_megatron recipe .venv (provides bionemo.evo2 + megatron)}" +export EVO2_CKPT_DIR="${EVO2_CKPT_DIR:?Set EVO2_CKPT_DIR to an Evo2 MBridge checkpoint directory}" +export SAE_CKPT_PATH="${SAE_CKPT_PATH:?Set SAE_CKPT_PATH to a trained SAE checkpoint (.pt)}" +# Optional: feature-label parquet (empty = features are unlabeled). Layer defaults to 26. +export FEATURE_ANNOTATIONS="${FEATURE_ANNOTATIONS:-}" +export EMBEDDING_LAYER="${EMBEDDING_LAYER:-26}" + +if [[ ! -x "$VENV/bin/python" ]]; then + echo "ERROR: evo2_megatron venv not found at $VENV (build it with the recipe's .ci_build.sh)" >&2 + exit 1 +fi + +source "$VENV/bin/activate" +cd "$RECIPE_DIR" +export PYTHONPATH="$RECIPE_DIR/src${PYTHONPATH:+:$PYTHONPATH}" +exec python -m evo2_sae.cli "$@" diff --git a/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/cli.py b/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/cli.py new file mode 100644 index 0000000000..1a910d6cde --- /dev/null +++ b/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/cli.py @@ -0,0 +1,188 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Evo2 SAE inference CLI — one engine, four modes. + + serve : start the FastAPI server (one sequence at a time, interactive) + encode : annotate ONE sequence -> top features (stdout JSON) + batch : run a FASTA of MANY sequences -> parquet of per-sequence top features + generate: generate DNA, optionally steering SAE features (stdout JSON) + +They all build the same `Evo2SAE` engine; config comes from flags or env +(EVO2_CKPT_DIR / SAE_CKPT_PATH / FEATURE_ANNOTATIONS / EMBEDDING_LAYER). +""" + +from __future__ import annotations + +import argparse +import json +import os + + +def _add_common(p: argparse.ArgumentParser) -> None: + """Register the shared inference arguments (checkpoints, layer, device) on a parser. + + Defaults come from env vars (``EVO2_CKPT_DIR``, ``SAE_CKPT_PATH``, ``FEATURE_ANNOTATIONS``, + ``EMBEDDING_LAYER``, ``DEVICE``, ``MAX_SEQ_LEN``); pass the flags to override. No hardcoded + paths — the checkpoints must be supplied via flag or env. + + Args: + p: The argparse parser (or subparser) to add the shared arguments to. + + Returns: + None. Mutates ``p`` in place. + """ + p.add_argument("--evo2-ckpt-dir", default=os.environ.get("EVO2_CKPT_DIR")) + p.add_argument("--sae-ckpt-path", default=os.environ.get("SAE_CKPT_PATH")) + p.add_argument("--feature-annotations", default=os.environ.get("FEATURE_ANNOTATIONS")) + p.add_argument("--layer", type=int, default=os.environ.get("EMBEDDING_LAYER", "26")) + p.add_argument("--device", default=os.environ.get("DEVICE", "cuda")) + p.add_argument("--max-seq-len", type=int, default=os.environ.get("MAX_SEQ_LEN", "8192")) + + +def _engine(args): + """Construct an Evo2SAE engine from parsed CLI args. + + Args: + args: Parsed argparse namespace with ``evo2_ckpt_dir``, ``sae_ckpt_path``, ``layer``, + ``device``, ``max_seq_len``, ``feature_annotations``. + + Returns: + An (unloaded) ``Evo2SAE`` instance — call ``.load()`` before use. + """ + from .core import Evo2SAE + + return Evo2SAE( + evo2_ckpt_dir=args.evo2_ckpt_dir, + sae_ckpt_path=args.sae_ckpt_path, + layer=args.layer, + device=args.device, + max_seq_len=args.max_seq_len, + feature_annotations=args.feature_annotations, + ) + + +def main(): + """Parse args and dispatch to the serve / encode / batch subcommand.""" + ap = argparse.ArgumentParser(description="Evo2 SAE inference (serve | encode | batch | generate)") + sub = ap.add_subparsers(dest="cmd", required=True) + + ps = sub.add_parser("serve", help="start the FastAPI inference server") + _add_common(ps) + ps.add_argument("--host", default="0.0.0.0") + ps.add_argument("--port", type=int, default=os.environ.get("PORT", "8001")) + + pe = sub.add_parser("encode", help="annotate ONE sequence -> top features (JSON)") + _add_common(pe) + pe.add_argument("--sequence", required=True) + pe.add_argument("--organism", default="None (raw DNA)") + pe.add_argument("--top-k", type=int, default=8) + + pb = sub.add_parser("batch", help="MANY sequences (FASTA) -> parquet of per-sequence top features") + _add_common(pb) + pb.add_argument("--fasta", required=True) + pb.add_argument("--out", required=True) + pb.add_argument("--top-k", type=int, default=16) + pb.add_argument("--batch-size", type=int, default=8) + + pg = sub.add_parser("generate", help="generate DNA, optionally steering SAE features") + _add_common(pg) + pg.add_argument("--prompt", default="", help="DNA to seed; steering applies to the continuation") + pg.add_argument("--organism", default="None (raw DNA)") + pg.add_argument( + "--clamp", + action="append", + default=[], + metavar="FEATURE_ID[:STRENGTH]", + help="clamp a feature on the continuation; repeatable (e.g. --clamp 29244:300). " + "Find feature ids with `encode`.", + ) + pg.add_argument("--n-tokens", type=int, default=120) + pg.add_argument("--temperature", type=float, default=1.0) + pg.add_argument("--top-k", type=int, default=0) + pg.add_argument("--compare-baseline", action="store_true", help="also generate unsteered, for comparison") + + args = ap.parse_args() + + if args.cmd == "serve": + import uvicorn + + from .server import build_app + + uvicorn.run(build_app(_engine(args)), host=args.host, port=args.port, log_level="info") + return + + from . import core + + eng = _engine(args).load() + + if args.cmd == "encode": + try: + dna, _tag, codes, tag_len = core.annotate(eng, args.sequence, args.organism) + except ValueError as e: + raise SystemExit(str(e)) + feats = eng.top_features(codes, tag_len=tag_len, k=args.top_k) + print( + json.dumps( + {"sequence": dna, "organism": args.organism, "bases": len(dna), "top_features": feats}, indent=2 + ) + ) + + elif args.cmd == "batch": + import pandas as pd + + from .fasta import read_fasta + + ids, seqs = [], [] + for sid, seq in read_fasta(args.fasta): + ids.append(sid) + seqs.append(seq) + print(f"[batch] {len(seqs)} sequences from {args.fasta}; encoding (batch_size={args.batch_size})…") + codes_list = eng.encode_batch(seqs, batch_size=args.batch_size) + rows = [] + for sid, codes in zip(ids, codes_list): + for rank, ft in enumerate(eng.top_features(codes, k=args.top_k)): + rows.append({"sequence_id": sid, "bp": int(codes.shape[0]), "rank": rank, **ft}) + df = pd.DataFrame(rows) + df.to_parquet(args.out, index=False) + print(f"[batch] wrote {len(df)} rows for {len(seqs)} sequences -> {args.out}") + + elif args.cmd == "generate": + try: + out = eng.generate( + prompt=args.prompt, + organism=args.organism, + features=args.clamp, # raw "ID[:STRENGTH]" strings; core.parse_clamp_spec normalizes + n_tokens=args.n_tokens, + temperature=args.temperature, + top_k=args.top_k, + compare_baseline=args.compare_baseline, + ) + except ValueError as e: + raise SystemExit(str(e)) + result = { + "prompt": out["prompt"], + "organism": out["organism"], + "steered": out["steered"], + "features": out["features"], + "sequence": out["generation"]["sequence"], + } + if out.get("baseline"): + result["baseline_sequence"] = out["baseline"]["sequence"] + print(json.dumps(result, indent=2)) + + +if __name__ == "__main__": + main() diff --git a/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/core.py b/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/core.py index 19b88048a1..0c0e4d571a 100644 --- a/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/core.py +++ b/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/core.py @@ -36,6 +36,7 @@ from __future__ import annotations import logging +import math import os import re import sys @@ -79,6 +80,64 @@ def clean_dna(seq: str) -> str: return _VALID_BASES.sub("", (seq or "").upper()) +def annotate(engine, sequence: str, organism: str = "None (raw DNA)", tag: Optional[str] = None): + """Shared encode path for the CLI ``encode`` and the server ``/annotate`` (topk). + + Cleans the sequence, resolves the phylo tag, encodes once, and computes the tag length + (the leading tag tokens to skip, ignored if it would drop the whole sequence). Returns + ``(dna, resolved_tag, codes, tag_len)``; callers add their own top-k / per-base presentation. + + Raises ``ValueError`` on an empty/non-DNA sequence or an unknown organism — the server maps + these to HTTP 400, the CLI to a clean exit. Takes the engine as an argument (rather than + living on ``Evo2SAE``) so it composes with any object exposing ``resolve_tag``/``encode``. + """ + dna = clean_dna(sequence) + if not dna: + raise ValueError("No valid nucleotides in sequence") + resolved_tag = engine.resolve_tag(organism, tag) + if resolved_tag is None: + raise ValueError(f"Unknown organism '{organism}' and no custom tag") + full = resolved_tag + dna + # Reject over-length input rather than letting encode() silently truncate to max_seq_len — + # the per-base `activations` would then be shorter than `bases` and the viz would misalign. + if len(full) > engine.max_seq_len: + raise ValueError(f"sequence too long: {len(full)} > max_seq_len ({engine.max_seq_len})") + codes = engine.encode(full) + if codes.shape[0] != len(full): # belt-and-suspenders: encode must be 1:1 with the input here + raise ValueError("encoded length != input length (tokenizer truncated)") + tag_len = len(resolved_tag) if codes.shape[0] >= len(resolved_tag) else 0 + return dna, resolved_tag, codes, tag_len + + +def parse_clamp_spec(spec) -> dict: + """Normalize one feature-clamp spec into ``{"feature_id": int, "strength": float}``. + + Accepts a ``"FEATURE_ID[:STRENGTH]"`` string (CLI ``--clamp``; strength defaults to 1.0, + e.g. ``29244:300`` or ``29244``), or a mapping with ``feature_id``/``strength`` (server + ``FeatureClamp``). Single source of truth so the CLI and the API parse clamps identically. + Magnitude/finiteness/range are enforced later by ``_sanitize_steering``. + + Raises ``ValueError`` on a malformed string (caller decides whether that's a 400 or an exit). + """ + if isinstance(spec, str): + fid, sep, strength = spec.partition(":") + try: + return {"feature_id": int(fid), "strength": float(strength) if (sep and strength) else 1.0} + except ValueError: + raise ValueError(f"invalid clamp {spec!r}: expected FEATURE_ID[:STRENGTH] with numeric values") + return {"feature_id": int(spec["feature_id"]), "strength": float(spec.get("strength", 1.0))} + + +def _is_unrecoverable_cuda(e: Exception) -> bool: + """True for CUDA errors that poison the process context (device-side assert / sticky CUDA error). + + These don't clear on the next request — the worker must be recycled — so the caller marks the + engine not-ready rather than retrying into a stream of 500s. + """ + s = str(e).lower() + return isinstance(e, RuntimeError) and ("device-side assert" in s or "cuda error" in s) + + def _sanitize_steering(features, n_features, temperature, top_k): """Validate/normalize steering inputs (pure, no GPU); raise on bad input. @@ -96,10 +155,12 @@ def _sanitize_steering(features, n_features, temperature, top_k): bad = sorted({int(f["feature_id"]) for f in features if not (0 <= int(f["feature_id"]) < n_features)}) if bad: raise ValueError(f"feature_id(s) {bad} out of range [0, {n_features})") - clamps = { - int(f["feature_id"]): max(-MAX_CLAMP_STRENGTH, min(MAX_CLAMP_STRENGTH, float(f.get("strength", 1.0)))) - for f in features - } + clamps = {} + for f in features: + s = float(f.get("strength", 1.0)) + if not math.isfinite(s): # NaN/±inf would blow the logits to NaN -> neutralize explicitly + s = 0.0 + clamps[int(f["feature_id"])] = max(-MAX_CLAMP_STRENGTH, min(MAX_CLAMP_STRENGTH, s)) temperature = float(temperature) top_k = max(0, int(top_k)) # negative top_k is an invalid sampler arg -> 0 (no filtering) if temperature <= 0: # greedy — avoid the sampler's logits/temperature division (NaN) @@ -384,7 +445,9 @@ def generate( from bionemo.evo2.run import infer as INF - features = features or [] + # Accept CLI "ID[:STRENGTH]" strings and server FeatureClamp dicts identically, then + # validate/cap below via _sanitize_steering. + features = [parse_clamp_spec(f) for f in (features or [])] resolved_tag = self.resolve_tag(organism, tag) if resolved_tag is None: raise ValueError(f"Unknown organism '{organism}' and no custom tag") @@ -401,30 +464,39 @@ def generate( # 0 each trigger CUDA device-side asserts that wedge the server (see _sanitize_steering). clamps, fids, temperature, top_k = _sanitize_steering(features, self.n_features, temperature, top_k) - with self._lock: - comp = self._ensure_engine() - hook_layer = unwrap_model(comp.model).decoder.layers[self.layer] - from sae.steering import clamp_hook - - feat_meta = [{"id": fid, "label": self.labels.get(fid), "strength": s} for fid, s in clamps.items()] - - def _run(steer: bool) -> str: - handle = ( - hook_layer.register_forward_hook(clamp_hook(self.sae, clamps, decode_only=True)) - if (steer and clamps) - else None - ) - try: - out = INF.generate( - comp, [full_prompt], max_new_tokens=n_tokens, temperature=temperature, top_k=top_k + try: + with self._lock: + comp = self._ensure_engine() + hook_layer = unwrap_model(comp.model).decoder.layers[self.layer] + from sae.steering import clamp_hook + + feat_meta = [{"id": fid, "label": self.labels.get(fid), "strength": s} for fid, s in clamps.items()] + + def _run(steer: bool) -> str: + handle = ( + hook_layer.register_forward_hook(clamp_hook(self.sae, clamps, decode_only=True)) + if (steer and clamps) + else None ) - return clean_dna(INF._unwrap_result(out[0]).generated_text) - finally: - if handle is not None: - handle.remove() - - main_dna = _run(steer=True) - base_dna = _run(steer=False) if (compare_baseline and clamps) else None + try: + out = INF.generate( + comp, [full_prompt], max_new_tokens=n_tokens, temperature=temperature, top_k=top_k + ) + return clean_dna(INF._unwrap_result(out[0]).generated_text) + finally: + if handle is not None: + handle.remove() + + main_dna = _run(steer=True) + base_dna = _run(steer=False) if (compare_baseline and clamps) else None + except Exception as e: + # A CUDA device-side assert poisons the context: every later request would 500 until + # restart. Mark the engine not-ready so /health flips to 503 and an orchestrator + # recycles the pod, instead of serving a permanently-wedged worker. + if _is_unrecoverable_cuda(e): + self.ready = False + logger.exception("unrecoverable CUDA error in generate() — marking engine not-ready") + raise resp = { "prompt": dna, diff --git a/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/server.py b/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/server.py new file mode 100644 index 0000000000..b1660494cc --- /dev/null +++ b/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/server.py @@ -0,0 +1,195 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""FastAPI server over the Evo2SAE engine — the live backend the viz talks to. + +Endpoints: /health, /features, /annotate (per-base activations for a pasted +sequence), /generate (autoregressive generation + optional SAE-feature clamp). +This is a thin layer; all model work lives in `core.Evo2SAE`. +""" + +from __future__ import annotations + +import logging +import os +from contextlib import asynccontextmanager +from typing import Optional + +import anyio +from fastapi import FastAPI, HTTPException, Response +from fastapi.responses import JSONResponse +from pydantic import BaseModel + +from . import core +from .core import Evo2SAE + + +logger = logging.getLogger("evo2_sae_infer.server") + + +class AnnotateRequest(BaseModel): + """Request body for /annotate (top-k feature scan or an explicit feature pick).""" + + sequence: str + organism: str = "None (raw DNA)" + tag: Optional[str] = None + mode: str = "topk" # "topk" | "pick" + k: int = 8 + feature_ids: Optional[list[int]] = None + + +class FeatureClamp(BaseModel): + """A single SAE-feature steering clamp (feature id + target strength).""" + + feature_id: int + strength: float = 1.0 + + +class GenerateRequest(BaseModel): + """Request body for /generate (autoregressive generation + optional SAE-feature clamps).""" + + prompt: str = "" + organism: str = "None (raw DNA)" + tag: Optional[str] = None + features: list[FeatureClamp] = [] + n_tokens: int = 120 + temperature: float = 1.0 + top_k: int = 0 + compare_baseline: bool = False + + +def build_app(engine: Evo2SAE) -> FastAPI: + """Build the FastAPI app; the engine is loaded once in the lifespan handler.""" + # One GPU (the engine serializes model calls with a lock), so cap how many sync requests run + # at once: excess requests wait for a worker instead of piling up dozens of parked threads. + # NOTE: generation is bounded only by the context window now, so a single /generate can run + # long — under concurrent load requests queue behind it. Sync endpoints run in Starlette's + # AnyIO threadpool (default 40); shrink it. Tune MAX_CONCURRENCY. + max_concurrency = int(os.getenv("MAX_CONCURRENCY", "8")) + + @asynccontextmanager + async def lifespan(app: FastAPI): + anyio.to_thread.current_default_thread_limiter().total_tokens = max_concurrency + try: + engine.load() + logger.info("engine ready") + except Exception: + logger.exception("engine startup failed — /health stays not-ready") + yield + + app = FastAPI(title="Evo2 SAE inference", lifespan=lifespan) + + # No CORS middleware: the dashboard always reaches the backend same-origin (Vite proxies + # /api -> :8001), so cross-origin is never used. CORS is browser-only and not an access + # control anyway — scripts ignore it; SSO + the limits below are what gate this endpoint. + + # Reject oversized request bodies up front (a multi-MB sequence would be read into memory + # before per-field validation could reject it). Default 16 MiB; override with MAX_BODY_BYTES. + max_body = int(os.getenv("MAX_BODY_BYTES", str(16 * 1024 * 1024))) + + @app.middleware("http") + async def _limit_body(request, call_next): + cl = request.headers.get("content-length") + if cl is not None and cl.isdigit() and int(cl) > max_body: + return JSONResponse({"detail": f"request body too large (> {max_body} bytes)"}, status_code=413) + return await call_next(request) + + def _require_ready(): + if not engine.ready: + raise HTTPException(503, "Backend not ready") + + @app.get("/health") + def health(response: Response): + if not engine.ready: + response.status_code = 503 # readiness probes shed this pod until load finishes (body still informative) + return { + "ready": bool(engine.ready), + "layer": engine.layer, + "n_features": engine.n_features, + "n_labels": len(engine.labels), + "organisms": list(engine.organism_tags.keys()), + "organism_tags": engine.organism_tags, + "device": engine.device, + "max_seq_len": engine.max_seq_len, # context budget — UI caps generation length to this + } + + @app.get("/features") + def features(): + _require_ready() + rows = [ + {"id": int(f), "label": lab, "natural_peak": engine.peaks.get(int(f))} for f, lab in engine.labels.items() + ] + rows.sort(key=lambda r: r["id"]) + return rows + + @app.post("/annotate") + def annotate(req: AnnotateRequest): + _require_ready() + try: + dna, tag, codes, tag_len = core.annotate(engine, req.sequence, req.organism, req.tag) + except ValueError as e: + raise HTTPException(413 if "too long" in str(e) else 400, str(e)) + full = tag + dna + if req.mode not in ("pick", "topk"): + raise HTTPException(400, f"Invalid mode {req.mode!r}: must be 'pick' or 'topk'") + if req.mode == "pick": + if not req.feature_ids: + raise HTTPException(400, "mode='pick' requires feature_ids") + chosen = [int(i) for i in req.feature_ids] + else: + k = max(1, min(int(req.k), 64)) + chosen = [ft["feature_id"] for ft in engine.top_features(codes, tag_len=tag_len, k=k)] + feats = [] + for fid in chosen: + col = codes[:, fid] + feats.append( + { + "feature_id": fid, + "label": engine.labels.get(fid), + "max_activation": float(col[tag_len:].max().item()) + if codes.shape[0] > tag_len + else float(col.max().item()), + "activations": [round(float(v), 4) for v in col.tolist()], + } + ) + return { + "sequence": dna, + "organism": req.organism, + "tag": tag, + "tag_len": tag_len, + "bases": list(full), + "n_tokens": codes.shape[0], + "layer": engine.layer, + "features": feats, + } + + @app.post("/generate") + def generate(req: GenerateRequest): + _require_ready() + try: + return engine.generate( + prompt=req.prompt, + organism=req.organism, + tag=req.tag, + features=[f.model_dump() for f in req.features], + n_tokens=req.n_tokens, + temperature=req.temperature, + top_k=req.top_k, + compare_baseline=req.compare_baseline, + ) + except ValueError as e: + raise HTTPException(413 if "too long" in str(e) else 400, str(e)) + + return app diff --git a/interpretability/sparse_autoencoders/recipes/evo2/tests/conftest.py b/interpretability/sparse_autoencoders/recipes/evo2/tests/conftest.py index 95fbc2dd68..76119b48fc 100644 --- a/interpretability/sparse_autoencoders/recipes/evo2/tests/conftest.py +++ b/interpretability/sparse_autoencoders/recipes/evo2/tests/conftest.py @@ -116,3 +116,78 @@ def sae_ckpt_path(tmp_path_factory) -> str: def embedding_layer() -> int: """Layer whose residual stream the SAE reads/steers (1B has 25 layers; default 19).""" return int(os.environ.get("EMBEDDING_LAYER", "19")) + + +# --------------------------------------------------------------------------------------------- +# CPU fixtures for the serve layer (#1637). `FakeEngine` is the one mock the CLI tests +# (test_cli.py) and the server contract tests (test_server.py) both drive — no model, CPU-only — +# so the two suites stay in lockstep on the engine surface the server/CLI touch. +# --------------------------------------------------------------------------------------------- +from evo2_sae import core # noqa: E402 + + +class FakeEngine: + """Minimal stand-in for Evo2SAE exposing only what the CLI + server touch.""" + + def __init__(self): + self.ready = True + self.layer = 19 + self.n_features = 4 + self.labels = {0: "feat0", 1: "feat1"} + self.peaks = {0: 0.5} + self.organism_tags = {"None (raw DNA)": "", "Human": "|tag|"} + self.device = "cpu" + self.sae_ckpt_path = "fake.pt" + self.max_seq_len = 8192 + self.gen_kwargs = None # records the last generate() call for CLI assertions + self.last_k = None # records the k top_features() was last called with + + def load(self): + self.ready = True + return self # the CLI uses `_engine(args).load()` + + def resolve_tag(self, organism, tag): + return tag if tag is not None else self.organism_tags.get(organism) + + def encode(self, full): + codes = torch.zeros(len(full), self.n_features) + codes[:, 0] = 1.0 # feature 0 fires everywhere + return codes + + def encode_batch(self, seqs, batch_size=8): + return [self.encode(s) for s in seqs] + + def top_features(self, codes, tag_len=0, k=8): + self.last_k = k # so tests can assert the server clamped k into [1, 64] + feats = [{"feature_id": i, "label": self.labels.get(i), "max_activation": 1.0} for i in range(self.n_features)] + return feats[:k] + + def generate(self, **kw): + self.gen_kwargs = kw + # Mirror the real engine: clamps normalize through the shared parser (a malformed --clamp + # raises), out-of-range ids are rejected (the wedge guard), and a seedless request is too. + specs = [core.parse_clamp_spec(f) for f in (kw.get("features") or [])] + bad = [s["feature_id"] for s in specs if not (0 <= s["feature_id"] < self.n_features)] + if bad: + raise ValueError(f"feature_id(s) {bad} out of range [0, {self.n_features})") + prompt = core.clean_dna(kw.get("prompt", "")) + if not prompt and kw.get("organism") == "None (raw DNA)" and not kw.get("tag"): + raise ValueError("need a seed") + resp = { + "prompt": prompt, + "organism": kw.get("organism"), + "tag": kw.get("tag"), + "steered": bool(specs), + "features": specs, + "generation": {"sequence": "ACGT", "activations": {0: [1.0, 1.0, 1.0, 1.0]}}, + "baseline": None, + } + if kw.get("compare_baseline") and specs: + resp["baseline"] = {"sequence": "TTTT", "activations": {}} + return resp + + +@pytest.fixture +def fake_engine(): + """A fresh CPU FakeEngine instance.""" + return FakeEngine() diff --git a/interpretability/sparse_autoencoders/recipes/evo2/tests/test_cli.py b/interpretability/sparse_autoencoders/recipes/evo2/tests/test_cli.py new file mode 100644 index 0000000000..6910e1518e --- /dev/null +++ b/interpretability/sparse_autoencoders/recipes/evo2/tests/test_cli.py @@ -0,0 +1,117 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CLI contract tests — the `encode | batch | generate` subcommands, CPU-only. + +A mocked engine (no model) drives `cli.main()` so these run in CI alongside test_server.py. +They lock the stdout JSON shapes, the clean-exit error handling (no tracebacks), and the +shared clamp/annotate paths in `core` that the CLI and the server now both go through. +Real model inference is covered by test_steering.py. +""" + +import json + +import pandas as pd +import pytest +from evo2_sae import cli, core + + +# FakeEngine lives in conftest.py — shared with test_server.py so both suites mock the same surface. + + +@pytest.fixture +def fake(monkeypatch, fake_engine): + """Patch the CLI's engine factory so every subcommand runs against the shared CPU fake.""" + monkeypatch.setattr(cli, "_engine", lambda args: fake_engine) + return fake_engine + + +def run(monkeypatch, *argv): + """Invoke `cli.main()` with the given argv (no leading prog name).""" + monkeypatch.setattr("sys.argv", ["evo2-sae", *argv]) + cli.main() + + +# ----------------------------------------------------------------------------- encode +def test_encode_outputs_top_features(fake, monkeypatch, capsys): + run(monkeypatch, "encode", "--sequence", "ACGTACGT") + out = json.loads(capsys.readouterr().out) + assert out["bases"] == 8 + assert out["top_features"][0]["feature_id"] == 0 + + +def test_encode_rejects_non_dna(fake, monkeypatch): + with pytest.raises(SystemExit): # clean exit, not a traceback + run(monkeypatch, "encode", "--sequence", "ZZZZ") + + +def test_encode_rejects_unknown_organism(fake, monkeypatch): + with pytest.raises(SystemExit): + run(monkeypatch, "encode", "--sequence", "ACGT", "--organism", "Martian") + + +# --------------------------------------------------------------------------- generate +def test_generate_outputs_sequence(fake, monkeypatch, capsys): + run(monkeypatch, "generate", "--prompt", "ACGT") + out = json.loads(capsys.readouterr().out) + assert out["sequence"] == "ACGT" + assert out["steered"] is False + + +def test_generate_passes_raw_clamp_strings(fake, monkeypatch, capsys): + run(monkeypatch, "generate", "--prompt", "ACGT", "--clamp", "0:5", "--clamp", "1") + # The CLI hands raw "ID[:STRENGTH]" strings straight to the engine; core normalizes them. + assert fake.gen_kwargs["features"] == ["0:5", "1"] + out = json.loads(capsys.readouterr().out) + assert out["features"] == [{"feature_id": 0, "strength": 5.0}, {"feature_id": 1, "strength": 1.0}] + + +def test_generate_rejects_malformed_clamp(fake, monkeypatch): + with pytest.raises(SystemExit): + run(monkeypatch, "generate", "--prompt", "ACGT", "--clamp", "notanumber") + + +# ------------------------------------------------------------------------------ batch +def test_batch_writes_parquet(fake, monkeypatch, capsys, tmp_path): + fasta = tmp_path / "in.fa" + fasta.write_text(">a\nACGTACGT\n>b\nTTTT\n") + out = tmp_path / "out.parquet" + run(monkeypatch, "batch", "--fasta", str(fasta), "--out", str(out)) + df = pd.read_parquet(out) + assert set(df["sequence_id"]) == {"a", "b"} + assert {"sequence_id", "bp", "rank", "feature_id"} <= set(df.columns) + + +# ------------------------------------------------------- shared core helpers (CLI + API) +def test_parse_clamp_spec_string_and_dict(): + assert core.parse_clamp_spec("29244:300") == {"feature_id": 29244, "strength": 300.0} + assert core.parse_clamp_spec("7") == {"feature_id": 7, "strength": 1.0} # default strength + assert core.parse_clamp_spec({"feature_id": 7, "strength": 2.5}) == {"feature_id": 7, "strength": 2.5} + assert core.parse_clamp_spec({"feature_id": 7}) == {"feature_id": 7, "strength": 1.0} + + +def test_parse_clamp_spec_rejects_garbage(): + with pytest.raises(ValueError): + core.parse_clamp_spec("x:y") + + +def test_annotate_skips_tag_and_rejects_bad_input(fake_engine): + dna, tag, codes, tag_len = core.annotate(fake_engine, "ACGT", organism="Human") + assert dna == "ACGT" and tag == "|tag|" and tag_len == len("|tag|") + assert codes.shape == (len(tag) + len(dna), fake_engine.n_features) + with pytest.raises(ValueError): + core.annotate(fake_engine, "ZZZZ") # non-DNA + with pytest.raises(ValueError): + core.annotate(fake_engine, "ACGT", organism="Martian") # unknown organism diff --git a/interpretability/sparse_autoencoders/recipes/evo2/tests/test_server.py b/interpretability/sparse_autoencoders/recipes/evo2/tests/test_server.py new file mode 100644 index 0000000000..61340b5dfb --- /dev/null +++ b/interpretability/sparse_autoencoders/recipes/evo2/tests/test_server.py @@ -0,0 +1,117 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Server contract tests — the API the feature-explorer viz consumes. + +A mocked engine (no model, CPU-only) drives the FastAPI app so these run in CI and lock the +response shapes + error codes the dashboard depends on: /health, /features, /annotate (per-base +activations), /generate. Real model inference is covered by test_steering.py. +""" + +import pytest +from evo2_sae.server import build_app +from fastapi.testclient import TestClient + + +# FakeEngine lives in conftest.py — shared with test_cli.py so both suites mock the same surface. + + +@pytest.fixture +def client(fake_engine): + with TestClient(build_app(fake_engine)) as c: + yield c + + +def test_health(client): + r = client.get("/health") + assert r.status_code == 200 # 200 only when ready + b = r.json() + assert b["ready"] is True and b["layer"] == 19 + assert "None (raw DNA)" in b["organisms"] + + +def test_annotate_rejects_too_long(client, fake_engine): + seq = "A" * (fake_engine.max_seq_len + 1) # exceeds the context budget + assert client.post("/annotate", json={"sequence": seq}).status_code == 413 + + +def test_features(client): + rows = client.get("/features").json() + assert {"id", "label", "natural_peak"} <= set(rows[0]) + + +def test_annotate_returns_per_base_activations(client): + b = client.post("/annotate", json={"sequence": "ACGTACGT", "organism": "None (raw DNA)"}).json() + assert {"sequence", "features", "bases", "tag_len", "layer", "n_tokens"} <= set(b) + assert b["features"][0]["activations"] # the per-base track the viz plots + + +def test_annotate_rejects_non_dna(client): + assert client.post("/annotate", json={"sequence": "ZZZZ"}).status_code == 400 + + +def test_annotate_pick_mode(client): + b = client.post("/annotate", json={"sequence": "ACGTACGT", "mode": "pick", "feature_ids": [1]}).json() + assert [f["feature_id"] for f in b["features"]] == [1] + assert b["features"][0]["activations"] # per-base track returned for the picked feature + + +def test_annotate_pick_requires_ids(client): + assert client.post("/annotate", json={"sequence": "ACGT", "mode": "pick"}).status_code == 400 + + +def test_annotate_rejects_invalid_mode(client): + assert client.post("/annotate", json={"sequence": "ACGT", "mode": "bogus"}).status_code == 400 + + +def test_annotate_clamps_k_into_range(client, fake_engine): + client.post("/annotate", json={"sequence": "ACGT", "k": 999}) + assert fake_engine.last_k == 64 # upper bound + client.post("/annotate", json={"sequence": "ACGT", "k": 0}) + assert fake_engine.last_k == 1 # lower bound + + +def test_generate_returns_sequence(client): + b = client.post("/generate", json={"prompt": "ACGT", "organism": "None (raw DNA)"}).json() + assert b["generation"]["sequence"] + + +def test_generate_rejects_out_of_range_feature(client): + r = client.post("/generate", json={"prompt": "ACGT", "features": [{"feature_id": 999}]}) + assert r.status_code == 400 # the wedge guard, surfaced to the client + + +def test_generate_compare_baseline(client): + b = client.post( + "/generate", + json={"prompt": "ACGT", "features": [{"feature_id": 0, "strength": 5.0}], "compare_baseline": True}, + ).json() + assert b["baseline"]["sequence"] # unsteered comparison returned alongside the steered one + + +def test_rejects_oversized_body(monkeypatch, fake_engine): + monkeypatch.setenv("MAX_BODY_BYTES", "100") + with TestClient(build_app(fake_engine)) as c: + assert c.post("/annotate", json={"sequence": "ACGT" * 100}).status_code == 413 + + +def test_endpoints_503_until_ready(fake_engine): + fake_engine.ready = False + fake_engine.load = lambda: None # startup leaves it not-ready + with TestClient(build_app(fake_engine)) as c: + assert c.get("/health").status_code == 503 # readiness probe sheds the pod + assert c.get("/features").status_code == 503 + assert c.post("/annotate", json={"sequence": "ACGT"}).status_code == 503 + assert c.post("/generate", json={"prompt": "ACGT", "organism": "None (raw DNA)"}).status_code == 503 diff --git a/interpretability/sparse_autoencoders/recipes/evo2/tests/test_steering.py b/interpretability/sparse_autoencoders/recipes/evo2/tests/test_steering.py index f18c8ad020..49ffffa474 100644 --- a/interpretability/sparse_autoencoders/recipes/evo2/tests/test_steering.py +++ b/interpretability/sparse_autoencoders/recipes/evo2/tests/test_steering.py @@ -28,7 +28,7 @@ import pytest import torch from evo2_sae import Evo2SAE -from evo2_sae.core import MAX_CLAMP_STRENGTH, _sanitize_steering +from evo2_sae.core import MAX_CLAMP_STRENGTH, _is_unrecoverable_cuda, _sanitize_steering # The delta-clamp math + decode-only/prefill behavior is covered against the production @@ -104,6 +104,13 @@ def test_sanitize_clamps_negative_top_k(): assert top_k == 0 +def test_is_unrecoverable_cuda(): + """Only sticky CUDA faults flip the engine not-ready; ordinary errors propagate normally.""" + assert _is_unrecoverable_cuda(RuntimeError("CUDA error: device-side assert triggered")) + assert not _is_unrecoverable_cuda(RuntimeError("shape mismatch")) + assert not _is_unrecoverable_cuda(ValueError("bad input")) + + # --------------------------------------------------------------------- GPU: real generation _PROMPT = "ATGGCCGAATTCGGCACGAGGACGTGCTGAAAGCTAGCTAGGCTAACCGGTTACGTGCAT" _ORG = "Human" From ac828a4e0954c808ea93b2146e57df5b8367c47a Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Tue, 23 Jun 2026 05:37:00 +0000 Subject: [PATCH 2/4] =?UTF-8?q?fix(serve):=20address=20review=20=E2=80=94?= =?UTF-8?q?=20pick-id=20range=20check,=20real=20/generate=20413,=20int=20p?= =?UTF-8?q?ort,=20fake=20shape?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. /annotate pick mode now range-checks user-supplied feature_ids -> 400 (was: out-of-range IndexError -> 500, negative id silently indexed the wrong feature via torch negative-index). + test_annotate_pick_rejects_out_of_range_id. 2. core.generate rejects an over-context prompt ("too long" -> server 413), instead of letting tokenize() silently truncate it — makes the /generate 413 branch live and matches /annotate. + test_generate_rejects_overlong_prompt. 3. cli.py: int() the env-var defaults (PORT/EMBEDDING_LAYER/MAX_SEQ_LEN) — argparse type= only coerces command-line values, so `serve` was handing uvicorn a str port. 4. conftest FakeEngine.generate now returns features keyed {id, label, strength} (the real feat_meta shape the dashboard consumes), not {feature_id, strength}; test_cli updated so the contract test pins the real API shape. 5. Note body-size limit is advisory (Content-Length only; chunked/lying bypasses). 6. Note the CUDA-wedge guard depends on a readiness-based recycler (else 503 until manual restart). Validated in the evo2_megatron venv: CPU 40 passed (was 38), GPU unaffected. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: Polina Binder --- .../recipes/evo2/src/evo2_sae/cli.py | 10 +++++++--- .../recipes/evo2/src/evo2_sae/core.py | 9 ++++++++- .../recipes/evo2/src/evo2_sae/server.py | 8 ++++++++ .../sparse_autoencoders/recipes/evo2/tests/conftest.py | 8 +++++++- .../sparse_autoencoders/recipes/evo2/tests/test_cli.py | 6 +++++- .../recipes/evo2/tests/test_core.py | 6 ++++++ .../recipes/evo2/tests/test_server.py | 9 +++++++++ 7 files changed, 50 insertions(+), 6 deletions(-) diff --git a/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/cli.py b/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/cli.py index 1a910d6cde..a164ec33ea 100644 --- a/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/cli.py +++ b/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/cli.py @@ -47,9 +47,11 @@ def _add_common(p: argparse.ArgumentParser) -> None: p.add_argument("--evo2-ckpt-dir", default=os.environ.get("EVO2_CKPT_DIR")) p.add_argument("--sae-ckpt-path", default=os.environ.get("SAE_CKPT_PATH")) p.add_argument("--feature-annotations", default=os.environ.get("FEATURE_ANNOTATIONS")) - p.add_argument("--layer", type=int, default=os.environ.get("EMBEDDING_LAYER", "26")) + # int() the env defaults explicitly: argparse's type= only coerces values passed on the command + # line, never the default — so an env-sourced (or absent) value would otherwise stay a str. + p.add_argument("--layer", type=int, default=int(os.environ.get("EMBEDDING_LAYER", "26"))) p.add_argument("--device", default=os.environ.get("DEVICE", "cuda")) - p.add_argument("--max-seq-len", type=int, default=os.environ.get("MAX_SEQ_LEN", "8192")) + p.add_argument("--max-seq-len", type=int, default=int(os.environ.get("MAX_SEQ_LEN", "8192"))) def _engine(args): @@ -82,7 +84,9 @@ def main(): ps = sub.add_parser("serve", help="start the FastAPI inference server") _add_common(ps) ps.add_argument("--host", default="0.0.0.0") - ps.add_argument("--port", type=int, default=os.environ.get("PORT", "8001")) + ps.add_argument( + "--port", type=int, default=int(os.environ.get("PORT", "8001")) + ) # int: uvicorn.run needs an int port pe = sub.add_parser("encode", help="annotate ONE sequence -> top features (JSON)") _add_common(pe) diff --git a/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/core.py b/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/core.py index 0c0e4d571a..36f8106270 100644 --- a/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/core.py +++ b/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/core.py @@ -455,6 +455,10 @@ def generate( full_prompt = resolved_tag + dna if not full_prompt: raise ValueError("Provide a prompt or pick an organism (need >=1 token to seed)") + # Reject an over-context prompt rather than silently truncating it in tokenize() (parity + # with annotate; the server maps "too long" -> 413). Raise MAX_SEQ_LEN to allow longer. + if len(full_prompt) > self.max_seq_len: + raise ValueError(f"prompt too long: {len(full_prompt)} > max_seq_len ({self.max_seq_len})") # Cap to the engine's configured context budget (prompt + generation must fit max_seq_len), # not an arbitrary constant. Raise MAX_SEQ_LEN at launch to generate longer — the 7B is the # long-context (1M) model, so it's memory-bound, not architecture-bound (OOD past training len). @@ -492,7 +496,10 @@ def _run(steer: bool) -> str: except Exception as e: # A CUDA device-side assert poisons the context: every later request would 500 until # restart. Mark the engine not-ready so /health flips to 503 and an orchestrator - # recycles the pod, instead of serving a permanently-wedged worker. + # recycles the pod, instead of serving a permanently-wedged worker. NOTE: this only + # recovers if something restarts the worker on a failed /health — without a + # readiness-based recycler it serves 503 until a human restarts (fail-closed, but a + # single bad request can take the replica down). Confirm the deployment recycles on 503. if _is_unrecoverable_cuda(e): self.ready = False logger.exception("unrecoverable CUDA error in generate() — marking engine not-ready") diff --git a/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/server.py b/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/server.py index b1660494cc..4e8c9ef0fe 100644 --- a/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/server.py +++ b/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/server.py @@ -97,6 +97,9 @@ async def lifespan(app: FastAPI): # Reject oversized request bodies up front (a multi-MB sequence would be read into memory # before per-field validation could reject it). Default 16 MiB; override with MAX_BODY_BYTES. + # NOTE: advisory — this trusts the Content-Length header, so a chunked request (no length) or a + # lying header bypasses it; it guards well-behaved clients, not a hard cap. Fine behind SSO; + # real enforcement would count streamed bytes. max_body = int(os.getenv("MAX_BODY_BYTES", str(16 * 1024 * 1024))) @app.middleware("http") @@ -148,6 +151,11 @@ def annotate(req: AnnotateRequest): if not req.feature_ids: raise HTTPException(400, "mode='pick' requires feature_ids") chosen = [int(i) for i in req.feature_ids] + # Pick ids are user-supplied; an out-of-range id would IndexError (500) and a negative + # one would silently index the wrong feature via torch negative-indexing. Reject -> 400. + bad = sorted({i for i in chosen if not (0 <= i < engine.n_features)}) + if bad: + raise HTTPException(400, f"feature_id(s) {bad} out of range [0, {engine.n_features})") else: k = max(1, min(int(req.k), 64)) chosen = [ft["feature_id"] for ft in engine.top_features(codes, tag_len=tag_len, k=k)] diff --git a/interpretability/sparse_autoencoders/recipes/evo2/tests/conftest.py b/interpretability/sparse_autoencoders/recipes/evo2/tests/conftest.py index 76119b48fc..077bad41ee 100644 --- a/interpretability/sparse_autoencoders/recipes/evo2/tests/conftest.py +++ b/interpretability/sparse_autoencoders/recipes/evo2/tests/conftest.py @@ -173,12 +173,18 @@ def generate(self, **kw): prompt = core.clean_dna(kw.get("prompt", "")) if not prompt and kw.get("organism") == "None (raw DNA)" and not kw.get("tag"): raise ValueError("need a seed") + # Match the real engine's response: features are feat_meta dicts keyed {id, label, strength} + # (NOT feature_id) — this is the shape the dashboard consumes through /generate. + feat_meta = [ + {"id": s["feature_id"], "label": self.labels.get(s["feature_id"]), "strength": s["strength"]} + for s in specs + ] resp = { "prompt": prompt, "organism": kw.get("organism"), "tag": kw.get("tag"), "steered": bool(specs), - "features": specs, + "features": feat_meta, "generation": {"sequence": "ACGT", "activations": {0: [1.0, 1.0, 1.0, 1.0]}}, "baseline": None, } diff --git a/interpretability/sparse_autoencoders/recipes/evo2/tests/test_cli.py b/interpretability/sparse_autoencoders/recipes/evo2/tests/test_cli.py index 6910e1518e..d541d3eeed 100644 --- a/interpretability/sparse_autoencoders/recipes/evo2/tests/test_cli.py +++ b/interpretability/sparse_autoencoders/recipes/evo2/tests/test_cli.py @@ -75,7 +75,11 @@ def test_generate_passes_raw_clamp_strings(fake, monkeypatch, capsys): # The CLI hands raw "ID[:STRENGTH]" strings straight to the engine; core normalizes them. assert fake.gen_kwargs["features"] == ["0:5", "1"] out = json.loads(capsys.readouterr().out) - assert out["features"] == [{"feature_id": 0, "strength": 5.0}, {"feature_id": 1, "strength": 1.0}] + # generate() echoes feat_meta keyed {id, label, strength} — the shape the viz consumes. + assert out["features"] == [ + {"id": 0, "label": "feat0", "strength": 5.0}, + {"id": 1, "label": "feat1", "strength": 1.0}, + ] def test_generate_rejects_malformed_clamp(fake, monkeypatch): diff --git a/interpretability/sparse_autoencoders/recipes/evo2/tests/test_core.py b/interpretability/sparse_autoencoders/recipes/evo2/tests/test_core.py index d39f293165..a353c97632 100644 --- a/interpretability/sparse_autoencoders/recipes/evo2/tests/test_core.py +++ b/interpretability/sparse_autoencoders/recipes/evo2/tests/test_core.py @@ -94,3 +94,9 @@ def test_generate_rejects_empty_prompt(): # "None (raw DNA)" has an empty tag, so an empty prompt leaves nothing to seed generation. with pytest.raises(ValueError): _engine().generate(prompt="", organism="None (raw DNA)") + + +def test_generate_rejects_overlong_prompt(): + # An over-context prompt is rejected (server -> 413), not silently truncated by tokenize(). + with pytest.raises(ValueError, match="too long"): + _engine(max_seq_len=16).generate(prompt="A" * 32, organism="None (raw DNA)") diff --git a/interpretability/sparse_autoencoders/recipes/evo2/tests/test_server.py b/interpretability/sparse_autoencoders/recipes/evo2/tests/test_server.py index 61340b5dfb..a3309c63e9 100644 --- a/interpretability/sparse_autoencoders/recipes/evo2/tests/test_server.py +++ b/interpretability/sparse_autoencoders/recipes/evo2/tests/test_server.py @@ -72,6 +72,15 @@ def test_annotate_pick_requires_ids(client): assert client.post("/annotate", json={"sequence": "ACGT", "mode": "pick"}).status_code == 400 +def test_annotate_pick_rejects_out_of_range_id(client, fake_engine): + # user-supplied pick ids: an over-range id must 400 (not 500/IndexError), a negative one must + # 400 (not silently index the wrong feature via torch negative-indexing). + over = client.post("/annotate", json={"sequence": "ACGT", "mode": "pick", "feature_ids": [fake_engine.n_features]}) + assert over.status_code == 400 + neg = client.post("/annotate", json={"sequence": "ACGT", "mode": "pick", "feature_ids": [-1]}) + assert neg.status_code == 400 + + def test_annotate_rejects_invalid_mode(client): assert client.post("/annotate", json={"sequence": "ACGT", "mode": "bogus"}).status_code == 400 From ea3c316e9971350d8cacaeddba99c1f7292ffc4f Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Tue, 23 Jun 2026 05:42:11 +0000 Subject: [PATCH 3/4] feat(serve): host-independent recovery from a CUDA wedge (exit + restart loop) A device-side assert poisons the process's CUDA context (unrecoverable in-process), so ready=False alone only recovers under a readiness-based recycler. Add restart-on-exit recovery, which almost every host provides: - core.generate: on an unrecoverable CUDA fault, if EXIT_ON_CUDA_WEDGE=1, os._exit(1) the worker (after ready=False). Default unset -> just fail-closed at 503 (safe for library/CLI/test use). - launch_inference.sh: for `serve`, export EXIT_ON_CUDA_WEDGE=1 and wrap in a restart loop (respawn on crash/wedge exit; stop on clean exit / Ctrl-C 130 / SIGTERM 143). Recovery now works with no external orchestrator (and composes with docker --restart / systemd / k8s). Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: Polina Binder --- .../recipes/evo2/scripts/launch_inference.sh | 17 +++++++++++++++++ .../recipes/evo2/src/evo2_sae/core.py | 16 ++++++++++------ 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/interpretability/sparse_autoencoders/recipes/evo2/scripts/launch_inference.sh b/interpretability/sparse_autoencoders/recipes/evo2/scripts/launch_inference.sh index 26768a4c46..e01083c8f2 100755 --- a/interpretability/sparse_autoencoders/recipes/evo2/scripts/launch_inference.sh +++ b/interpretability/sparse_autoencoders/recipes/evo2/scripts/launch_inference.sh @@ -34,4 +34,21 @@ fi source "$VENV/bin/activate" cd "$RECIPE_DIR" export PYTHONPATH="$RECIPE_DIR/src${PYTHONPATH:+:$PYTHONPATH}" + +# One-shot modes (encode/batch/generate) just run once. For `serve`, a bad request can trip a +# CUDA device-side assert that poisons the process context — unrecoverable in-process, restart is +# the only cure. So tell the engine to exit the worker on that fault (EXIT_ON_CUDA_WEDGE) and +# respawn it here, making recovery independent of the host (no k8s/liveness probe required). +# Restart on a crash/wedge exit; stop on a clean exit (0) or a signal (Ctrl-C 130 / SIGTERM 143). +if [[ "${1:-}" == "serve" ]]; then + export EXIT_ON_CUDA_WEDGE=1 + while true; do + python -m evo2_sae.cli "$@" && rc=0 || rc=$? + if [[ "$rc" -eq 0 || "$rc" -eq 130 || "$rc" -eq 143 ]]; then + exit "$rc" # clean shutdown / Ctrl-C / SIGTERM — don't respawn + fi + echo "[launch] serve exited ($rc); restarting in 2s…" >&2 + sleep 2 + done +fi exec python -m evo2_sae.cli "$@" diff --git a/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/core.py b/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/core.py index 36f8106270..b355b33403 100644 --- a/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/core.py +++ b/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/core.py @@ -494,15 +494,19 @@ def _run(steer: bool) -> str: main_dna = _run(steer=True) base_dna = _run(steer=False) if (compare_baseline and clamps) else None except Exception as e: - # A CUDA device-side assert poisons the context: every later request would 500 until - # restart. Mark the engine not-ready so /health flips to 503 and an orchestrator - # recycles the pod, instead of serving a permanently-wedged worker. NOTE: this only - # recovers if something restarts the worker on a failed /health — without a - # readiness-based recycler it serves 503 until a human restarts (fail-closed, but a - # single bad request can take the replica down). Confirm the deployment recycles on 503. + # A CUDA device-side assert poisons the context for the whole process — unrecoverable + # in-process (re-init won't clear it), so restart is the only cure. Mark not-ready so + # /health flips to 503; and if EXIT_ON_CUDA_WEDGE is set (launch_inference.sh sets it + # for `serve`, where it runs under a restart loop), exit the worker so ANY restart-on- + # exit supervisor (the launch loop / docker --restart / systemd / k8s) respawns it — + # host-independent recovery, not dependent on a readiness probe. Default (unset): just + # fail-closed at 503 (safe for library/CLI/test use, which must not kill the process). if _is_unrecoverable_cuda(e): self.ready = False logger.exception("unrecoverable CUDA error in generate() — marking engine not-ready") + if os.environ.get("EXIT_ON_CUDA_WEDGE") == "1": + logger.critical("EXIT_ON_CUDA_WEDGE=1 — exiting the worker for the supervisor to restart") + os._exit(1) raise resp = { From a819d98437d3e7284b20684ad55baeaba9409a10 Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Tue, 23 Jun 2026 06:00:45 +0000 Subject: [PATCH 4/4] fix(serve): venv-agnostic launch, signal-safe restart loop, /generate 413 test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - launch_inference.sh: stop managing the venv — assume it's already active (Docker: on PATH; bare metal: source the evo2_megatron .venv first, like the tests). Drops the messy VENV= passing; adds a clear "bionemo.evo2 not importable" preflight. - Restart loop signal fix (was a graceful-shutdown regression): run the worker in the background and `wait`, with a trap that forwards SIGTERM/SIGINT to it (uvicorn graceful shutdown) and stops the loop — so `docker stop`/k8s on PID 1 no longer orphans the worker. Adds a 10-restart cap + backoff so a persistent crash (e.g. port already bound) doesn't loop forever. Smoke-tested: SIGTERM stops in ~1s, not the worker's full lifetime. - /generate 413 now pinned at the server layer: FakeEngine raises "too long" past max_seq_len and test_generate_rejects_too_long drives POST /generate -> 413 (was only covered via test_core). - Reframe the CUDA-wedge comment: it's PURELY DEFENSIVE — _sanitize_steering neutralizes every client-reachable assert trigger, so a wedge implies a hardware/driver fault, not a crafted request (exit+restart is not a remote DoS). New triggers must extend _sanitize_steering. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: Polina Binder --- .../recipes/evo2/scripts/launch_inference.sh | 60 +++++++++++++------ .../recipes/evo2/src/evo2_sae/core.py | 21 ++++--- .../recipes/evo2/tests/conftest.py | 2 + .../recipes/evo2/tests/test_server.py | 5 ++ 4 files changed, 62 insertions(+), 26 deletions(-) diff --git a/interpretability/sparse_autoencoders/recipes/evo2/scripts/launch_inference.sh b/interpretability/sparse_autoencoders/recipes/evo2/scripts/launch_inference.sh index e01083c8f2..dc9476b41b 100755 --- a/interpretability/sparse_autoencoders/recipes/evo2/scripts/launch_inference.sh +++ b/interpretability/sparse_autoencoders/recipes/evo2/scripts/launch_inference.sh @@ -12,43 +12,65 @@ # Config via env. Required: EVO2_CKPT_DIR, SAE_CKPT_PATH. Optional (have defaults): # FEATURE_ANNOTATIONS, EMBEDDING_LAYER (26), DEVICE, PORT, CUDA_VISIBLE_DEVICES. # -# Requires the evo2_megatron recipe venv (provides bionemo.evo2 + megatron). +# Run INSIDE the evo2_megatron venv (it provides bionemo.evo2 + megatron) — same venv the tests +# use. In the Docker image it's already active (on PATH), so just run this. On bare metal, activate +# it first: source /.venv/bin/activate set -euo pipefail HERE="$(cd "$(dirname "$0")" && pwd)" RECIPE_DIR="$(cd "$HERE/.." && pwd)" # recipes/evo2 — so the evo2_sae package imports # Required (no hardcoded defaults — supply your own paths via env): -VENV="${VENV:?Set VENV to the evo2_megatron recipe .venv (provides bionemo.evo2 + megatron)}" export EVO2_CKPT_DIR="${EVO2_CKPT_DIR:?Set EVO2_CKPT_DIR to an Evo2 MBridge checkpoint directory}" export SAE_CKPT_PATH="${SAE_CKPT_PATH:?Set SAE_CKPT_PATH to a trained SAE checkpoint (.pt)}" # Optional: feature-label parquet (empty = features are unlabeled). Layer defaults to 26. export FEATURE_ANNOTATIONS="${FEATURE_ANNOTATIONS:-}" export EMBEDDING_LAYER="${EMBEDDING_LAYER:-26}" -if [[ ! -x "$VENV/bin/python" ]]; then - echo "ERROR: evo2_megatron venv not found at $VENV (build it with the recipe's .ci_build.sh)" >&2 +export PYTHONPATH="$RECIPE_DIR/src${PYTHONPATH:+:$PYTHONPATH}" # find evo2_sae without a pip install + +# We assume the evo2_megatron venv is already active (Docker: on PATH; bare metal: you sourced it). +# Fail with a clear message rather than a deep ImportError if it isn't. +if ! python -c "import bionemo.evo2" 2>/dev/null; then + echo "ERROR: bionemo.evo2 not importable — activate the evo2_megatron venv first" >&2 + echo " (source /.venv/bin/activate) or run inside the Docker image." >&2 exit 1 fi -source "$VENV/bin/activate" -cd "$RECIPE_DIR" -export PYTHONPATH="$RECIPE_DIR/src${PYTHONPATH:+:$PYTHONPATH}" - -# One-shot modes (encode/batch/generate) just run once. For `serve`, a bad request can trip a -# CUDA device-side assert that poisons the process context — unrecoverable in-process, restart is -# the only cure. So tell the engine to exit the worker on that fault (EXIT_ON_CUDA_WEDGE) and -# respawn it here, making recovery independent of the host (no k8s/liveness probe required). -# Restart on a crash/wedge exit; stop on a clean exit (0) or a signal (Ctrl-C 130 / SIGTERM 143). +# One-shot modes (encode/batch/generate) just run once -> exec, signals pass straight through. +# For `serve`, a bad request can trip a CUDA device-side assert that poisons the process context +# (unrecoverable in-process — restart is the only cure). So tell the engine to exit the worker on +# that fault (EXIT_ON_CUDA_WEDGE) and respawn it here, making recovery host-independent. +# +# The worker runs in the BACKGROUND and we `wait` on it so this supervisor stays responsive to +# signals: a trap forwards SIGTERM/SIGINT to the worker (triggering uvicorn's graceful shutdown) +# and stops the respawn loop — important when this script is PID 1 under `docker stop`/k8s, where +# bash would otherwise not forward the signal and orphan the worker. Crash/wedge -> respawn with +# backoff, capped so a persistent failure (e.g. port already bound) doesn't loop forever. if [[ "${1:-}" == "serve" ]]; then export EXIT_ON_CUDA_WEDGE=1 - while true; do - python -m evo2_sae.cli "$@" && rc=0 || rc=$? - if [[ "$rc" -eq 0 || "$rc" -eq 130 || "$rc" -eq 143 ]]; then - exit "$rc" # clean shutdown / Ctrl-C / SIGTERM — don't respawn + child="" + stop=0 + trap 'stop=1; [[ -n "$child" ]] && kill -TERM "$child" 2>/dev/null || true' TERM INT + rc=0 + fails=0 + while [[ "$stop" -eq 0 ]]; do + python -m evo2_sae.cli "$@" & + child=$! + wait "$child" && rc=0 || rc=$? + child="" + # stop on a clean exit or when a signal asked us to (143 SIGTERM / 130 SIGINT belt-and-suspenders). + [[ "$stop" -eq 1 || "$rc" -eq 0 || "$rc" -eq 143 || "$rc" -eq 130 ]] && break + fails=$((fails + 1)) + if [[ "$fails" -ge 10 ]]; then + echo "[launch] serve exited ($rc) $fails times — giving up; fix the cause." >&2 + break fi - echo "[launch] serve exited ($rc); restarting in 2s…" >&2 - sleep 2 + backoff=$([[ "$fails" -lt 5 ]] && echo 2 || echo 10) + echo "[launch] serve exited ($rc); restart $fails/10 in ${backoff}s…" >&2 + sleep "$backoff" & + wait $! 2>/dev/null || true # interruptible: a signal during backoff stops the loop promptly done + exit "$rc" fi exec python -m evo2_sae.cli "$@" diff --git a/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/core.py b/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/core.py index b355b33403..f69af993a3 100644 --- a/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/core.py +++ b/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/core.py @@ -494,13 +494,20 @@ def _run(steer: bool) -> str: main_dna = _run(steer=True) base_dna = _run(steer=False) if (compare_baseline and clamps) else None except Exception as e: - # A CUDA device-side assert poisons the context for the whole process — unrecoverable - # in-process (re-init won't clear it), so restart is the only cure. Mark not-ready so - # /health flips to 503; and if EXIT_ON_CUDA_WEDGE is set (launch_inference.sh sets it - # for `serve`, where it runs under a restart loop), exit the worker so ANY restart-on- - # exit supervisor (the launch loop / docker --restart / systemd / k8s) respawns it — - # host-independent recovery, not dependent on a readiness probe. Default (unset): just - # fail-closed at 503 (safe for library/CLI/test use, which must not kill the process). + # PURELY DEFENSIVE: the known client-reachable CUDA-assert triggers are all neutralized + # earlier by _sanitize_steering (out-of-range id, magnitude cap, non-finite strength, + # temperature<=0, negative top_k), so a wedge here implies a genuine hardware/driver + # fault, NOT a crafted request — it is not client-inducible (so exit+restart is not a + # remote DoS vector). If that ever stops holding, a new trigger must be added to + # _sanitize_steering, not handled by leaning on this path. + # + # When it does happen, the device-side assert poisons the CUDA context for the whole + # process — unrecoverable in-process (re-init won't clear it), so restart is the only + # cure. Mark not-ready so /health flips to 503; and if EXIT_ON_CUDA_WEDGE is set + # (launch_inference.sh sets it for `serve`, which runs under a restart loop), exit the + # worker so ANY restart-on-exit supervisor (the launch loop / docker --restart / + # systemd / k8s) respawns it — host-independent recovery, no readiness probe required. + # Default (unset): fail-closed at 503 (safe for library/CLI/test use — must not exit). if _is_unrecoverable_cuda(e): self.ready = False logger.exception("unrecoverable CUDA error in generate() — marking engine not-ready") diff --git a/interpretability/sparse_autoencoders/recipes/evo2/tests/conftest.py b/interpretability/sparse_autoencoders/recipes/evo2/tests/conftest.py index 077bad41ee..da2cfdaad3 100644 --- a/interpretability/sparse_autoencoders/recipes/evo2/tests/conftest.py +++ b/interpretability/sparse_autoencoders/recipes/evo2/tests/conftest.py @@ -173,6 +173,8 @@ def generate(self, **kw): prompt = core.clean_dna(kw.get("prompt", "")) if not prompt and kw.get("organism") == "None (raw DNA)" and not kw.get("tag"): raise ValueError("need a seed") + if len(prompt) > self.max_seq_len: # mirror the real over-context reject (server -> 413) + raise ValueError(f"prompt too long: {len(prompt)} > max_seq_len ({self.max_seq_len})") # Match the real engine's response: features are feat_meta dicts keyed {id, label, strength} # (NOT feature_id) — this is the shape the dashboard consumes through /generate. feat_meta = [ diff --git a/interpretability/sparse_autoencoders/recipes/evo2/tests/test_server.py b/interpretability/sparse_autoencoders/recipes/evo2/tests/test_server.py index a3309c63e9..57d74bb3c7 100644 --- a/interpretability/sparse_autoencoders/recipes/evo2/tests/test_server.py +++ b/interpretability/sparse_autoencoders/recipes/evo2/tests/test_server.py @@ -102,6 +102,11 @@ def test_generate_rejects_out_of_range_feature(client): assert r.status_code == 400 # the wedge guard, surfaced to the client +def test_generate_rejects_too_long(client, fake_engine): + seq = "A" * (fake_engine.max_seq_len + 1) # exceeds the context budget -> 413 (parity w/ annotate) + assert client.post("/generate", json={"prompt": seq}).status_code == 413 + + def test_generate_compare_baseline(client): b = client.post( "/generate",