sae: shared probing primitives (eval metrics + ActivationBuffer)#1629
sae: shared probing primitives (eval metrics + ActivationBuffer)#1629polinabinder1 wants to merge 3 commits into
Conversation
|
Important Review skippedAuto reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Enterprise Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughThis PR adds a comprehensive SAE feature-probing evaluation module ( ChangesSAE Probing Evaluation Suite
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Poem
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
|
@coderabbitai review |
✅ Action performedReview finished.
|
There was a problem hiding this comment.
🧹 Nitpick comments (2)
bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/probing.py (2)
54-65: 💤 Low value
allow_pickle=Trueposes a deserialization risk if loading untrusted files.This is acceptable for internal artifacts but worth documenting. If these buffers might come from external sources, consider validating provenance or using a safer serialization format.
`@classmethod` def load(cls, path: str) -> "ActivationBuffer": - """Load an ActivationBuffer from an .npz written by save().""" + """Load an ActivationBuffer from an .npz written by save(). + + Warning: + Uses allow_pickle=True; only load files from trusted sources. + """ z = np.load(path, allow_pickle=True)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/probing.py` around lines 54 - 65, The load method in ActivationBuffer uses np.load(..., allow_pickle=True) which is unsafe for untrusted files; change load to avoid allow_pickle=True by default (use allow_pickle=False) or add an explicit parameter (e.g., allow_pickle: bool = False) and fail with a clear error if pickled objects are required, and update the ActivationBuffer.load docstring to document the deserialization risk and the need to validate provenance when loading external files; ensure references to ActivationBuffer.load and the local variable z are used to implement and surface the safer behavior.
243-245: 💤 Low valueConsider adding a comment explaining the
+2sizing for the remap tensor.The
+2accounts for 0-indexing and ensures negative indexing (-1) wraps to a valid buffer position. While correct, this is subtle:- remap = torch.full((int(inst_ids.max().item()) + 2,), -1, device=dev, dtype=torch.long) + # +2: one for 0-indexing, one so that -1 wraps to a valid (unused) slot + remap = torch.full((int(inst_ids.max().item()) + 2,), -1, device=dev, dtype=torch.long)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/probing.py` around lines 243 - 245, Add an inline comment above the remap creation explaining why the size is int(inst_ids.max().item()) + 2: we need +1 for 0-based indexing of the maximum id and an extra slot so that using -1 as a sentinel (when indexing remap with potentially -1 inst_ids) will wrap to a valid buffer position instead of raising an out-of-bounds error; reference the remap tensor and the subsequent remap[uniq.long()] / remap[inst_ids.long()] usage (and the torch.full default -1) so readers understand the sentinel handling.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In
`@bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/probing.py`:
- Around line 54-65: The load method in ActivationBuffer uses np.load(...,
allow_pickle=True) which is unsafe for untrusted files; change load to avoid
allow_pickle=True by default (use allow_pickle=False) or add an explicit
parameter (e.g., allow_pickle: bool = False) and fail with a clear error if
pickled objects are required, and update the ActivationBuffer.load docstring to
document the deserialization risk and the need to validate provenance when
loading external files; ensure references to ActivationBuffer.load and the local
variable z are used to implement and surface the safer behavior.
- Around line 243-245: Add an inline comment above the remap creation explaining
why the size is int(inst_ids.max().item()) + 2: we need +1 for 0-based indexing
of the maximum id and an extra slot so that using -1 as a sentinel (when
indexing remap with potentially -1 inst_ids) will wrap to a valid buffer
position instead of raising an out-of-bounds error; reference the remap tensor
and the subsequent remap[uniq.long()] / remap[inst_ids.long()] usage (and the
torch.full default -1) so readers understand the sentinel handling.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 23ddf87a-6a45-46a2-8264-db968ee016e5
📒 Files selected for processing (3)
bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/__init__.pybionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/probing.pybionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_probing.py
|
Addressed the two nitpicks in |
Re-lands #1629 (sae.eval.probing: AUROC / domain-F1 / linear probes + ActivationBuffer) onto the post-#1633 top-level layout, and adds a dedicated CPU workflow (ubuntu-latest, no model/GPU) that runs the model-agnostic probing tests. Separate from the evo2 GPU lane; the tensor-parallel sae tests (torchrun/multi-GPU) are out of scope here. Validated: tests/test_probing.py -> 6 passed (CPU). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Polina Binder <pbinder@nvidia.com>
auroc_all / auroc_vec / best_single / macro_auroc ranked via argsort().argsort(), giving tied values arbitrary distinct ranks. SAE codes are sparse (heavy zero-mass), so that biased the AUROC on the real data distribution — and the oracle test only covered randn (no ties). Switch to average (Mann-Whitney) ranks via a vectorized searchsorted helper (keeps the all-features-at-once speed that motivates hand-rolling), make the oracle tie-aware, and add sparse-tie + constant-feature tests. Also documents why these metrics are hand-rolled. tests/test_probing.py -> 8 passed (CPU). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Polina Binder <pbinder@nvidia.com>
57837ec to
13a0690
Compare
…er None paths - a never/always-firing concept -> AUROC 0.5 (the valid-mask branch; realistic for rare concepts) - auroc_vec directly (was only tested transitively via best_single) on tied scores - ActivationBuffer with no dense twin / no instances (the Optional -> None save/load paths) Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Polina Binder <pbinder@nvidia.com>
Re-lands #1630 on the post-#1633 layout, on top of the rebased #1629: the DNA label producers (scripts/{labelers,annot_tracks,euk_windows}.py) that emit per-token concept labels (genes/exons/ motifs) to fill #1629's ActivationBuffer, + biopython dep (genetic code in labelers.py). Validated: tests/{test_labelers,test_annot_tracks}.py -> 8 passed (CPU). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Polina Binder <pbinder@nvidia.com>
Re-lands #1636 on the post-#1633 layout, on top of rebased #1630: the harness/CLI (scripts/{evo2_buffer,probe,probe_loss_recovered}.py) that runs the model to build an ActivationBuffer (#1629) from #1630's labels and emits the probing metrics. Syntax-checked; the GPU extract->score smoke is a follow-up (no unit tests in this PR yet). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Polina Binder <pbinder@nvidia.com>
Summary
Shared, model-agnostic SAE probing primitives in the
saepackage (sibling ofloss_recovered/sparsity/dead_latents): scoring metrics + per-feature annotation, all pure functions of codes + labels.Contents —
sae.eval.probingActivationBuffer(codes + optional dense twin + per-token labels + instance ids)auroc_all,auroc_vec,best_single_train_testfit_logreg/fit_softmax/macro_auroc/decode_evaldomain_f1(precision-per-nt, recall-per-instance)annotate_features(per-feature best concept by AUROC → the annotation table)How to use
pytest sae/tests/test_probing.py— CPU, no model.Why hand-rolled (not sklearn)
GPU-vectorized over the whole ~32k-feature dictionary in one pass; sklearn is CPU/per-
(scores,label)— CodonFM used it and had to subsample to ≤5k features. Package staystorch+numpy. Each metric is validated against an independent reference in the tests (pairwise-AUROC oracle, hand-computeddomain_f1, etc.).Base of: #1630 (eval labels) → #1636 (harness).
On external libraries (checked — not a win)
We evaluated replacing the hand-rolled metrics with
sklearn/torchmetrics, function by function:auroc_all— no library computes a vectorized[features × labels]AUROC matrix on GPU;sklearn.roc_auc_score/torchmetricsare CPU and per-(scores, label), so a ~32k-feature dictionary becomes a 32k-iteration CPU loop. Kept.domain_f1,best_single_train_test,annotate_features— bespoke (instance-F1, winner's-curse, per-feature assignment); no library equivalent.fit_logreg/fit_softmax/decode_eval(~62 lines) — the onlysklearn-replaceable code, but they fit on the[N≈50k, F≈32k]SAE-code matrix, which is exactly where CodonFM triedsklearn.LogisticRegressionand had to subsample to ≤5k features. Swapping reintroduces that coverage loss and a runtime dep. Net regression.ActivationBuffer/split_indices/standardize—np.savez+ 7-line helpers; nothing to gain.Conclusion: kept the package
torch+numpy-only. The metrics are standard formulas (Mann–Whitney rank-AUROC, Adam BCE/softmax, InterPLM instance-F1) vectorized for full-dictionary GPU scale, and each is validated against an independent reference in the tests.Expected output
pytest sae/tests/test_probing.py→ 6 passed: AUROC vs the pairwise-definition oracle,domain_f1vs a hand-computed reference,best_singlewinner's-curse flip,decode_evalseparability,annotate_featuresbest-concept, buffer roundtrip.Summary by CodeRabbit
New Features
Tests