From 31dd54749fdbfd2b5273cf996361a7248ce5e90a Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Tue, 23 Jun 2026 05:55:58 +0000 Subject: [PATCH 1/3] sae: probing primitives on the migrated layout + a CPU CI lane Re-lands #1629 (sae.eval.probing: AUROC / domain-F1 / linear probes + ActivationBuffer) onto the post-#1633 top-level layout, and adds a dedicated CPU workflow (ubuntu-latest, no model/GPU) that runs the model-agnostic probing tests. Separate from the evo2 GPU lane; the tensor-parallel sae tests (torchrun/multi-GPU) are out of scope here. Validated: tests/test_probing.py -> 6 passed (CPU). Co-Authored-By: Claude Opus 4.8 Signed-off-by: Polina Binder --- .github/workflows/unit-tests-sae.yaml | 51 ++++ .../sae/src/sae/eval/__init__.py | 26 ++ .../sae/src/sae/eval/probing.py | 277 ++++++++++++++++++ .../sae/tests/test_probing.py | 142 +++++++++ 4 files changed, 496 insertions(+) create mode 100644 .github/workflows/unit-tests-sae.yaml create mode 100644 interpretability/sparse_autoencoders/sae/src/sae/eval/probing.py create mode 100644 interpretability/sparse_autoencoders/sae/tests/test_probing.py diff --git a/.github/workflows/unit-tests-sae.yaml b/.github/workflows/unit-tests-sae.yaml new file mode 100644 index 0000000000..3baab5d6d4 --- /dev/null +++ b/.github/workflows/unit-tests-sae.yaml @@ -0,0 +1,51 @@ +name: "BioNeMo SAE Library CI" + +# CPU unit tests for the model-agnostic `sae` library (interpretability/sparse_autoencoders/sae). +# Separate from the evo2 GPU lane: these need no model and no GPU — just CPU torch + numpy — so +# they run on a cheap ubuntu-latest runner. Scoped to the model-agnostic tests; the +# tensor-parallel tests (test_tp_*) need torchrun/multi-GPU and are intentionally not run here. + +on: + push: + branches: + - "pull-request/[0-9]+" + - "dependabot/**" + paths: + - "interpretability/sparse_autoencoders/sae/**" + - ".github/workflows/unit-tests-sae.yaml" + merge_group: + types: [checks_requested] + schedule: + - cron: "0 9 * * *" # Runs at 9 AM UTC daily (2 AM MST) + +defaults: + run: + shell: bash -x -e -u -o pipefail {0} + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs: + sae-cpu-tests: + runs-on: ubuntu-latest + name: "sae-lib-tests (cpu)" + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + sparse-checkout: interpretability/sparse_autoencoders/sae + sparse-checkout-cone-mode: false + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Install (CPU torch + test deps) + working-directory: interpretability/sparse_autoencoders/sae + run: pip install --extra-index-url https://download.pytorch.org/whl/cpu -e '.[dev]' + + - name: Run model-agnostic probing tests + working-directory: interpretability/sparse_autoencoders/sae + run: pytest -v tests/test_probing.py diff --git a/interpretability/sparse_autoencoders/sae/src/sae/eval/__init__.py b/interpretability/sparse_autoencoders/sae/src/sae/eval/__init__.py index 1039045c8a..208c9ef3cd 100644 --- a/interpretability/sparse_autoencoders/sae/src/sae/eval/__init__.py +++ b/interpretability/sparse_autoencoders/sae/src/sae/eval/__init__.py @@ -22,6 +22,20 @@ compute_loss_recovered, evaluate_loss_recovered, ) +from .probing import ( + ActivationBuffer, + annotate_features, + auroc_all, + auroc_vec, + best_single_train_test, + decode_eval, + domain_f1, + fit_logreg, + fit_softmax, + macro_auroc, + split_indices, + standardize, +) from .reconstruction import ( ReconstructionMetrics, compute_reconstruction_metrics, @@ -31,16 +45,28 @@ __all__ = [ + "ActivationBuffer", "DeadLatentStats", "DeadLatentTracker", "EvalResults", "LossRecoveredResult", "ReconstructionMetrics", "SparsityMetrics", + "annotate_features", + "auroc_all", + "auroc_vec", + "best_single_train_test", "compute_loss_recovered", "compute_reconstruction_metrics", + "decode_eval", + "domain_f1", "evaluate_loss_recovered", "evaluate_reconstruction", "evaluate_sae", "evaluate_sparsity", + "fit_logreg", + "fit_softmax", + "macro_auroc", + "split_indices", + "standardize", ] diff --git a/interpretability/sparse_autoencoders/sae/src/sae/eval/probing.py b/interpretability/sparse_autoencoders/sae/src/sae/eval/probing.py new file mode 100644 index 0000000000..ec36394451 --- /dev/null +++ b/interpretability/sparse_autoencoders/sae/src/sae/eval/probing.py @@ -0,0 +1,277 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Model-agnostic SAE feature-probing metrics + the activation-buffer artifact. + +Everything here is a pure function of a probing buffer (per-token feature codes, +an optional dense-residual twin, per-token labels, optional instance IDs). Recipe +drivers (e.g. Evo2) only produce the buffer; all scoring lives here so it is shared +and reusable. Companions in this package: loss_recovered (fidelity), reconstruction, +sparsity, dead_latents. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Dict, Optional + +import numpy as np +import torch + + +# ───────────────────────────────────────────────────────────── artifact +@dataclass +class ActivationBuffer: + """A probing buffer: SAE codes (+ optional dense twin), per-token labels, instance IDs.""" + + codes: np.ndarray # [N, F] float16 SAE feature activations + labels: np.ndarray # [N, L] bool + label_names: list + dense: Optional[np.ndarray] = None # [N, H] float16 raw layer residual (dense twin) + instances: Optional[Dict[str, np.ndarray]] = None # {concept: [N] int32, -1 outside} + + def save(self, path: str) -> None: + """Write codes, labels, names (+ optional dense twin / instance ids) to an .npz.""" + d = {"codes": self.codes, "labels": self.labels, "label_names": np.array(self.label_names)} + if self.dense is not None: + d["dense"] = self.dense + for k, v in (self.instances or {}).items(): + d[f"inst_{k}"] = v + np.savez(path, **d) + + @classmethod + def load(cls, path: str) -> "ActivationBuffer": + """Load an ActivationBuffer from an .npz written by save(). + + Warning: + Uses ``allow_pickle=True`` (the per-concept instance dict is an object array); + only load buffers from trusted sources. + """ + z = np.load(path, allow_pickle=True) + inst = {k[5:]: z[k] for k in z.files if k.startswith("inst_")} + return cls( + codes=z["codes"], + labels=z["labels"], + label_names=list(z["label_names"]), + dense=z["dense"] if "dense" in z.files else None, + instances=inst or None, + ) + + @property + def name_idx(self): + """Map each label name to its column index in ``labels``.""" + return {n: i for i, n in enumerate(self.label_names)} + + +def split_indices(n, test_frac=0.4, seed=0): + """Deterministic train/test split of ``range(n)``; returns (train_idx, test_idx).""" + perm = torch.randperm(n, generator=torch.Generator().manual_seed(seed)) + nte = int(n * test_frac) + return perm[nte:], perm[:nte] # train, test + + +def standardize(X, tr): + """Return (mean, std) of ``X`` over the train rows ``tr`` (std floored by 1e-6).""" + mu, sd = X[tr].mean(0), X[tr].std(0) + 1e-6 + return mu, sd + + +# ───────────────────────────────────────────────────────────── AUROC +@torch.no_grad() +def auroc_all(X, Y, chunk=1024): + """X [N,F], Y [N,L] bool -> AUROC [F,L] via vectorized rank statistic.""" + N, F = X.shape + L = Y.shape[1] + y = Y.float() + npos = y.sum(0) + nneg = N - npos + valid = (npos > 0) & (nneg > 0) + denom = (npos * nneg).clamp_min(1.0) + half = npos * (npos + 1) / 2.0 + out = torch.full((F, L), 0.5, device=X.device) + for c0 in range(0, F, chunk): + c1 = min(c0 + chunk, F) + ranks = X[:, c0:c1].float().argsort(0).argsort(0).float() + 1.0 + au = (y.t() @ ranks - half[:, None]) / denom[:, None] + out[c0:c1] = au.t() + out[:, ~valid] = 0.5 + return out + + +@torch.no_grad() +def auroc_vec(scores, y): + """AUROC of a single score vector against boolean labels ``y`` (0.5 if degenerate).""" + n = scores.numel() + npos = int(y.sum()) + nneg = n - npos + if npos == 0 or nneg == 0: + return 0.5 + ranks = scores.argsort().argsort().float() + 1.0 + return float((ranks[y].sum() - npos * (npos + 1) / 2) / (npos * nneg)) + + +@torch.no_grad() +def best_single_train_test(Xtr, ytr, Xte, yte, chunk=2048): + """Pick the best single dim on TRAIN, report ITS AUROC on TEST (no winner's curse).""" + + def per_feat(X, y): + n = X.shape[0] + npos = int(y.sum()) + nneg = n - npos + if npos == 0 or nneg == 0: + return None + yf = y.float() + F = X.shape[1] + out = torch.empty(F, device=X.device) + for c0 in range(0, F, chunk): + ranks = X[:, c0 : c0 + chunk].float().argsort(0).argsort(0).float() + 1.0 + out[c0 : c0 + chunk] = (yf @ ranks - npos * (npos + 1) / 2) / (npos * nneg) + return out + + a_tr = per_feat(Xtr, ytr) + if a_tr is None: + return float("nan") + f = int(torch.maximum(a_tr, 1 - a_tr).argmax()) + flip = bool(a_tr[f] < 0.5) + a_te = auroc_vec(Xte[:, f].float(), yte) + return float(1 - a_te if flip else a_te) + + +@torch.no_grad() +def annotate_features(codes, labels, label_names, min_auroc: float = 0.8, chunk: int = 1024): + """Assign each feature the concept it best separates (by AUROC) -> the feature->label table. + + The persistence half of probing: turns a buffer (codes + concept labels) into per-feature + annotations. For each feature, takes the concept with the highest AUROC and keeps it only if + that AUROC >= ``min_auroc`` (unconfident features stay unlabeled). + + Args: + codes: [N, F] feature activations. + labels: [N, L] bool concept masks. + label_names: length-L concept names. + min_auroc: keep a feature's annotation only if its best AUROC clears this. + chunk: feature chunk size for ``auroc_all``. + + Returns: + ``[{"feature_id": int, "label": str, "auroc": float}]`` sorted by feature_id. + """ + au = auroc_all(codes, labels, chunk=chunk) # [F, L] + best = au.max(dim=1) + names = list(label_names) + out = [] + for f in range(au.shape[0]): + score = float(best.values[f]) + if score >= min_auroc: + out.append({"feature_id": int(f), "label": str(names[int(best.indices[f])]), "auroc": round(score, 4)}) + return out + + +# ───────────────────────────────────────────────────────────── linear probes +def fit_logreg(Xtr, ytr, steps=400, lr=0.05, wd=1e-2): + """Fit a logistic-regression probe (Adam + BCE-with-logits); returns (w, b).""" + w = torch.zeros(Xtr.shape[1], device=Xtr.device, requires_grad=True) + b = torch.zeros(1, device=Xtr.device, requires_grad=True) + opt = torch.optim.Adam([w, b], lr=lr, weight_decay=wd) + lossf = torch.nn.BCEWithLogitsLoss() + with torch.enable_grad(): + for _ in range(steps): + opt.zero_grad() + lossf(Xtr @ w + b, ytr).backward() + opt.step() + return w.detach(), b.detach() + + +def fit_softmax(Xtr, ytr, nclass, steps=400, lr=0.05, wd=1e-2): + """Fit a multinomial-softmax probe (Adam + cross-entropy); returns (W, b).""" + W = torch.zeros(Xtr.shape[1], nclass, device=Xtr.device, requires_grad=True) + b = torch.zeros(nclass, device=Xtr.device, requires_grad=True) + opt = torch.optim.Adam([W, b], lr=lr, weight_decay=wd) + lossf = torch.nn.CrossEntropyLoss() + with torch.enable_grad(): + for _ in range(steps): + opt.zero_grad() + lossf(Xtr @ W + b, ytr).backward() + opt.step() + return W.detach(), b.detach() + + +@torch.no_grad() +def macro_auroc(logits, y, nclass): + """Macro-averaged one-vs-rest AUROC over ``nclass``; returns (mean_auroc, n_classes_scored).""" + aucs = [] + for c in range(nclass): + yc = y == c + npos = int(yc.sum()) + if npos == 0 or npos == len(y): + continue + ranks = logits[:, c].argsort().argsort().float() + 1.0 + aucs.append(float((ranks[yc].sum() - npos * (npos + 1) / 2) / (npos * (len(y) - npos)))) + return (sum(aucs) / max(1, len(aucs))), len(aucs) + + +def decode_eval(Xtr, ytr, Xte, yte, nclass, **kw): + """Fit a softmax probe on train; return (accuracy, macro_auroc, n_classes) on test.""" + W, b = fit_softmax(Xtr, ytr, nclass, **kw) + logits = Xte @ W + b + acc = float((logits.argmax(1) == yte).float().mean()) + mauc, ncls = macro_auroc(logits, yte, nclass) + return acc, mauc, ncls + + +# ───────────────────────────────────────────────────────────── domain-adjusted F1 +@torch.no_grad() +def domain_f1(codes, fmax, concept_mask, inst_ids, thresholds=(0.15, 0.3, 0.5, 0.6, 0.8), chunk=1024): + """InterPLM domain-adjusted F1 per feature: precision-per-position, recall-per-instance. + + codes [P,F] (>=0), fmax [F], concept_mask [P] bool, inst_ids [P] int (-1 outside). + Returns (best_f1[F], best_threshold[F]) over the threshold sweep. + """ + _, F = codes.shape + dev = codes.device + valid = inst_ids >= 0 + uniq = torch.unique(inst_ids[valid]) + n_inst = len(uniq) + if n_inst == 0: + return torch.zeros(F, device=dev), torch.zeros(F, device=dev) + # size = max instance id + 2: +1 to index by the max id itself, +1 headroom so a -1 + # sentinel never indexes out of bounds when remapped. + remap = torch.full((int(inst_ids.max().item()) + 2,), -1, device=dev, dtype=torch.long) + remap[uniq.long()] = torch.arange(n_inst, device=dev) + inst_c = torch.where(valid, remap[inst_ids.long()], torch.full_like(inst_ids, -1, dtype=torch.long)) + best_f1 = torch.zeros(F, device=dev) + best_t = torch.zeros(F, device=dev) + for c0 in range(0, F, chunk): + c1 = min(c0 + chunk, F) + cn = codes[:, c0:c1] / fmax[c0:c1].clamp_min(1e-6) + C = c1 - c0 + cb = torch.zeros(C, device=dev) + ct = torch.zeros(C, device=dev) + for t in thresholds: + fire = cn > t + firing = fire.sum(0).float() + prec = torch.where( + firing > 0, (fire & concept_mask[:, None]).sum(0).float() / firing, torch.zeros(C, device=dev) + ) + bucket = torch.zeros(n_inst, C, device=dev) + vm = inst_c >= 0 + bucket.index_reduce_(0, inst_c[vm], fire[vm].float(), "amax", include_self=False) + recall = (bucket > 0).sum(0).float() / n_inst + f1 = torch.where((prec + recall) > 0, 2 * prec * recall / (prec + recall), torch.zeros(C, device=dev)) + upd = f1 > cb + cb = torch.where(upd, f1, cb) + ct = torch.where(upd, torch.full_like(ct, t), ct) + best_f1[c0:c1] = cb + best_t[c0:c1] = ct + return best_f1, best_t diff --git a/interpretability/sparse_autoencoders/sae/tests/test_probing.py b/interpretability/sparse_autoencoders/sae/tests/test_probing.py new file mode 100644 index 0000000000..d38868a962 --- /dev/null +++ b/interpretability/sparse_autoencoders/sae/tests/test_probing.py @@ -0,0 +1,142 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CPU correctness tests for sae.eval.probing (no model / no GPU). + +One strong test per non-trivial metric: each checks the result against an independent +reference (a definitional oracle or a hand-computed value) rather than a loose sanity bound. +The trivial standardize helper is exercised transitively (decode_eval test); split_indices +folds into the buffer roundtrip. +""" + +import numpy as np +import torch +from sae.eval.probing import ( + ActivationBuffer, + annotate_features, + auroc_all, + best_single_train_test, + decode_eval, + domain_f1, + split_indices, +) + + +def _auroc_ref(scores: torch.Tensor, y: torch.Tensor) -> float: + """Definitional AUROC oracle: P(score+ > score-) over all positive/negative pairs. + + Computed by brute-force pair comparison — independent of the argsort rank-sum used by + auroc_all, so agreement validates that implementation (randn inputs => no ties). + """ + pos, neg = scores[y], scores[~y] + return float((pos[:, None] > neg[None, :]).float().mean()) + + +def test_auroc_all_matches_definition(): + """auroc_all matches the pairwise-definition AUROC for every (feature, label).""" + torch.manual_seed(0) + n, f, ell = 200, 6, 3 + x = torch.randn(n, f) + y = torch.randn(n, ell) > 0 + au = auroc_all(x, y) # [F, L] + for fi in range(f): + for li in range(ell): + assert abs(float(au[fi, li]) - _auroc_ref(x[:, fi], y[:, li])) < 1e-6 + + +def test_best_single_reports_flipped_test_auroc(): + """best_single picks the most-separating TRAIN feature and reports ITS test AUROC, + flipping a feature that separates by firing on the negatives (no winner's curse).""" + torch.manual_seed(0) + y = torch.cat([torch.zeros(10), torch.ones(10)]).bool() + # 'anti' fires on the y=0 class (train AUROC ~0 -> selected via 1-AUROC, flip=True); + # it stays anti-correlated on test, so the reported (flipped) test AUROC is ~1. + anti_tr = torch.cat([torch.ones(10), torch.zeros(10)]) + torch.randn(20) * 0.01 + anti_te = torch.cat([torch.ones(10), torch.zeros(10)]) + torch.randn(20) * 0.01 + xtr = torch.stack([anti_tr, torch.randn(20)], 1) # 2nd feature is noise + xte = torch.stack([anti_te, torch.randn(20)], 1) + assert best_single_train_test(xtr, y, xte, y.clone()) > 0.9 + + +def test_domain_f1_matches_hand_computed(): + """domain_f1 = precision-per-position, recall-per-instance, best over the threshold sweep. + + Two binary features over 6 positions, 2 annotation instances ({0,1} and {4}): + feat0 fires at an extra non-concept position -> prec 3/4, recall 2/2 -> F1 = 6/7 + feat1 fires exactly on concept positions -> prec 1, recall 2/2 -> F1 = 1 + """ + codes = torch.tensor([[1, 1], [1, 1], [1, 0], [0, 0], [1, 1], [0, 0]], dtype=torch.float) + fmax = codes.max(0).values + concept_mask = torch.tensor([1, 1, 0, 0, 1, 0], dtype=torch.bool) + inst_ids = torch.tensor([0, 0, -1, -1, 1, -1]) + f1, _ = domain_f1(codes, fmax, concept_mask, inst_ids) + assert abs(float(f1[0]) - 6 / 7) < 1e-4 + assert abs(float(f1[1]) - 1.0) < 1e-4 + + +def test_decode_eval_recovers_separable_classes(): + """The softmax decoder (fit_softmax + macro_auroc) separates separable classes and not noise.""" + torch.manual_seed(0) + dim, nclass = 8, 3 + centers = torch.eye(nclass, dim) * 6.0 + + def make(per): + ys = torch.arange(nclass).repeat_interleave(per) + return centers[ys] + torch.randn(len(ys), dim), ys + + xtr, ytr = make(40) + xte, yte = make(20) + acc, mauc, ncls = decode_eval(xtr, ytr, xte, yte, nclass, steps=400, lr=0.1) + assert acc > 0.9 and mauc > 0.9 and ncls == 3 + + # random features/labels -> no better than chance (1/3) + xr, yr = torch.randn(120, dim), torch.randint(0, nclass, (120,)) + acc_rand, _, _ = decode_eval(xr[:90], yr[:90], xr[90:], yr[90:], nclass, steps=400, lr=0.1) + assert acc_rand < 0.6 + + +def test_annotate_features_assigns_best_concept_above_threshold(): + """Each feature gets the concept it best separates; unconfident features stay unlabeled.""" + torch.manual_seed(0) + n = 200 + labels = torch.stack([torch.arange(n) % 2 == 0, torch.arange(n) < n // 2], 1) # [N, 2]: 'even', 'first_half' + detector = labels[:, 0].float() + torch.randn(n) * 0.01 # cleanly tracks 'even' + noise = torch.randn(n) # tracks nothing + codes = torch.stack([detector, noise], 1) # [N, 2 features] + ann = annotate_features(codes, labels, ["even", "first_half"], min_auroc=0.9) + assert {a["feature_id"]: a["label"] for a in ann} == {0: "even"} # feature 1 (noise) excluded + assert ann[0]["auroc"] > 0.99 + + +def test_buffer_roundtrip_and_split(tmp_path): + """ActivationBuffer save/load preserves codes/labels/names/dense/instances; split is a partition.""" + rng = np.random.default_rng(0) + codes = rng.random((10, 4)).astype(np.float16) + labels = np.tile(np.array([True, False, True]), (10, 1)) + dense = rng.random((10, 8)).astype(np.float16) + instances = {"exon": np.array([0, 0, -1, 1, 1, -1, 2, 2, 2, -1], np.int32)} + buf = ActivationBuffer(codes, labels, ["a", "b", "c"], dense=dense, instances=instances) + path = str(tmp_path / "buf.npz") + buf.save(path) + + lo = ActivationBuffer.load(path) + assert np.array_equal(lo.codes, codes) + assert np.array_equal(lo.dense, dense) + assert np.array_equal(lo.instances["exon"], instances["exon"]) + assert lo.name_idx["c"] == 2 + + tr, te = split_indices(100, test_frac=0.4, seed=0) + s_tr, s_te = set(tr.tolist()), set(te.tolist()) + assert s_tr.isdisjoint(s_te) and (s_tr | s_te) == set(range(100)) and len(s_te) == 40 From 13a06900cba5ce7c016d4d5bf2747308e49a4fde Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Tue, 23 Jun 2026 06:06:03 +0000 Subject: [PATCH 2/3] fix(sae): tie-correct AUROC ranks (average ranks) + tie/sparse tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit auroc_all / auroc_vec / best_single / macro_auroc ranked via argsort().argsort(), giving tied values arbitrary distinct ranks. SAE codes are sparse (heavy zero-mass), so that biased the AUROC on the real data distribution — and the oracle test only covered randn (no ties). Switch to average (Mann-Whitney) ranks via a vectorized searchsorted helper (keeps the all-features-at-once speed that motivates hand-rolling), make the oracle tie-aware, and add sparse-tie + constant-feature tests. Also documents why these metrics are hand-rolled. tests/test_probing.py -> 8 passed (CPU). Co-Authored-By: Claude Opus 4.8 Signed-off-by: Polina Binder --- .../sae/src/sae/eval/probing.py | 35 ++++++++++++++++--- .../sae/tests/test_probing.py | 31 +++++++++++++--- 2 files changed, 58 insertions(+), 8 deletions(-) diff --git a/interpretability/sparse_autoencoders/sae/src/sae/eval/probing.py b/interpretability/sparse_autoencoders/sae/src/sae/eval/probing.py index ec36394451..d8cd2dc095 100644 --- a/interpretability/sparse_autoencoders/sae/src/sae/eval/probing.py +++ b/interpretability/sparse_autoencoders/sae/src/sae/eval/probing.py @@ -20,6 +20,13 @@ drivers (e.g. Evo2) only produce the buffer; all scoring lives here so it is shared and reusable. Companions in this package: loss_recovered (fidelity), reconstruction, sparsity, dead_latents. + +These metrics are hand-rolled rather than pulled from scikit-learn / scipy on purpose: +``auroc_all`` ranks every feature against every concept in one vectorized, GPU-capable pass +(``sklearn.roc_auc_score`` is a per-(feature, concept) CPU loop — infeasible at ~65k features), +and it keeps this library's dependencies minimal. Correctness isn't traded away for it: ranks are +tie-averaged (the Mann-Whitney convention), so the result equals the reference AUROC even on the +heavy zero-mass of sparse SAE codes — pinned by the oracle tests. """ from __future__ import annotations @@ -88,6 +95,26 @@ def standardize(X, tr): return mu, sd +def _average_ranks(x: torch.Tensor) -> torch.Tensor: + """Tie-corrected (average) ranks along dim 0, 1-indexed. + + Tied values share their mean rank — the convention the Mann-Whitney/AUROC rank-sum assumes. + Plain ``argsort().argsort()`` hands the heavy zero-mass of sparse SAE codes arbitrary distinct + ranks, which biases the AUROC; this averages each tie instead. Vectorized via ``searchsorted`` + so it keeps the all-features-at-once speed: ``lt``/``le`` count values strictly-less / <= each + value, the tie group spans ordinal ranks ``lt+1..le``, whose mean is ``(lt+le+1)/2``. + """ + x = x.float() + xs = x.sort(dim=0).values + if x.dim() == 1: + lt = torch.searchsorted(xs, x, right=False) + le = torch.searchsorted(xs, x, right=True) + else: # searchsorted ranks along the last dim; transpose so the token axis (dim 0) is searched + lt = torch.searchsorted(xs.t().contiguous(), x.t().contiguous(), right=False).t() + le = torch.searchsorted(xs.t().contiguous(), x.t().contiguous(), right=True).t() + return (lt + le + 1).float() / 2.0 + + # ───────────────────────────────────────────────────────────── AUROC @torch.no_grad() def auroc_all(X, Y, chunk=1024): @@ -103,7 +130,7 @@ def auroc_all(X, Y, chunk=1024): out = torch.full((F, L), 0.5, device=X.device) for c0 in range(0, F, chunk): c1 = min(c0 + chunk, F) - ranks = X[:, c0:c1].float().argsort(0).argsort(0).float() + 1.0 + ranks = _average_ranks(X[:, c0:c1]) au = (y.t() @ ranks - half[:, None]) / denom[:, None] out[c0:c1] = au.t() out[:, ~valid] = 0.5 @@ -118,7 +145,7 @@ def auroc_vec(scores, y): nneg = n - npos if npos == 0 or nneg == 0: return 0.5 - ranks = scores.argsort().argsort().float() + 1.0 + ranks = _average_ranks(scores) return float((ranks[y].sum() - npos * (npos + 1) / 2) / (npos * nneg)) @@ -136,7 +163,7 @@ def per_feat(X, y): F = X.shape[1] out = torch.empty(F, device=X.device) for c0 in range(0, F, chunk): - ranks = X[:, c0 : c0 + chunk].float().argsort(0).argsort(0).float() + 1.0 + ranks = _average_ranks(X[:, c0 : c0 + chunk]) out[c0 : c0 + chunk] = (yf @ ranks - npos * (npos + 1) / 2) / (npos * nneg) return out @@ -216,7 +243,7 @@ def macro_auroc(logits, y, nclass): npos = int(yc.sum()) if npos == 0 or npos == len(y): continue - ranks = logits[:, c].argsort().argsort().float() + 1.0 + ranks = _average_ranks(logits[:, c]) aucs.append(float((ranks[yc].sum() - npos * (npos + 1) / 2) / (npos * (len(y) - npos)))) return (sum(aucs) / max(1, len(aucs))), len(aucs) diff --git a/interpretability/sparse_autoencoders/sae/tests/test_probing.py b/interpretability/sparse_autoencoders/sae/tests/test_probing.py index d38868a962..f0860987da 100644 --- a/interpretability/sparse_autoencoders/sae/tests/test_probing.py +++ b/interpretability/sparse_autoencoders/sae/tests/test_probing.py @@ -35,13 +35,16 @@ def _auroc_ref(scores: torch.Tensor, y: torch.Tensor) -> float: - """Definitional AUROC oracle: P(score+ > score-) over all positive/negative pairs. + """Definitional AUROC oracle: P(s+ > s-) + 0.5·P(s+ == s-) over all positive/negative pairs. - Computed by brute-force pair comparison — independent of the argsort rank-sum used by - auroc_all, so agreement validates that implementation (randn inputs => no ties). + The 0.5·tie term is the Mann-Whitney convention. Computed by brute-force pair comparison — + independent of the rank-sum in auroc_all — so agreement validates that implementation, + including on tied (sparse) inputs. """ pos, neg = scores[y], scores[~y] - return float((pos[:, None] > neg[None, :]).float().mean()) + gt = (pos[:, None] > neg[None, :]).float().mean() + eq = (pos[:, None] == neg[None, :]).float().mean() + return float(gt + 0.5 * eq) def test_auroc_all_matches_definition(): @@ -56,6 +59,26 @@ def test_auroc_all_matches_definition(): assert abs(float(au[fi, li]) - _auroc_ref(x[:, fi], y[:, li])) < 1e-6 +def test_auroc_all_handles_ties(): + """On sparse codes (heavily tied at 0) — the regime plain argsort ranks get wrong — auroc_all + still matches the tie-aware oracle. ~85% exact zeros, the rest positive, per feature.""" + torch.manual_seed(0) + n, f, ell = 300, 5, 2 + x = torch.where(torch.rand(n, f) < 0.15, torch.rand(n, f), torch.zeros(n, f)) # mostly exact 0 + y = torch.rand(n, ell) > 0.5 + au = auroc_all(x, y) + for fi in range(f): + for li in range(ell): + assert abs(float(au[fi, li]) - _auroc_ref(x[:, fi], y[:, li])) < 1e-5 + + +def test_auroc_all_constant_feature_is_half(): + """A constant (all-tied) feature scores exactly 0.5 — every positive/negative pair is a tie.""" + x = torch.zeros(50, 1) + y = (torch.arange(50) < 20).unsqueeze(1) # 20 pos / 30 neg, both present + assert abs(float(auroc_all(x, y)[0, 0]) - 0.5) < 1e-6 + + def test_best_single_reports_flipped_test_auroc(): """best_single picks the most-separating TRAIN feature and reports ITS test AUROC, flipping a feature that separates by firing on the negatives (no winner's curse).""" From 896b2532de609724cb4332b8d3425ea1eef6eb86 Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Tue, 23 Jun 2026 06:11:56 +0000 Subject: [PATCH 3/3] test(sae): cover degenerate-label AUROC, auroc_vec ties, and the buffer None paths - a never/always-firing concept -> AUROC 0.5 (the valid-mask branch; realistic for rare concepts) - auroc_vec directly (was only tested transitively via best_single) on tied scores - ActivationBuffer with no dense twin / no instances (the Optional -> None save/load paths) Co-Authored-By: Claude Opus 4.8 Signed-off-by: Polina Binder --- .../sae/tests/test_probing.py | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/interpretability/sparse_autoencoders/sae/tests/test_probing.py b/interpretability/sparse_autoencoders/sae/tests/test_probing.py index f0860987da..f1db77294a 100644 --- a/interpretability/sparse_autoencoders/sae/tests/test_probing.py +++ b/interpretability/sparse_autoencoders/sae/tests/test_probing.py @@ -27,6 +27,7 @@ ActivationBuffer, annotate_features, auroc_all, + auroc_vec, best_single_train_test, decode_eval, domain_f1, @@ -163,3 +164,35 @@ def test_buffer_roundtrip_and_split(tmp_path): tr, te = split_indices(100, test_frac=0.4, seed=0) s_tr, s_te = set(tr.tolist()), set(te.tolist()) assert s_tr.isdisjoint(s_te) and (s_tr | s_te) == set(range(100)) and len(s_te) == 40 + + +def test_auroc_all_degenerate_label_is_half(): + """A concept that never fires (all-False) or always fires (all-True) -> AUROC 0.5, not garbage. + + Realistic for rare genomic concepts (a buffer/split can contain only negatives). Exercises the + valid = (npos>0) & (nneg>0) guard. + """ + x = torch.randn(40, 3) + y = torch.zeros(40, 2, dtype=torch.bool) + y[:, 1] = True # col 0: never fires (npos=0); col 1: always fires (nneg=0) + au = auroc_all(x, y) + assert torch.allclose(au, torch.full_like(au, 0.5)) + + +def test_auroc_vec_matches_oracle_with_ties(): + """auroc_vec (the single-vector AUROC behind best_single) matches the tie-aware oracle.""" + scores = torch.tensor([0.0, 0.0, 0.0, 1.0, 1.0, 2.0]) + y = torch.tensor([False, True, False, True, False, True]) + assert abs(auroc_vec(scores, y) - _auroc_ref(scores, y)) < 1e-6 + + +def test_buffer_roundtrip_without_dense_or_instances(tmp_path): + """ActivationBuffer with no dense twin and no instances round-trips (the Optional -> None paths).""" + codes = np.zeros((5, 3), np.float16) + labels = np.ones((5, 2), bool) + buf = ActivationBuffer(codes, labels, ["a", "b"]) # dense=None, instances=None + path = str(tmp_path / "b.npz") + buf.save(path) + lo = ActivationBuffer.load(path) + assert lo.dense is None and lo.instances is None + assert np.array_equal(lo.codes, codes) and lo.label_names == ["a", "b"]