Skip to content

sae: shared probing primitives (eval metrics + ActivationBuffer)#1629

Open
polinabinder1 wants to merge 3 commits into
mainfrom
pbinder/sae-interp-primitives
Open

sae: shared probing primitives (eval metrics + ActivationBuffer)#1629
polinabinder1 wants to merge 3 commits into
mainfrom
pbinder/sae-interp-primitives

Conversation

@polinabinder1

@polinabinder1 polinabinder1 commented Jun 11, 2026

Copy link
Copy Markdown
Collaborator

Summary

Shared, model-agnostic SAE probing primitives in the sae package (sibling of loss_recovered/sparsity/dead_latents): scoring metrics + per-feature annotation, all pure functions of codes + labels.

Contents — sae.eval.probing

  • ActivationBuffer (codes + optional dense twin + per-token labels + instance ids)
  • AUROC: auroc_all, auroc_vec, best_single_train_test
  • decoders: fit_logreg / fit_softmax / macro_auroc / decode_eval
  • domain_f1 (precision-per-nt, recall-per-instance)
  • annotate_features (per-feature best concept by AUROC → the annotation table)

How to use

from sae.eval.probing import auroc_all, annotate_features
au  = auroc_all(codes, labels)                                   # [F, L]
ann = annotate_features(codes, labels, names, min_auroc=0.85)    # [{feature_id, label, auroc}]

pytest sae/tests/test_probing.py — CPU, no model.

Why hand-rolled (not sklearn)

GPU-vectorized over the whole ~32k-feature dictionary in one pass; sklearn is CPU/per-(scores,label) — CodonFM used it and had to subsample to ≤5k features. Package stays torch+numpy. Each metric is validated against an independent reference in the tests (pairwise-AUROC oracle, hand-computed domain_f1, etc.).

Base of: #1630 (eval labels) → #1636 (harness).

On external libraries (checked — not a win)

