From cc98427e7c19490b4024b41db5ed9838d2c9a940 Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Wed, 10 Jun 2026 23:12:48 +0000 Subject: [PATCH 1/2] sae: causal feature-steering primitive (delta-clamp hook) Add sae/steering.py: clamp_hook(sae, {feature_idx: value}) + a steer() context manager. A forward hook on the SAE's layer that re-encodes the activation, overrides features in code-space, and adds decode(clamped) - decode(original) back -- the delta approach, so the SAE reconstruction error cancels and only the clamped feature's decoder contribution moves the activation (value=0 ablates; negative reverses). Model-agnostic (needs only encode_pre_act/decode/top_k); measure the effect by running the model with vs without the hook. Ported from the CodonFM steering notebooks. CPU test in sae/tests/test_steering.py. Co-Authored-By: Claude Opus 4.8 Signed-off-by: Polina Binder --- .../sae/src/sae/steering.py | 77 +++++++++++++++++++ .../sae/tests/test_steering.py | 68 ++++++++++++++++ 2 files changed, 145 insertions(+) create mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/steering.py create mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_steering.py diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/steering.py b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/steering.py new file mode 100644 index 0000000000..c061e38533 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/steering.py @@ -0,0 +1,77 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Causal feature steering for SAEs — clamp features in code-space, inject only the delta. + +A forward hook on the layer the SAE was trained on: it re-encodes the layer output through +the SAE, overrides chosen features in code-space, decodes, and adds the **delta** back to the +activation. Because we add ``decode(clamped) - decode(original)`` (not the recon itself), the +SAE's reconstruction error cancels and only the clamped feature's decoder contribution moves +the activation. Model-agnostic: needs only the SAE (``encode_pre_act`` / ``decode`` / ``top_k``) +and the module to hook. Measure the effect (e.g. ΔP of a target token) by running the model +with vs. without the hook. +""" + +from contextlib import contextmanager +from typing import Dict + +import torch + + +def clamp_hook(sae, clamps: Dict[int, float]): + """Build a forward hook that clamps ``{feature_idx: value}`` via the delta method. + + The hook adds ``decode(clamped_codes) - decode(original_codes)`` to the hooked module's + output, so the SAE reconstruction error cancels. ``value=0`` ablates a feature; a negative + value reverses its decoder direction. Works whether the module returns a tensor or a tuple + whose first element is the hidden state. + + Args: + sae: A trained SAE exposing ``encode_pre_act(x) -> (pre_act, info)``, ``decode(codes, info)``, + and ``top_k``. + clamps: Map of feature index -> absolute code value to force at every position. + + Returns: + A ``register_forward_hook``-compatible ``hook(module, inputs, output)``. + """ + items = [(int(f), float(v)) for f, v in clamps.items()] + + def hook(module, inputs, output): + h, rest = (output[0], output[1:]) if isinstance(output, tuple) else (output, None) + dtype, shape = h.dtype, h.shape + h_flat = h.reshape(-1, h.shape[-1]).float() + with torch.no_grad(): + pre_act, info = sae.encode_pre_act(h_flat) + codes = torch.relu(pre_act) + kvals, kidx = torch.topk(codes, sae.top_k, dim=-1) + codes_orig = torch.zeros_like(codes).scatter(-1, kidx, kvals) + codes_clamped = codes_orig.clone() + for f, v in items: + codes_clamped[:, f] = v + delta = sae.decode(codes_clamped, info) - sae.decode(codes_orig, info) + h_out = (h_flat + delta).to(dtype).reshape(shape) + return (h_out, *rest) if rest is not None else h_out + + return hook + + +@contextmanager +def steer(module, sae, clamps: Dict[int, float]): + """Register the clamp hook on ``module`` for the duration of the ``with`` block, then remove it.""" + handle = module.register_forward_hook(clamp_hook(sae, clamps)) + try: + yield + finally: + handle.remove() diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_steering.py b/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_steering.py new file mode 100644 index 0000000000..a3ac2c5d3e --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_steering.py @@ -0,0 +1,68 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CPU tests for sae.steering: the delta-clamp hook adds exactly decode(clamped) - decode(orig).""" + +import torch +from sae.architectures import TopKSAE +from sae.steering import clamp_hook, steer +from torch import nn + + +def _sae(): + torch.manual_seed(0) + return TopKSAE(input_dim=8, hidden_dim=16, top_k=4, normalize_input=False) + + +def test_no_clamp_is_a_noop(): + """An empty clamp map leaves the activation unchanged.""" + sae, m, x = _sae(), nn.Identity(), torch.randn(5, 8) + with steer(m, sae, {}): + out = m(x) + assert torch.allclose(out, x, atol=1e-5) + + +def test_clamp_adds_decoder_delta(): + """Clamping a feature shifts the activation by exactly decode(clamped) - decode(orig).""" + sae, m, x = _sae(), nn.Identity(), torch.randn(5, 8) + with torch.no_grad(): + pre, info = sae.encode_pre_act(x.float()) + codes = torch.relu(pre) + kv, ki = torch.topk(codes, sae.top_k, dim=-1) + co = torch.zeros_like(codes).scatter(-1, ki, kv) + cc = co.clone() + cc[:, 3] = 5.0 + expected = x + (sae.decode(cc, info) - sae.decode(co, info)) + with steer(m, sae, {3: 5.0}): + out = m(x) + assert torch.allclose(out, expected, atol=1e-4) + + +def test_tuple_output_first_element_steered_rest_preserved(): + """When the module returns a tuple, only the hidden state (elem 0) is steered.""" + + class M(nn.Module): + def forward(self, x): + return (x, "meta") + + sae, x = _sae(), torch.randn(3, 8) + m = M() + handle = m.register_forward_hook(clamp_hook(sae, {0: 2.0})) + out = m(x) + handle.remove() + assert isinstance(out, tuple) + assert out[1] == "meta" + assert out[0].shape == x.shape + assert not torch.allclose(out[0], x) # the clamp moved it From 3583c00296fb5b8451950c0c000cada42499aab8 Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Wed, 10 Jun 2026 23:56:07 +0000 Subject: [PATCH 2/2] evo2: SAE steering harness (steer.py) on sae.steering Add recipes/evo2/scripts/steer.py: encode a sequence -> its active features, then clamp a target feature across a strength sweep (dose-response) and apply the same clamp to control features (selectivity), comparing each steered continuation to baseline. Uses sae.steering.clamp_hook registered on the Evo2 decoder layer; loads the model/SAE via Evo2SAE (lazy import). GPU harness (run on H100), not a CPU test. Note: clamp_hook steers prefill+decode; the decode-only variant is Evo2SAE._clamp_hook, unifying them is a follow-up. Co-Authored-By: Claude Opus 4.8 Signed-off-by: Polina Binder --- .../recipes/evo2/scripts/steer.py | 121 ++++++++++++++++++ 1 file changed, 121 insertions(+) create mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/steer.py diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/steer.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/steer.py new file mode 100644 index 0000000000..5cbf20a8e0 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/steer.py @@ -0,0 +1,121 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Evo2 SAE steering harness — clamp features and measure the causal effect on generation. + +Uses ``sae.steering.clamp_hook`` (the shared delta-clamp) registered on the Evo2 decoder layer +the SAE was trained on. Workflow: encode a sequence to find its active features, then for a +**target** feature sweep the clamp strength (dose-response) and for **control** features apply +the same clamp (selectivity), each time comparing the steered continuation to the baseline. + +GPU harness — run on an H100 with the inference engine available; this is not a CPU unit test. + + python steer.py --evo2-ckpt-dir --sae-checkpoint --layer 26 \ + --sequence ATGGCC... --feature 29244 --controls 12345,54321 --strengths 0,50,100,200 + +Note: ``sae.steering.clamp_hook`` clamps on *every* forward (prefill + decode), so it steers +the prompt as well as the continuation. The decode-only ("continuation-only") variant lives in +``evo2_sae_infer.core.Evo2SAE._clamp_hook``; unifying the two onto ``sae.steering`` (with a +``decode_only`` flag) is a planned follow-up. +""" + +from __future__ import annotations + +import argparse +import sys +from contextlib import nullcontext +from pathlib import Path + + +_HERE = Path(__file__).resolve().parent +sys.path.insert(0, str(_HERE)) +sys.path.insert(0, str(_HERE.parent)) +sys.path.insert(0, str(_HERE.parents[2] / "sae" / "src")) + +from sae.steering import steer # noqa: E402 + + +def _divergence(a: str, b: str): + """Return (first differing index, fraction of differing chars) over the shared prefix length.""" + n = min(len(a), len(b)) + first = next((i for i in range(n) if a[i] != b[i]), n) + diff = sum(1 for i in range(n) if a[i] != b[i]) / max(1, n) + return first, diff + + +def main(): + """Encode a sequence, then steer a target feature (dose-response) + control features (selectivity).""" + p = argparse.ArgumentParser(description="Evo2 SAE steering harness (clamp -> continuation effect).") + p.add_argument("--evo2-ckpt-dir", required=True) + p.add_argument("--sae-checkpoint", required=True) + p.add_argument("--layer", type=int, required=True) + p.add_argument("--sequence", required=True) + p.add_argument("--organism", default="None (raw DNA)") + p.add_argument("--feature", type=int, default=None, help="Target feature id (default: top labeled feature).") + p.add_argument("--controls", default="", help="Comma-separated control feature ids (selectivity).") + p.add_argument("--strengths", default="0,50,100,200", help="Comma-separated clamp strengths to sweep.") + p.add_argument("--n-tokens", type=int, default=60) + p.add_argument("--device", default="cuda") + a = p.parse_args() + + from bionemo.evo2.run import infer as INF # noqa: E402, I001, RUF100 + from evo2_sae_infer.core import Evo2SAE, clean_dna # noqa: E402, RUF100 + from megatron.core.utils import unwrap_model # noqa: E402, RUF100 + + eng = Evo2SAE(a.evo2_ckpt_dir, a.sae_checkpoint, a.layer, device=a.device).load() + + # 1. Encode -> the sequence's most-active features (pick a target if not given). + codes = eng.encode(a.sequence) + vals, ids = codes.max(0).values.topk(10) + print(f"top features on {a.sequence[:24]}...:") + target = a.feature + for v, i in zip(vals.tolist(), ids.tolist()): + lab = eng.labels.get(int(i)) + print(f" feat {int(i):6d} {str(lab):18s} max_act {v:7.2f}") + if target is None and lab: + target = int(i) + controls = [int(c) for c in a.controls.split(",") if c.strip()] + strengths = [float(s) for s in a.strengths.split(",")] + + # 2. The Evo2 decoder layer the SAE hooks + a clean (tag + DNA) prompt. + comp = eng._ensure_engine() + prompt = (eng.resolve_tag(a.organism, None) or "") + clean_dna(a.sequence) + layer_mod = unwrap_model(comp.model).decoder.layers[a.layer] + + def gen(clamps): + ctx = steer(layer_mod, eng.sae, clamps) if clamps else nullcontext() + with ctx: + out = INF.generate(comp, [prompt], max_new_tokens=a.n_tokens, temperature=0.0, top_k=1) + return clean_dna(INF._unwrap_result(out[0]).generated_text) + + base = gen({}) + print(f"\nbaseline: {base[:60]}") + print(f"\n=== dose-response: feature {target} ({eng.labels.get(target)}) ===") + for s in strengths: + steered = gen({target: s}) + first, diff = _divergence(base, steered) + print(f" strength {s:7.1f}: diverges@{first:3d} {diff:6.1%} changed {steered[:44]}") + + if controls: + s = strengths[-1] + print(f"\n=== selectivity: control features clamped to {s} ===") + for c in controls: + steered = gen({c: s}) + first, diff = _divergence(base, steered) + print(f" control {c:6d} ({str(eng.labels.get(c)):16s}): diverges@{first:3d} {diff:6.1%} changed") + + +if __name__ == "__main__": + main()