Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions .github/workflows/unit-tests-sae.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
name: "BioNeMo SAE Library CI"

# CPU unit tests for the model-agnostic `sae` library (interpretability/sparse_autoencoders/sae).
# Separate from the evo2 GPU lane: these need no model and no GPU — just CPU torch + numpy — so
# they run on a cheap ubuntu-latest runner. Scoped to the model-agnostic tests; the
# tensor-parallel tests (test_tp_*) need torchrun/multi-GPU and are intentionally not run here.

on:
push:
branches:
- "pull-request/[0-9]+"
- "dependabot/**"
paths:
- "interpretability/sparse_autoencoders/sae/**"
- ".github/workflows/unit-tests-sae.yaml"
merge_group:
types: [checks_requested]
schedule:
- cron: "0 9 * * *" # Runs at 9 AM UTC daily (2 AM MST)

defaults:
run:
shell: bash -x -e -u -o pipefail {0}

concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true

jobs:
sae-cpu-tests:
runs-on: ubuntu-latest
name: "sae-lib-tests (cpu)"
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
sparse-checkout: interpretability/sparse_autoencoders/sae
sparse-checkout-cone-mode: false

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.12"

- name: Install (CPU torch + test deps)
working-directory: interpretability/sparse_autoencoders/sae
run: pip install --extra-index-url https://download.pytorch.org/whl/cpu -e '.[dev]'

- name: Run model-agnostic probing tests
working-directory: interpretability/sparse_autoencoders/sae
run: pytest -v tests/test_probing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,20 @@
compute_loss_recovered,
evaluate_loss_recovered,
)
from .probing import (
ActivationBuffer,
annotate_features,
auroc_all,
auroc_vec,
best_single_train_test,
decode_eval,
domain_f1,
fit_logreg,
fit_softmax,
macro_auroc,
split_indices,
standardize,
)
from .reconstruction import (
ReconstructionMetrics,
compute_reconstruction_metrics,
Expand All @@ -31,16 +45,28 @@


__all__ = [
"ActivationBuffer",
"DeadLatentStats",
"DeadLatentTracker",
"EvalResults",
"LossRecoveredResult",
"ReconstructionMetrics",
"SparsityMetrics",
"annotate_features",
"auroc_all",
"auroc_vec",
"best_single_train_test",
"compute_loss_recovered",
"compute_reconstruction_metrics",
"decode_eval",
"domain_f1",
"evaluate_loss_recovered",
"evaluate_reconstruction",
"evaluate_sae",
"evaluate_sparsity",
"fit_logreg",
"fit_softmax",
"macro_auroc",
"split_indices",
"standardize",
]
Loading
Loading