We evaluated replacing the hand-rolled metrics with sklearn/torchmetrics, function by function:

  • auroc_all — no library computes a vectorized [features × labels] AUROC matrix on GPU; sklearn.roc_auc_score / torchmetrics are CPU and per-(scores, label), so a ~32k-feature dictionary becomes a 32k-iteration CPU loop. Kept.
  • domain_f1, best_single_train_test, annotate_features — bespoke (instance-F1, winner's-curse, per-feature assignment); no library equivalent.
  • fit_logreg / fit_softmax / decode_eval (~62 lines) — the only sklearn-replaceable code, but they fit on the [N≈50k, F≈32k] SAE-code matrix, which is exactly where CodonFM tried sklearn.LogisticRegression and had to subsample to ≤5k features. Swapping reintroduces that coverage loss and a runtime dep. Net regression.
  • ActivationBuffer / split_indices / standardizenp.savez + 7-line helpers; nothing to gain.

Conclusion: kept the package torch+numpy-only. The metrics are standard formulas (Mann–Whitney rank-AUROC, Adam BCE/softmax, InterPLM instance-F1) vectorized for full-dictionary GPU scale, and each is validated against an independent reference in the tests.

Expected output

  • Library (no script). pytest sae/tests/test_probing.py6 passed: AUROC vs the pairwise-definition oracle, domain_f1 vs a hand-computed reference, best_single winner's-curse flip, decode_eval separability, annotate_features best-concept, buffer roundtrip.

Summary by CodeRabbit

  • New Features

    • Added comprehensive evaluation and probing utilities for sparse autoencoders, including AUROC metrics, feature annotation, classifier-based probing, and domain F1 scoring.
    • Introduced a data buffer utility for storing and persisting activation analysis data.
  • Tests

    • Added comprehensive test coverage for new evaluation metrics and utilities.

@coderabbitai

coderabbitai Bot commented Jun 11, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

Important

Review skipped

Auto reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: dc76b55b-7c86-47b5-a64a-0c3699980ca1

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This PR adds a comprehensive SAE feature-probing evaluation module (probing.py) to enable model-agnostic interpretation of learned features through metrics, classifiers, and annotation tools, along with an ActivationBuffer artifact for persistence and a full test suite validating correctness across all components.

Changes

SAE Probing Evaluation Suite

Layer / File(s) Summary
ActivationBuffer data structure and persistence
bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/probing.py (lines 1–65), bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_probing.py (lines 123–142)
Dataclass storing SAE feature codes, per-token boolean labels and names, optional dense residuals, and concept-to-instance id mappings; .save() serializes to typed .npz with per-concept instance arrays; .load() reconstructs the dataclass; .name_idx property maps label names to column indices.
Dataset utilities and standardization
bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/probing.py (lines 73–84)
split_indices performs deterministic train/test splitting via seeded torch.randperm; standardize computes mean and std on training rows with epsilon-clamped std normalization.
AUROC computation and best-feature selection
bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/probing.py (lines 86–145), bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_probing.py (lines 37–71)
auroc_all computes full [feature, label] AUROC matrix via chunked rank-statistics; auroc_vec handles single-vector AUROC with degenerate-case handling; best_single_train_test selects best feature on training set and reports test AUROC without winner's-curse bias; test oracle _auroc_ref validates against brute-force reference.
Feature concept annotation via AUROC thresholding
bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/probing.py (lines 147–174), bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_probing.py (lines 110–121)
annotate_features derives per-feature best-label annotations by selecting max AUROC across labels and filtering by configurable AUROC threshold; excludes low-information features.
Linear classifier training and macro-AUROC evaluation
bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/probing.py (lines 176–226), bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_probing.py (lines 89–108)
fit_logreg trains binary logistic regression; fit_softmax trains multinomial softmax; both use Adam with BCE-with-logits and cross-entropy respectively; macro_auroc computes macro one-vs-rest AUROC; decode_eval orchestrates training and dual metric reporting for test accuracy and macro AUROC.
Domain-adjusted F1 with instance-aware thresholding
bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/probing.py (lines 228–270), bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_probing.py (lines 73–87)
domain_f1 computes threshold-swept per-feature F1 by normalizing activations per-feature, remapping instance ids, aggregating per-instance firing via index_reduce_, combining precision from concept masks with recall from instance aggregation, and selecting best F1 threshold per feature in chunked passes.
Module public API and test setup
bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/__init__.py (lines 25–71), bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_probing.py (lines 1–35)
Imports and re-exports all probing.py utilities in __all__ for public access; test module imports and validates all components.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Poem

🐰 A warren of metrics, now bundled with care,
AUROC and F1 floating through air,
Buffers that save what the features unfold,
Linear probes seeking wisdom untold,
Domain-aware thresholds, adaptive and keen—
The richest of probing suites ever been seen! 🌟

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately summarizes the main changes: introduction of shared probing primitives including ActivationBuffer and evaluation metrics for SAE feature probing.
Docstring Coverage ✅ Passed Docstring coverage is 91.30% which is sufficient. The required threshold is 80.00%.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch pbinder/sae-interp-primitives

Comment @coderabbitai help to get the list of available commands.

@copy-pr-bot

copy-pr-bot Bot commented Jun 11, 2026

Copy link
Copy Markdown

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@polinabinder1 polinabinder1 changed the title sae: shared interpretability primitives (probing + steering) sae: shared probing primitives (eval metrics + ActivationBuffer) Jun 11, 2026
@polinabinder1

Copy link
Copy Markdown
Collaborator Author

@coderabbitai review

@coderabbitai

coderabbitai Bot commented Jun 12, 2026

Copy link
Copy Markdown
Contributor
✅ Action performed

Review finished.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/probing.py (2)

54-65: 💤 Low value

allow_pickle=True poses a deserialization risk if loading untrusted files.

This is acceptable for internal artifacts but worth documenting. If these buffers might come from external sources, consider validating provenance or using a safer serialization format.

     `@classmethod`
     def load(cls, path: str) -> "ActivationBuffer":
-        """Load an ActivationBuffer from an .npz written by save()."""
+        """Load an ActivationBuffer from an .npz written by save().
+
+        Warning:
+            Uses allow_pickle=True; only load files from trusted sources.
+        """
         z = np.load(path, allow_pickle=True)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In
`@bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/probing.py`
around lines 54 - 65, The load method in ActivationBuffer uses np.load(...,
allow_pickle=True) which is unsafe for untrusted files; change load to avoid
allow_pickle=True by default (use allow_pickle=False) or add an explicit
parameter (e.g., allow_pickle: bool = False) and fail with a clear error if
pickled objects are required, and update the ActivationBuffer.load docstring to
document the deserialization risk and the need to validate provenance when
loading external files; ensure references to ActivationBuffer.load and the local
variable z are used to implement and surface the safer behavior.

243-245: 💤 Low value

Consider adding a comment explaining the +2 sizing for the remap tensor.

The +2 accounts for 0-indexing and ensures negative indexing (-1) wraps to a valid buffer position. While correct, this is subtle:

-    remap = torch.full((int(inst_ids.max().item()) + 2,), -1, device=dev, dtype=torch.long)
+    # +2: one for 0-indexing, one so that -1 wraps to a valid (unused) slot
+    remap = torch.full((int(inst_ids.max().item()) + 2,), -1, device=dev, dtype=torch.long)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In
`@bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/probing.py`
around lines 243 - 245, Add an inline comment above the remap creation
explaining why the size is int(inst_ids.max().item()) + 2: we need +1 for
0-based indexing of the maximum id and an extra slot so that using -1 as a
sentinel (when indexing remap with potentially -1 inst_ids) will wrap to a valid
buffer position instead of raising an out-of-bounds error; reference the remap
tensor and the subsequent remap[uniq.long()] / remap[inst_ids.long()] usage (and
the torch.full default -1) so readers understand the sentinel handling.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In
`@bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/probing.py`:
- Around line 54-65: The load method in ActivationBuffer uses np.load(...,
allow_pickle=True) which is unsafe for untrusted files; change load to avoid
allow_pickle=True by default (use allow_pickle=False) or add an explicit
parameter (e.g., allow_pickle: bool = False) and fail with a clear error if
pickled objects are required, and update the ActivationBuffer.load docstring to
document the deserialization risk and the need to validate provenance when
loading external files; ensure references to ActivationBuffer.load and the local
variable z are used to implement and surface the safer behavior.
- Around line 243-245: Add an inline comment above the remap creation explaining
why the size is int(inst_ids.max().item()) + 2: we need +1 for 0-based indexing
of the maximum id and an extra slot so that using -1 as a sentinel (when
indexing remap with potentially -1 inst_ids) will wrap to a valid buffer
position instead of raising an out-of-bounds error; reference the remap tensor
and the subsequent remap[uniq.long()] / remap[inst_ids.long()] usage (and the
torch.full default -1) so readers understand the sentinel handling.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 23ddf87a-6a45-46a2-8264-db968ee016e5

📥 Commits

Reviewing files that changed from the base of the PR and between e407165 and 79df727.

📒 Files selected for processing (3)
  • bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/__init__.py
  • bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/probing.py
  • bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_probing.py

@polinabinder1

Copy link
Copy Markdown
Collaborator Author

Addressed the two nitpicks in 57837ec7: documented the allow_pickle=True trust caveat on ActivationBuffer.load, and added a comment explaining the +2 remap-tensor sizing (index-by-max-id + sentinel headroom). Tests still green (6 passed).

@polinabinder1 polinabinder1 marked this pull request as ready for review June 12, 2026 05:32
polinabinder1 and others added 2 commits June 23, 2026 05:55
Re-lands #1629 (sae.eval.probing: AUROC / domain-F1 / linear probes + ActivationBuffer) onto
the post-#1633 top-level layout, and adds a dedicated CPU workflow (ubuntu-latest, no model/GPU)
that runs the model-agnostic probing tests. Separate from the evo2 GPU lane; the tensor-parallel
sae tests (torchrun/multi-GPU) are out of scope here.

Validated: tests/test_probing.py -> 6 passed (CPU).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: Polina Binder <pbinder@nvidia.com>
auroc_all / auroc_vec / best_single / macro_auroc ranked via argsort().argsort(), giving tied
values arbitrary distinct ranks. SAE codes are sparse (heavy zero-mass), so that biased the AUROC
on the real data distribution — and the oracle test only covered randn (no ties). Switch to
average (Mann-Whitney) ranks via a vectorized searchsorted helper (keeps the all-features-at-once
speed that motivates hand-rolling), make the oracle tie-aware, and add sparse-tie +
constant-feature tests. Also documents why these metrics are hand-rolled.

tests/test_probing.py -> 8 passed (CPU).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: Polina Binder <pbinder@nvidia.com>
@polinabinder1 polinabinder1 force-pushed the pbinder/sae-interp-primitives branch from 57837ec to 13a0690 Compare June 23, 2026 06:06
…er None paths

- a never/always-firing concept -> AUROC 0.5 (the valid-mask branch; realistic for rare concepts)
- auroc_vec directly (was only tested transitively via best_single) on tied scores
- ActivationBuffer with no dense twin / no instances (the Optional -> None save/load paths)

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: Polina Binder <pbinder@nvidia.com>
polinabinder1 added a commit that referenced this pull request Jun 23, 2026
Re-lands #1630 on the post-#1633 layout, on top of the rebased #1629: the DNA label producers
(scripts/{labelers,annot_tracks,euk_windows}.py) that emit per-token concept labels (genes/exons/
motifs) to fill #1629's ActivationBuffer, + biopython dep (genetic code in labelers.py).

Validated: tests/{test_labelers,test_annot_tracks}.py -> 8 passed (CPU).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: Polina Binder <pbinder@nvidia.com>
polinabinder1 added a commit that referenced this pull request Jun 23, 2026
Re-lands #1636 on the post-#1633 layout, on top of rebased #1630: the harness/CLI
(scripts/{evo2_buffer,probe,probe_loss_recovered}.py) that runs the model to build an
ActivationBuffer (#1629) from #1630's labels and emits the probing metrics. Syntax-checked;
the GPU extract->score smoke is a follow-up (no unit tests in this PR yet).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: Polina Binder <pbinder@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant