From 3557958bef78f2a7c31f6dde0198bccd6c6e4769 Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Tue, 23 Jun 2026 06:17:52 +0000 Subject: [PATCH 1/6] evo2 SAE eval (1/2): DNA label producers, rebased onto migrated #1629 Re-lands #1630 on the post-#1633 layout, on top of the rebased #1629: the DNA label producers (scripts/{labelers,annot_tracks,euk_windows}.py) that emit per-token concept labels (genes/exons/ motifs) to fill #1629's ActivationBuffer, + biopython dep (genetic code in labelers.py). Validated: tests/{test_labelers,test_annot_tracks}.py -> 8 passed (CPU). Co-Authored-By: Claude Opus 4.8 Signed-off-by: Polina Binder --- .../recipes/evo2/pyproject.toml | 1 + .../recipes/evo2/scripts/annot_tracks.py | 146 +++++++ .../recipes/evo2/scripts/euk_windows.py | 250 ++++++++++++ .../recipes/evo2/scripts/labelers.py | 367 ++++++++++++++++++ .../recipes/evo2/tests/test_annot_tracks.py | 76 ++++ .../recipes/evo2/tests/test_labelers.py | 53 +++ 6 files changed, 893 insertions(+) create mode 100644 interpretability/sparse_autoencoders/recipes/evo2/scripts/annot_tracks.py create mode 100644 interpretability/sparse_autoencoders/recipes/evo2/scripts/euk_windows.py create mode 100644 interpretability/sparse_autoencoders/recipes/evo2/scripts/labelers.py create mode 100644 interpretability/sparse_autoencoders/recipes/evo2/tests/test_annot_tracks.py create mode 100644 interpretability/sparse_autoencoders/recipes/evo2/tests/test_labelers.py diff --git a/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml b/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml index f7b23d8eff..ba794c5e74 100644 --- a/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml +++ b/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml @@ -13,6 +13,7 @@ dependencies = [ "torch>=2.0", "numpy>=1.20", "pyarrow>=23.0.0", + "biopython>=1.80", # genetic code / translation in labelers.py ] # No package code lives here yet — the recipe is just an entry-point for diff --git a/interpretability/sparse_autoencoders/recipes/evo2/scripts/annot_tracks.py b/interpretability/sparse_autoencoders/recipes/evo2/scripts/annot_tracks.py new file mode 100644 index 0000000000..88994c2113 --- /dev/null +++ b/interpretability/sparse_autoencoders/recipes/evo2/scripts/annot_tracks.py @@ -0,0 +1,146 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Generic interval-track loader for the "user-supplied annotated dataset" eval. + +The user hands in an annotated dataset: a FASTA of sequences + one or more annotation +tracks (BED or GFF) naming intervals — RefSeq genes/exons, Rfam ncRNA, JASPAR TFBS, +ENCODE cCREs, etc. Each interval is one annotation **instance**. This module tiles the +sequences into windows and produces, per concept, a per-token boolean mask + per-token +**global** instance IDs (stable across the windows an interval spans) — exactly the +inputs `sae.eval.probing.domain_f1` (recall-per-instance) and `auroc_all` (per-feature) +consume. No model here; the SAE-encode step lives in the probe CLI (`probe.py domain-eval`). + +This is the generic sibling of `euk_windows.py` (which decomposes RefSeq gene models into +exon/intron/cds). Both feed the same shared scorers. +""" + +from __future__ import annotations + +import gzip +from collections import defaultdict + + +def _open(path): + """Open a path for text reading, transparently handling ``.gz``.""" + return (gzip.open if str(path).endswith(".gz") else open)(path, "rt") + + +def read_fasta_dict(path: str) -> dict[str, str]: + """Read a (multi-record) FASTA into ``{seq_id: sequence}`` (``.gz`` transparent). + + ``seq_id`` is the first whitespace token of the header — matches the chrom/seqid of BED/GFF. + """ + seqs: dict[str, str] = {} + name, parts = None, [] + with _open(path) as fh: + for line in fh: + line = line.rstrip() + if line.startswith(">"): + if name is not None: + seqs[name] = "".join(parts) + tok = line[1:].split() + name, parts = (tok[0] if tok else f"seq_{len(seqs)}"), [] + elif line: + parts.append(line) + if name is not None: + seqs[name] = "".join(parts) + return seqs + + +def _intervals(path, fmt, feature_type=None): + """Yield (seqid, start0, end0) from BED (0-based) or GFF/GTF (1-based -> 0-based half-open). + + GFF rows are optionally filtered to a single column-3 ``feature_type`` (e.g. ``exon``). + """ + chrom_i, start_i, end_i, off = (0, 1, 2, 0) if fmt == "bed" else (0, 3, 4, 1) + with _open(path) as fh: + for line in fh: + if not line.strip() or line[0] == "#" or line.startswith(("track", "browser")): + continue + f = line.split("\t") + if len(f) <= end_i or (feature_type and fmt != "bed" and f[2] != feature_type): + continue + yield f[chrom_i], int(f[start_i]) - off, int(f[end_i]) + + +def load_track(path, feature_type=None, fmt=None): + """Load one annotation track into ``{seqid: [(start0, end0), ...]}`` (0-based half-open, sorted). + + ``fmt`` (``bed``/``gff``) is inferred from the extension; ``feature_type`` filters GFF column 3. + Every interval is one annotation instance. + """ + fmt = fmt or ("gff" if str(path).replace(".gz", "").endswith((".gff", ".gff3", ".gtf")) else "bed") + by_seq = defaultdict(list) + for chrom, s, e in _intervals(path, fmt, feature_type): + if e > s: + by_seq[chrom].append((s, e)) + return {k: sorted(v) for k, v in by_seq.items()} + + +def label_windows(seqs, tracks, seq_len=1024, stride=None, max_tokens=None, min_n_frac=0.5): + """Tile sequences into windows, labeling each position per concept (mask + global instance id). + + Args: + seqs: ``{seqid: dna_str}``. + tracks: ``{concept: {seqid: [(start0, end0), ...]}}`` (e.g. from `load_track`). + seq_len: window length in bp. + stride: step between windows (defaults to non-overlapping = seq_len). + max_tokens: stop once this many positions are emitted (None = all). + min_n_frac: skip windows whose ``N`` fraction exceeds this. + + Returns: + (windows, stats). Each window is ``{"dna": str, "labels": {concept: bool[L]}, + "instances": {concept: int32[L]}}``. Each interval gets one global id, stable across + the windows it spans, so `domain_f1`'s recall-per-instance counts a split interval once. + """ + import numpy as np + + stride = stride or seq_len + concepts = list(tracks.keys()) + # assign a global instance id to every interval, per concept + concept_iv: dict[str, dict[str, list[tuple[int, int, int]]]] = {} + n_inst: dict[str, int] = {} + for concept in concepts: + gid = 0 + cc: dict[str, list[tuple[int, int, int]]] = {} + for seqid, ivs in tracks[concept].items(): + cc[seqid] = [(s, e, (gid := gid + 1) - 1) for (s, e) in ivs] + concept_iv[concept] = cc + n_inst[concept] = gid + + windows, tot = [], 0 + for seqid, dna in seqs.items(): + dna = dna.upper() + N = len(dna) + for w0 in range(0, max(1, N - seq_len + 1), stride): + w1 = min(N, w0 + seq_len) + sub = dna[w0:w1] + L = w1 - w0 + if L < 60 or sub.count("N") > min_n_frac * L: + continue + labels = {c: np.zeros(L, bool) for c in concepts} + inst = {c: np.full(L, -1, np.int32) for c in concepts} + for c in concepts: + for s, e, gid in concept_iv[c].get(seqid, []): + if e <= w0 or s >= w1: + continue + labels[c][max(s, w0) - w0 : min(e, w1) - w0] = True + inst[c][max(s, w0) - w0 : min(e, w1) - w0] = gid + windows.append({"dna": sub, "labels": labels, "instances": inst}) + tot += L + if max_tokens and tot >= max_tokens: + return windows, {"tokens": tot, "n_inst": n_inst, "concepts": concepts} + return windows, {"tokens": tot, "n_inst": n_inst, "concepts": concepts} diff --git a/interpretability/sparse_autoencoders/recipes/evo2/scripts/euk_windows.py b/interpretability/sparse_autoencoders/recipes/evo2/scripts/euk_windows.py new file mode 100644 index 0000000000..0f381df84b --- /dev/null +++ b/interpretability/sparse_autoencoders/recipes/evo2/scripts/euk_windows.py @@ -0,0 +1,250 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Build instance-level exon/intron/CDS-labeled windows from a genome FASTA + GFF3. + +Eukaryotic gene-structure annotation for SAE feature probing. Unlike the +sequence-derived labelers, these labels come from real gene models, and crucially +carry *instance IDs* (which exon / which intron / which gene each position belongs +to) so domain-adjusted F1 can compute recall PER ANNOTATION INSTANCE (a feature +"recalls" an exon if it fires anywhere inside it), not per position. + +For each protein-coding gene we take a representative transcript (longest by total +exon length), tile its span ± flank into windows, and label every position: + exon / intron / cds / utr / intergenic (+ per-position instance IDs for + exon, intron, gene) + +`python euk_windows.py --fasta chr21.fa --gff chr21.gff3 --dry-run` prints +coverage stats without building sequences. +""" + +from __future__ import annotations + +import argparse +import random +from collections import defaultdict + +import numpy as np + + +def _attrs(s): + return dict(kv.split("=", 1) for kv in s.strip().split(";") if "=" in kv) + + +def parse_gff(gff_path): + """Return {gene_id: {strand, tx: {tx_id: {'exon': [(s,e)], 'cds': [(s,e)]}}}} (protein_coding).""" + gene_strand, gene_biotype = {}, {} + tx_gene, tx_biotype = {}, {} + tx_exon = defaultdict(list) + tx_cds = defaultdict(list) + with open(gff_path) as fh: + for line in fh: + if line.startswith("#"): + continue + f = line.rstrip("\n").split("\t") + if len(f) < 9: + continue + typ, s, e, strand, attr = f[2], int(f[3]), int(f[4]), f[6], f[8] + a = _attrs(attr) + if typ == "gene": + gid = a.get("ID", "").replace("gene:", "") + gene_strand[gid] = strand + gene_biotype[gid] = a.get("biotype", "") + elif typ in ("mRNA", "transcript"): + tid = a.get("ID", "").replace("transcript:", "") + tx_gene[tid] = a.get("Parent", "").replace("gene:", "") + tx_biotype[tid] = a.get("biotype", "") + elif typ == "exon": + for tid in a.get("Parent", "").replace("transcript:", "").split(","): + if tid: + tx_exon[tid].append((s, e)) + elif typ == "CDS": + for tid in a.get("Parent", "").replace("transcript:", "").split(","): + if tid: + tx_cds[tid].append((s, e)) + genes = {} + for tid, gid in tx_gene.items(): + if gene_biotype.get(gid) != "protein_coding" or tx_biotype.get(tid) != "protein_coding": + continue + if not tx_exon.get(tid): + continue + genes.setdefault(gid, {"strand": gene_strand.get(gid, "+"), "tx": {}}) + genes[gid]["tx"][tid] = {"exon": sorted(tx_exon[tid]), "cds": sorted(tx_cds.get(tid, []))} + return genes + + +def representative_tx(gene): + """Longest transcript by total exon length.""" + best, best_len = None, -1 + for tid, t in gene["tx"].items(): + ln = sum(e - s + 1 for s, e in t["exon"]) + if ln > best_len: + best, best_len = tid, ln + return best, gene["tx"][best] + + +def _label_window(chrom, w0, w1, gm, N): + """Label a window [w0,w1) using one gene model's intervals (central-gene approx).""" + L = w1 - w0 + pos = np.arange(w0, w1) + lab = {k: np.zeros(L, bool) for k in ("exon", "intron", "cds", "utr", "intergenic")} + inst = {k: np.full(L, -1, np.int32) for k in ("exon", "intron", "gene")} + g_start, g_end = gm["span"] + in_tx = (pos >= g_start - 1) & (pos < g_end) + lab["intergenic"][~in_tx] = True + inst["gene"][in_tx] = gm["gi"] + for (s, e), iid in zip(gm["exons"], gm["exon_ids"]): + m = (pos >= s - 1) & (pos < e) + lab["exon"][m] = True + inst["exon"][m] = iid + for (s, e), iid in zip(gm["introns"], gm["intron_ids"]): + m = (pos >= s - 1) & (pos < e) + lab["intron"][m] = True + inst["intron"][m] = iid + for s, e in gm["cds"]: + lab["cds"][(pos >= s - 1) & (pos < e)] = True + lab["utr"] = lab["exon"] & ~lab["cds"] + return {"dna": chrom[w0:w1], "labels": lab, "instances": inst} + + +def build_windows( # noqa: D103 + fasta, gff, seq_len=1024, max_tokens=300_000, flank=300, seed=0, intergenic_frac=0.12, dry_run=False +): + seqs = [] + with open(fasta) as fh: + for line in fh: + if not line.startswith(">"): + seqs.append(line.strip()) + chrom = "".join(seqs).upper() + N = len(chrom) + genes = parse_gff(gff) + + exon_id, intron_id, gene_id = {}, {}, {} + stats = defaultdict(int) + gene_models, gene_spans = [], [] + for gid, gene in genes.items(): + tid, tx = representative_tx(gene) + exons, cds = tx["exon"], tx["cds"] + if not exons: + continue + g_start, g_end = exons[0][0], exons[-1][1] + introns = [ + (exons[i][1] + 1, exons[i + 1][0] - 1) + for i in range(len(exons) - 1) + if exons[i + 1][0] - 1 >= exons[i][1] + 1 + ] + gi = gene_id.setdefault(gid, len(gene_id)) + eids = [exon_id.setdefault((tid, i), len(exon_id)) for i in range(len(exons))] + iids = [intron_id.setdefault((tid, i), len(intron_id)) for i in range(len(introns))] + gene_models.append( + { + "exons": exons, + "introns": introns, + "cds": cds, + "gi": gi, + "exon_ids": eids, + "intron_ids": iids, + "span": (g_start, g_end), + } + ) + gene_spans.append((g_start, g_end)) + stats["genes"] += 1 + stats["exons"] += len(exons) + stats["introns"] += len(introns) + stats["exon_bp"] += sum(e - s + 1 for s, e in exons) + stats["intron_bp"] += sum(e - s + 1 for s, e in introns) + stats["cds_bp"] += sum(e - s + 1 for s, e in cds) + if dry_run: + return [], dict(stats), 0, N + + rng = random.Random(seed) + # exon-centered windows sampled across ALL genes' exons (diverse + exon/intron balanced) + exon_refs = [(gi, ei) for gi, gm in enumerate(gene_models) for ei in range(len(gm["exons"]))] + rng.shuffle(exon_refs) + windows, tot = [], 0 + budget_genic = int(max_tokens * (1 - intergenic_frac)) + for gi, ei in exon_refs: + if tot >= budget_genic: + break + gm = gene_models[gi] + s, e = gm["exons"][ei] + center = (s - 1 + e) // 2 + w0 = max(0, center - seq_len // 2) + w1 = min(N, w0 + seq_len) + if w1 - w0 < 60: + continue + win = _label_window(chrom, w0, w1, gm, N) + if win["dna"].count("N") > 0.5 * len(win["dna"]): + continue + windows.append(win) + tot += w1 - w0 + # intergenic windows: random spots clear of any gene span (+flank) + spans = sorted(gene_spans) + tries = 0 + while tot < max_tokens and tries < 20000: + tries += 1 + w0 = rng.randint(0, N - seq_len) + w1 = w0 + seq_len + if any(not (w1 < gs - flank or w0 > ge + flank) for gs, ge in spans): + continue + dna = chrom[w0:w1] + if dna.count("N") > 0.5 * seq_len: + continue + lab = {k: np.zeros(seq_len, bool) for k in ("exon", "intron", "cds", "utr", "intergenic")} + lab["intergenic"][:] = True + inst = {k: np.full(seq_len, -1, np.int32) for k in ("exon", "intron", "gene")} + windows.append({"dna": dna, "labels": lab, "instances": inst}) + tot += seq_len + return windows, dict(stats), tot, N + + +def main(): # noqa: D103 + ap = argparse.ArgumentParser() + ap.add_argument("--fasta", required=True) + ap.add_argument("--gff", required=True) + ap.add_argument("--seq-len", type=int, default=1024) + ap.add_argument("--max-tokens", type=int, default=300_000) + ap.add_argument("--flank", type=int, default=300) + ap.add_argument("--dry-run", action="store_true") + args = ap.parse_args() + windows, stats, tot, N = build_windows( + args.fasta, args.gff, args.seq_len, args.max_tokens, args.flank, dry_run=args.dry_run + ) + print(f"chromosome length: {N:,} bp") + print(f"protein-coding genes used: {stats.get('genes', 0):,}") + print(f"exons: {stats.get('exons', 0):,} introns: {stats.get('introns', 0):,}") + if args.dry_run: + print( + f"exon bp: {stats.get('exon_bp', 0):,} intron bp: {stats.get('intron_bp', 0):,} cds bp: {stats.get('cds_bp', 0):,}" + ) + return + print(f"windows built: {len(windows):,} total tokens: {tot:,}") + # coverage over built windows + cov = defaultdict(int) + ninst = {k: set() for k in ("exon", "intron", "gene")} + for w in windows: + for k, m in w["labels"].items(): + cov[k] += int(m.sum()) + for k in ninst: + ids = w["instances"][k] + ninst[k].update(int(x) for x in np.unique(ids) if x >= 0) + print("per-position coverage (of built windows):") + for k in ("exon", "intron", "cds", "utr", "intergenic"): + print(f" {k:11s} {cov[k]:>9,} ({100 * cov[k] / max(1, tot):5.1f}%)") + print(f"instances: exons={len(ninst['exon']):,} introns={len(ninst['intron']):,} genes={len(ninst['gene']):,}") + + +if __name__ == "__main__": + main() diff --git a/interpretability/sparse_autoencoders/recipes/evo2/scripts/labelers.py b/interpretability/sparse_autoencoders/recipes/evo2/scripts/labelers.py new file mode 100644 index 0000000000..3243742d75 --- /dev/null +++ b/interpretability/sparse_autoencoders/recipes/evo2/scripts/labelers.py @@ -0,0 +1,367 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Extensible per-token biological labelers for SAE feature probing. + +Each labeler maps a `SeqContext` (one tokenized sequence) to a per-token boolean +mask of length `T`. The per-feature AUROC probe (`probe_features.py`) asks, for +every label and every SAE feature, how well the feature's activation separates +positive from negative tokens. + +Adding a feature is just writing a function and decorating it: + + @labeler("my_concept") + def _my(ctx): + return some_bool_array_len_T + +`complex=True` flags labelers that are proxies or need real external annotation +(e.g. true gene models) and should be refined later — they're the natural home +for the "more complicated features" we want to add at the end. + +Conventions +----------- +* Tokens 0..tag_len-1 are the phylogenetic-tag prefix; sequence-derived motif / + positional labels are False there (use `_dna_mask`). Sequence-level labels + (`is_prok`) and norm-based labels (`is_sink_token`) may mark tag tokens. +* Byte-level Evo2 tokenization is 1 char = 1 token, so token i in the DNA region + corresponds to base `ctx.dna[i - tag_len]`. +""" + +from __future__ import annotations + +import re +from dataclasses import dataclass +from typing import Optional + +import numpy as np +from Bio.Data import CodonTable +from Bio.Seq import Seq + + +# name -> fn(ctx) -> np.ndarray[bool] of length T +LABELERS: dict[str, callable] = {} +# labelers that are proxies / need real annotations (documented, refine later) +COMPLEX_LABELERS: set[str] = set() + +# Sink-token norm threshold (residual L2). Set by the driver from the data +# (Evo2 7B layer-26 sinks sit ~1638 vs a ~21 median, so this cleanly isolates them). +SINK_NORM_THRESHOLD: float = 100.0 + + +def labeler(name: str, complex: bool = False): + """Register a per-token labeler under `name`.""" + + def deco(fn): + LABELERS[name] = fn + if complex: + COMPLEX_LABELERS.add(name) + return fn + + return deco + + +@dataclass +class SeqContext: + """Everything a labeler needs about one tokenized sequence.""" + + text: str # tag + dna (1 char == 1 token) + tag_len: int # number of leading phylo-tag tokens + dna: str # the DNA region (uppercase ACGTN), len == T - tag_len + kingdom: str # 'prok' | 'euk' + hidden_norm: np.ndarray # [T] residual L2 norm per token + # Gene-structure annotation over DNA positions (filled by a gene caller; None if absent). + cds_mask: Optional[np.ndarray] = None # bool[len(dna)] — within a predicted CDS (either strand) + cds_frame: Optional[np.ndarray] = None # int8[len(dna)] — codon position 0/1/2 within CDS, -1 if not + gene_starts: Optional[np.ndarray] = None # bool[len(dna)] — predicted translation start positions + + @property + def T(self) -> int: # noqa: D102 + return self.tag_len + len(self.dna) + + +def _dna_mask(ctx: SeqContext, dna_bool: np.ndarray) -> np.ndarray: + """Lift a per-DNA-position bool array to a per-token mask (tag tokens False).""" + out = np.zeros(ctx.T, dtype=bool) + out[ctx.tag_len : ctx.tag_len + len(dna_bool)] = dna_bool + return out + + +def _bytes(dna: str) -> np.ndarray: + return np.frombuffer(dna.encode("ascii", "replace"), dtype=np.uint8) + + +# --------------------------------------------------------------------- positional +@labeler("first_100bp") +def _first(ctx): + d = np.zeros(len(ctx.dna), bool) + d[:100] = True + return _dna_mask(ctx, d) + + +@labeler("last_100bp") +def _last(ctx): + d = np.zeros(len(ctx.dna), bool) + if len(d): + d[-100:] = True + return _dna_mask(ctx, d) + + +# --------------------------------------------------------------------- single base +# Evo2 is nucleotide-level (1 token = 1 base), so the most atomic feature is a single-base +# detector. One labeler per nucleotide — fires at every position whose base equals it — +# so probing can surface features that monosemantically track a single base (e.g. base_G). +def _register_base_labelers(): + for base in "ACGT": + labeler(f"base_{base}")(lambda ctx, b=base: _dna_mask(ctx, _bytes(ctx.dna) == ord(b))) + + +_register_base_labelers() + + +# --------------------------------------------------------------------- composition +def _gc_window(dna: str, radius: int = 10) -> np.ndarray: + arr = _bytes(dna) + gc = ((arr == ord("G")) | (arr == ord("C"))).astype(np.float64) + csum = np.concatenate([[0.0], np.cumsum(gc)]) + n = len(gc) + idx = np.arange(n) + lo = np.maximum(0, idx - radius) + hi = np.minimum(n, idx + radius + 1) + return (csum[hi] - csum[lo]) / np.maximum(1, hi - lo) + + +@labeler("gc_high_window") +def _gch(ctx): + return _dna_mask(ctx, _gc_window(ctx.dna) >= 0.60) + + +@labeler("gc_low_window") +def _gcl(ctx): + return _dna_mask(ctx, _gc_window(ctx.dna) <= 0.30) + + +@labeler("homopolymer_window") +def _homo(ctx, k: int = 5): + d, n = ctx.dna, len(ctx.dna) + out = np.zeros(n, bool) + i = 0 + while i < n: + j = i + while j + 1 < n and d[j + 1] == d[i]: + j += 1 + if j - i + 1 >= k: + out[i : j + 1] = True + i = j + 1 + return _dna_mask(ctx, out) + + +@labeler("dinuc_repeat_window") +def _dinuc(ctx, min_reps: int = 3): + d, n = ctx.dna, len(ctx.dna) + out = np.zeros(n, bool) + i = 0 + while i < n - 1: + if d[i] != d[i + 1]: + j = i + while j + 2 < n and d[j + 2] == d[j]: + j += 1 + span = j + 2 - i + if span >= 2 * min_reps: + out[i : j + 2] = True + i = max(j + 1, i + 1) + else: + i += 1 + return _dna_mask(ctx, out) + + +# --------------------------------------------------------------------- motifs +def _starts(dna: str, pattern: str) -> np.ndarray: + out = np.zeros(len(dna), bool) + for m in re.finditer(pattern, dna): + out[m.start()] = True + return out + + +def _spans(dna: str, pattern: str) -> np.ndarray: + out = np.zeros(len(dna), bool) + for m in re.finditer(pattern, dna): + out[m.start() : m.end()] = True + return out + + +# Consensus motifs: (name, matcher, regex) — `_starts` marks the match start, `_spans` the whole match. +_MOTIFS = [ + ("motif_ATG", _starts, r"ATG"), + ("motif_stop", _starts, r"TAA|TAG|TGA"), + ("motif_TATA", _spans, r"TATA[AT]A"), + ("motif_RBS_SD", _spans, r"AGGAGG"), # Shine-Dalgarno ribosome-binding site +] +for _name, _match, _pat in _MOTIFS: + labeler(_name)(lambda ctx, m=_match, p=_pat: _dna_mask(ctx, m(ctx.dna, p))) + + +# --------------------------------------------------- complex / consensus (refine later) +@labeler("kozak_atg", complex=True) +def _kozak(ctx): + # Kozak: (A/G)xxATGG — mark the ATG start (match start + 3) + out = np.zeros(len(ctx.dna), bool) + for m in re.finditer(r"[AG]..ATGG", ctx.dna): + out[m.start() + 3] = True + return _dna_mask(ctx, out) + + +@labeler("splice_donor", complex=True) +def _sd(ctx): + # 5' donor consensus GT(A/G)AGT — mark the GT + return _dna_mask(ctx, _starts(ctx.dna, r"GT[AG]AGT")) + + +@labeler("splice_acceptor", complex=True) +def _sa(ctx): + # 3' acceptor: polypyrimidine tract then AG — mark the AG + out = np.zeros(len(ctx.dna), bool) + for m in re.finditer(r"[CT]{6}[ACGT]?AG", ctx.dna): + out[m.end() - 2 : m.end()] = True + return _dna_mask(ctx, out) + + +# --------------------------------------------------------------- sequence / norm level +@labeler("is_prok") +def _prok(ctx): + return np.full(ctx.T, ctx.kingdom == "prok", dtype=bool) + + +@labeler("is_sink_token", complex=True) +def _sink(ctx): + return ctx.hidden_norm > SINK_NORM_THRESHOLD + + +# --------------------------------------------- gene structure (real annotation, prok) +# These read a CDS annotation attached to the context by a gene caller (see +# predict_cds, prokaryotes only). They are no-ops when the annotation is absent. +@labeler("cds_coding", complex=True) +def _cds(ctx): + if ctx.cds_mask is None: + return np.zeros(ctx.T, bool) + return _dna_mask(ctx, ctx.cds_mask) + + +@labeler("cds_start", complex=True) +def _cds_start(ctx): + if ctx.gene_starts is None: + return np.zeros(ctx.T, bool) + return _dna_mask(ctx, ctx.gene_starts) + + +@labeler("cds_frame_1", complex=True) +def _cds_f1(ctx): + # codon position 1 within a REAL predicted CDS (not the frame-0-from-start proxy) + if ctx.cds_frame is None: + return np.zeros(ctx.T, bool) + return _dna_mask(ctx, ctx.cds_frame == 0) + + +@labeler("cds_frame_3", complex=True) +def _cds_f3(ctx): + if ctx.cds_frame is None: + return np.zeros(ctx.T, bool) + return _dna_mask(ctx, ctx.cds_frame == 2) + + +_GENE_FINDER = None + +# Standard genetic code (NCBI translation table 1) via Biopython; codon -> amino acid ('*' = stop). +_STD_CODE = CodonTable.unambiguous_dna_by_id[1] +CODON_TABLE = {**_STD_CODE.forward_table, **dict.fromkeys(_STD_CODE.stop_codons, "*")} +CODON_LIST = sorted(CODON_TABLE) # 64 codons +CODON_TO_IDX = {c: i for i, c in enumerate(CODON_LIST)} +AA_LIST = sorted(set(CODON_TABLE.values())) # 20 aa + '*' (stop) +AA_TO_IDX = {a: i for i, a in enumerate(AA_LIST)} + + +def _revcomp(s): + return str(Seq(s).reverse_complement()) + + +def predict_codons(dna: str): + """In-frame codon + amino-acid identity at strand-correct codon anchors (prok genes). + + Returns (codon_id[N], aa_id[N]) over forward DNA coordinates; the anchor is the + first translated base of each codon (low coord on +strand, high coord on -strand), + other positions are -1. codon_id in 0..63 (CODON_LIST), aa_id in 0..20 (AA_LIST). + """ + global _GENE_FINDER + n = len(dna) + codon_id = np.full(n, -1, dtype=np.int16) + aa_id = np.full(n, -1, dtype=np.int8) + if n < 60: + return codon_id, aa_id + if _GENE_FINDER is None: + import pyrodigal + + _GENE_FINDER = pyrodigal.GeneFinder(meta=True) + for g in _GENE_FINDER.find_genes(dna.encode("ascii", "replace")): + b, e = max(0, g.begin - 1), min(n, g.end) + sub = dna[b:e] + coding = sub if g.strand == 1 else _revcomp(sub) + for i in range(len(coding) // 3): + cod = coding[3 * i : 3 * i + 3] + j = CODON_TO_IDX.get(cod) + if j is None: + continue + p = b + 3 * i if g.strand == 1 else (e - 1 - 3 * i) + if 0 <= p < n: + codon_id[p] = j + aa_id[p] = AA_TO_IDX[CODON_TABLE[cod]] + return codon_id, aa_id + + +def predict_cds(dna: str): + """Prokaryotic gene calling via pyrodigal (meta mode) on a single DNA chunk. + + Returns (cds_mask, cds_frame, gene_starts) over forward DNA coordinates: + cds_mask[i] True if position i lies within any predicted CDS (either strand) + cds_frame[i] codon position 0/1/2 relative to that gene's start (strand-aware), else -1 + gene_starts[i] True at predicted translation starts + """ + global _GENE_FINDER + n = len(dna) + cds_mask = np.zeros(n, dtype=bool) + cds_frame = np.full(n, -1, dtype=np.int8) + gene_starts = np.zeros(n, dtype=bool) + if n < 60: + return cds_mask, cds_frame, gene_starts + if _GENE_FINDER is None: + import pyrodigal + + _GENE_FINDER = pyrodigal.GeneFinder(meta=True) + for g in _GENE_FINDER.find_genes(dna.encode("ascii", "replace")): + b, e = g.begin - 1, g.end # 0-based half-open, forward coords + b, e = max(0, b), min(n, e) + if e <= b: + continue + cds_mask[b:e] = True + idx = np.arange(b, e) + if g.strand == 1: + gene_starts[b] = True + cds_frame[b:e] = (idx - b) % 3 + else: # reverse strand: start codon sits at the (forward) end + gene_starts[e - 1] = True + cds_frame[b:e] = ((e - 1) - idx) % 3 + return cds_mask, cds_frame, gene_starts + + +# Default label set for the probe (order preserved in outputs). +DEFAULT_LABELS = list(LABELERS.keys()) diff --git a/interpretability/sparse_autoencoders/recipes/evo2/tests/test_annot_tracks.py b/interpretability/sparse_autoencoders/recipes/evo2/tests/test_annot_tracks.py new file mode 100644 index 0000000000..e927c9620d --- /dev/null +++ b/interpretability/sparse_autoencoders/recipes/evo2/tests/test_annot_tracks.py @@ -0,0 +1,76 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CPU unit tests for the generic interval-track loader (no model / no torch-CUDA).""" + +import sys +from pathlib import Path + + +sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "scripts")) + +from annot_tracks import label_windows, load_track, read_fasta_dict + + +def test_read_fasta_dict_uses_first_token(tmp_path): + """seq_id is the first header token (so it matches BED/GFF chrom).""" + fa = tmp_path / "g.fa" + fa.write_text(">chr1 Homo sapiens\nACGT\nACGT\n>chr2\nTTTT\n") + assert read_fasta_dict(fa) == {"chr1": "ACGTACGT", "chr2": "TTTT"} + + +def test_load_bed_is_half_open(tmp_path): + """BED is 0-based half-open and used as-is.""" + bed = tmp_path / "t.bed" + bed.write_text("chr1\t2\t5\tsiteA\nchr1\t10\t12\tsiteB\n") + assert load_track(str(bed)) == {"chr1": [(2, 5), (10, 12)]} + + +def test_load_gff_converts_to_half_open_and_filters_type(tmp_path): + """GFF 1-based inclusive -> 0-based half-open; feature_type filters column 3.""" + gff = tmp_path / "t.gff3" + gff.write_text( + "# comment\n" + "chr1\tsrc\texon\t3\t5\t.\t+\t.\tID=e1\n" # 1-based [3,5] -> [2,5) + "chr1\tsrc\tCDS\t3\t5\t.\t+\t.\tID=c1\n" + ) + assert load_track(str(gff), feature_type="exon") == {"chr1": [(2, 5)]} + + +def test_label_windows_mask_and_instance_ids(tmp_path): + """Each interval is one instance; mask + instance id line up with the window positions.""" + seqs = {"chr1": "ACGT" * 25} # 100 bp (above the 60 bp window floor) + tracks = {"site": {"chr1": [(2, 5), (10, 12)]}} # two instances + windows, stats = label_windows(seqs, tracks, seq_len=100) + assert len(windows) == 1 + w = windows[0] + mask, inst = w["labels"]["site"], w["instances"]["site"] + assert list(mask.nonzero()[0]) == [2, 3, 4, 10, 11] + assert set(inst[mask].tolist()) == {0, 1} # two distinct instances + assert (inst[~mask] == -1).all() + assert stats["n_inst"]["site"] == 2 + + +def test_instance_id_stable_across_split_windows(): + """An interval spanning a window boundary keeps ONE global instance id (recall counts it once).""" + seqs = {"chr1": "A" * 200} + tracks = {"big": {"chr1": [(90, 110)]}} # straddles the 0-100 / 100-200 boundary + windows, stats = label_windows(seqs, tracks, seq_len=100) + ids = set() + for w in windows: + inst = w["instances"]["big"] + ids.update(int(x) for x in inst[inst >= 0]) + assert ids == {0} # same id in both windows + assert stats["n_inst"]["big"] == 1 diff --git a/interpretability/sparse_autoencoders/recipes/evo2/tests/test_labelers.py b/interpretability/sparse_autoencoders/recipes/evo2/tests/test_labelers.py new file mode 100644 index 0000000000..e99cf236b2 --- /dev/null +++ b/interpretability/sparse_autoencoders/recipes/evo2/tests/test_labelers.py @@ -0,0 +1,53 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CPU tests for the per-token labelers (pure masks, no model).""" + +import sys +from pathlib import Path + +import numpy as np + + +sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "scripts")) + +from labelers import LABELERS, SeqContext + + +def _ctx(dna, tag_len=0): + text = "X" * tag_len + dna + return SeqContext(text=text, tag_len=tag_len, dna=dna, kingdom="prok", hidden_norm=np.zeros(tag_len + len(dna))) + + +def test_consensus_motifs_fire_at_match_positions(): + """The table-driven motifs mark the right positions (ATG/stop = start, TATA = span).""" + ctx = _ctx("ATGTAACGT") # ATG @0 ; TAA (stop) @3 + assert list(LABELERS["motif_ATG"](ctx).nonzero()[0]) == [0] + assert list(LABELERS["motif_stop"](ctx).nonzero()[0]) == [3] + assert list(LABELERS["motif_TATA"](_ctx("TATAAA")).nonzero()[0]) == [0, 1, 2, 3, 4, 5] # spans the match + + +def test_base_labelers_fire_per_nucleotide(): + """base_A/C/G/T each fire exactly on their nucleotide.""" + ctx = _ctx("ACGTAA") + assert list(LABELERS["base_A"](ctx).nonzero()[0]) == [0, 4, 5] + assert list(LABELERS["base_G"](ctx).nonzero()[0]) == [2] + + +def test_tag_prefix_is_unlabeled(): + """Sequence-derived labels are False over the leading phylo-tag tokens.""" + ctx = _ctx("ATG", tag_len=2) # tokens: [tag, tag, A, T, G] + m = LABELERS["motif_ATG"](ctx) + assert len(m) == 5 and not m[:2].any() and m[2] # ATG starts at DNA pos 0 -> token 2 From 6566f2a071af73a8056130e05a12f93b61f9c9fe Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Tue, 23 Jun 2026 06:36:32 +0000 Subject: [PATCH 2/6] =?UTF-8?q?fix(evo2-sae):=20address=20#1630=20review?= =?UTF-8?q?=20=E2=80=94=20deps,=20single-chrom=20guard,=20label=20correctn?= =?UTF-8?q?ess,=20tests?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - declare pyrodigal (predict_cds/predict_codons import it lazily; only biopython was declared, so the gene-calling path would ImportError in a clean env) - euk_windows: reject a multi-record FASTA (parse_gff drops seqid + we concatenate, so a multi-chrom GFF would silently mislabel against a blob) and document the single-chromosome assumption - skip windows overlapping a *second* gene's span (central-gene approx mislabels a neighbor's exons as intergenic) and de-dup near-duplicate adjacent-exon windows so they don't eat the token budget - extract the CDS reverse-strand frame math into a unit-tested helper (_frame_and_start) - tests: euk_windows had none — add parse_gff/1-based->0-based labeling/single-chrom guard; cover the previously-untested labelers (gc/homopolymer/dinuc/kozak/splice) + the +/- strand frame anchor #1630 CPU tests: 16 passed. Co-Authored-By: Claude Opus 4.8 Signed-off-by: Polina Binder --- .../recipes/evo2/pyproject.toml | 1 + .../recipes/evo2/scripts/euk_windows.py | 36 +++++++- .../recipes/evo2/scripts/labelers.py | 24 ++++-- .../recipes/evo2/tests/test_euk_windows.py | 82 +++++++++++++++++++ .../recipes/evo2/tests/test_labelers.py | 38 +++++++++ 5 files changed, 170 insertions(+), 11 deletions(-) create mode 100644 interpretability/sparse_autoencoders/recipes/evo2/tests/test_euk_windows.py diff --git a/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml b/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml index ba794c5e74..047e3e69f0 100644 --- a/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml +++ b/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml @@ -14,6 +14,7 @@ dependencies = [ "numpy>=1.20", "pyarrow>=23.0.0", "biopython>=1.80", # genetic code / translation in labelers.py + "pyrodigal>=3.0", # prokaryotic gene calling in labelers.predict_cds / predict_codons ] # No package code lives here yet — the recipe is just an entry-point for diff --git a/interpretability/sparse_autoencoders/recipes/evo2/scripts/euk_windows.py b/interpretability/sparse_autoencoders/recipes/evo2/scripts/euk_windows.py index 0f381df84b..f656428b9e 100644 --- a/interpretability/sparse_autoencoders/recipes/evo2/scripts/euk_windows.py +++ b/interpretability/sparse_autoencoders/recipes/evo2/scripts/euk_windows.py @@ -28,6 +28,11 @@ `python euk_windows.py --fasta chr21.fa --gff chr21.gff3 --dry-run` prints coverage stats without building sequences. + +Single chromosome only: the FASTA is read as one concatenated sequence and ``parse_gff`` +drops the seqid column, so a multi-record FASTA + multi-chromosome GFF would apply +per-chromosome coordinates against a concatenated blob (silently wrong labels). ``build_windows`` +asserts a single FASTA record; run one chromosome at a time. """ from __future__ import annotations @@ -96,7 +101,12 @@ def representative_tx(gene): def _label_window(chrom, w0, w1, gm, N): - """Label a window [w0,w1) using one gene model's intervals (central-gene approx).""" + """Label a window [w0,w1) using one gene model's intervals (central-gene approx). + + Caveat: labels derive from a single gene model — every position outside this gene's span is + marked intergenic, so a *second* gene overlapping the window would have its exons mislabeled. + build_windows skips windows overlapping another gene's span to avoid injecting that bias. + """ L = w1 - w0 pos = np.arange(w0, w1) lab = {k: np.zeros(L, bool) for k in ("exon", "intron", "cds", "utr", "intergenic")} @@ -122,11 +132,20 @@ def _label_window(chrom, w0, w1, gm, N): def build_windows( # noqa: D103 fasta, gff, seq_len=1024, max_tokens=300_000, flank=300, seed=0, intergenic_frac=0.12, dry_run=False ): - seqs = [] + seqs, n_records = [], 0 with open(fasta) as fh: for line in fh: - if not line.startswith(">"): + if line.startswith(">"): + n_records += 1 + else: seqs.append(line.strip()) + if n_records != 1: + # parse_gff keeps no seqid, and we concatenate all records into one `chrom`; a multi-record + # FASTA + multi-chromosome GFF would silently apply per-chromosome coordinates to the blob. + raise ValueError( + f"euk_windows expects a single-chromosome FASTA matching the GFF, got {n_records} " + f"records. Run one chromosome at a time (e.g. chr21.fa + chr21.gff3)." + ) chrom = "".join(seqs).upper() N = len(chrom) genes = parse_gff(gff) @@ -173,7 +192,7 @@ def build_windows( # noqa: D103 # exon-centered windows sampled across ALL genes' exons (diverse + exon/intron balanced) exon_refs = [(gi, ei) for gi, gm in enumerate(gene_models) for ei in range(len(gm["exons"]))] rng.shuffle(exon_refs) - windows, tot = [], 0 + windows, tot, accepted = [], 0, [] budget_genic = int(max_tokens * (1 - intergenic_frac)) for gi, ei in exon_refs: if tot >= budget_genic: @@ -185,10 +204,19 @@ def build_windows( # noqa: D103 w1 = min(N, w0 + seq_len) if w1 - w0 < 60: continue + # _label_window labels from one gene model, so a second gene overlapping the window would + # have its exons mislabeled intergenic — skip those windows to keep eval labels trustworthy. + if any(j != gi and w0 < ge and w1 > gs - 1 for j, (gs, ge) in enumerate(gene_spans)): + continue + # adjacent exons of one gene center on nearly the same span; drop near-duplicate windows so + # heavily-overlapping (correlated) positions don't quietly eat the token budget. + if any(min(w1, aw1) - max(w0, aw0) > seq_len // 2 for aw0, aw1 in accepted): + continue win = _label_window(chrom, w0, w1, gm, N) if win["dna"].count("N") > 0.5 * len(win["dna"]): continue windows.append(win) + accepted.append((w0, w1)) tot += w1 - w0 # intergenic windows: random spots clear of any gene span (+flank) spans = sorted(gene_spans) diff --git a/interpretability/sparse_autoencoders/recipes/evo2/scripts/labelers.py b/interpretability/sparse_autoencoders/recipes/evo2/scripts/labelers.py index 3243742d75..04b4c0d53e 100644 --- a/interpretability/sparse_autoencoders/recipes/evo2/scripts/labelers.py +++ b/interpretability/sparse_autoencoders/recipes/evo2/scripts/labelers.py @@ -328,6 +328,20 @@ def predict_codons(dna: str): return codon_id, aa_id +def _frame_and_start(b: int, e: int, strand: int): + """Codon-frame array (0/1/2 along [b, e)) + the translation-start position for one CDS. + + Frame is relative to the strand-correct start: the start is the low coord on the + strand + (frame counts up from ``b``) and the high coord on the - strand (frame counts back from + ``e-1``). Pure arithmetic, split out from predict_cds so the reverse-strand anchor (an easy + off-by-one) is unit-tested directly. + """ + idx = np.arange(b, e) + if strand == 1: + return ((idx - b) % 3).astype(np.int8), b + return (((e - 1) - idx) % 3).astype(np.int8), e - 1 + + def predict_cds(dna: str): """Prokaryotic gene calling via pyrodigal (meta mode) on a single DNA chunk. @@ -353,13 +367,9 @@ def predict_cds(dna: str): if e <= b: continue cds_mask[b:e] = True - idx = np.arange(b, e) - if g.strand == 1: - gene_starts[b] = True - cds_frame[b:e] = (idx - b) % 3 - else: # reverse strand: start codon sits at the (forward) end - gene_starts[e - 1] = True - cds_frame[b:e] = ((e - 1) - idx) % 3 + frame, start = _frame_and_start(b, e, g.strand) + cds_frame[b:e] = frame + gene_starts[start] = True return cds_mask, cds_frame, gene_starts diff --git a/interpretability/sparse_autoencoders/recipes/evo2/tests/test_euk_windows.py b/interpretability/sparse_autoencoders/recipes/evo2/tests/test_euk_windows.py new file mode 100644 index 0000000000..5296fa21e5 --- /dev/null +++ b/interpretability/sparse_autoencoders/recipes/evo2/tests/test_euk_windows.py @@ -0,0 +1,82 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CPU tests for euk_windows: GFF parsing, 1-based->0-based labeling, single-chromosome guard.""" + +import sys +from pathlib import Path + +import pytest + + +sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "scripts")) + +from euk_windows import _label_window, build_windows, parse_gff, representative_tx + + +_GFF = ( + "##gff-version 3\n" + "1\tt\tgene\t11\t40\t.\t+\t.\tID=gene:g1;biotype=protein_coding\n" + "1\tt\tmRNA\t11\t40\t.\t+\t.\tID=transcript:t1;Parent=gene:g1;biotype=protein_coding\n" + "1\tt\texon\t11\t20\t.\t+\t.\tParent=transcript:t1\n" + "1\tt\texon\t31\t40\t.\t+\t.\tParent=transcript:t1\n" + "1\tt\tCDS\t11\t20\t.\t+\t.\tParent=transcript:t1\n" + # a non-coding gene that must be filtered out + "1\tt\tgene\t100\t120\t.\t+\t.\tID=gene:g2;biotype=lncRNA\n" + "1\tt\tmRNA\t100\t120\t.\t+\t.\tID=transcript:t2;Parent=gene:g2;biotype=lncRNA\n" + "1\tt\texon\t100\t120\t.\t+\t.\tParent=transcript:t2\n" +) + + +def test_parse_gff_keeps_only_protein_coding(tmp_path): + """parse_gff returns only protein_coding genes, with sorted exon/CDS coords (1-based, inclusive).""" + gff = tmp_path / "g.gff3" + gff.write_text(_GFF) + genes = parse_gff(str(gff)) + assert set(genes) == {"g1"} # the lncRNA gene is filtered out + _, tx = representative_tx(genes["g1"]) + assert tx["exon"] == [(11, 20), (31, 40)] and tx["cds"] == [(11, 20)] + + +def test_label_window_converts_1based_inclusive_to_0based_halfopen(): + """A GFF interval s..e (1-based inclusive) labels 0-based positions s-1..e-1; intergenic/UTR + fall out correctly.""" + chrom = "A" * 50 + gm = { + "exons": [(11, 20)], + "introns": [(21, 25)], + "cds": [(11, 15)], + "gi": 0, + "exon_ids": [0], + "intron_ids": [0], + "span": (11, 25), + } + lab = _label_window(chrom, 0, 50, gm, 50)["labels"] + assert list(lab["exon"].nonzero()[0]) == list(range(10, 20)) # 11..20 -> 10..19 + assert list(lab["intron"].nonzero()[0]) == list(range(20, 25)) # 21..25 -> 20..24 + assert list(lab["cds"].nonzero()[0]) == list(range(10, 15)) # 11..15 -> 10..14 + assert list(lab["utr"].nonzero()[0]) == list(range(15, 20)) # exon minus cds + assert lab["intergenic"][:10].all() and lab["intergenic"][25:].all() # outside the gene span + assert not lab["intergenic"][10:25].any() + + +def test_build_windows_rejects_multi_record_fasta(tmp_path): + """A multi-record FASTA is rejected (coords would be applied against a concatenated blob).""" + fasta = tmp_path / "f.fa" + fasta.write_text(">chr1\nACGTACGT\n>chr2\nACGTACGT\n") + gff = tmp_path / "g.gff3" + gff.write_text("##gff-version 3\n") + with pytest.raises(ValueError, match="single-chromosome"): + build_windows(str(fasta), str(gff)) diff --git a/interpretability/sparse_autoencoders/recipes/evo2/tests/test_labelers.py b/interpretability/sparse_autoencoders/recipes/evo2/tests/test_labelers.py index e99cf236b2..54daaabf65 100644 --- a/interpretability/sparse_autoencoders/recipes/evo2/tests/test_labelers.py +++ b/interpretability/sparse_autoencoders/recipes/evo2/tests/test_labelers.py @@ -51,3 +51,41 @@ def test_tag_prefix_is_unlabeled(): ctx = _ctx("ATG", tag_len=2) # tokens: [tag, tag, A, T, G] m = LABELERS["motif_ATG"](ctx) assert len(m) == 5 and not m[:2].any() and m[2] # ATG starts at DNA pos 0 -> token 2 + + +def test_gc_window_labelers(): + """gc_high/low fire on GC-rich / AT-rich windows (mean GC over the ±10 window).""" + assert LABELERS["gc_high_window"](_ctx("G" * 30)).all() # all-GC -> high everywhere + assert not LABELERS["gc_high_window"](_ctx("A" * 30)).any() + assert LABELERS["gc_low_window"](_ctx("A" * 30)).all() # all-AT -> low everywhere + + +def test_homopolymer_window(): + """homopolymer_window marks runs of >=5 identical bases (and nothing shorter).""" + m = LABELERS["homopolymer_window"](_ctx("TTAAAAACC")) # AAAAA (5) at positions 2..6 + assert list(m.nonzero()[0]) == [2, 3, 4, 5, 6] + + +def test_dinuc_repeat_window(): + """dinuc_repeat_window marks alternating dinucleotide repeats (>=3 reps, span >=6).""" + assert LABELERS["dinuc_repeat_window"](_ctx("ATATATAT")).all() # (AT)x4 + assert not LABELERS["dinuc_repeat_window"](_ctx("ATAT")).any() # (AT)x2 -> too short + + +def test_consensus_offsets_kozak_and_splice(): + """The consensus labelers mark the biologically meaningful offset within the match.""" + assert list(LABELERS["kozak_atg"](_ctx("AAAATGG")).nonzero()[0]) == [3] # [AG]..ATGG -> the ATG + assert list(LABELERS["splice_donor"](_ctx("CCGTAAGTCC")).nonzero()[0]) == [2] # GT[AG]AGT -> the GT + assert list(LABELERS["splice_acceptor"](_ctx("TTTTTTAG")).nonzero()[0]) == [6, 7] # ...AG -> the AG + + +def test_frame_and_start_forward_and_reverse(): + """The CDS frame helper anchors frame 0 at the strand-correct start (the reverse-strand + off-by-one): start = low coord on +strand, high coord on -strand, counting back.""" + from labelers import _frame_and_start + + frame, start = _frame_and_start(10, 19, strand=1) # 9 bp = 3 codons, forward + assert start == 10 and list(frame) == [0, 1, 2, 0, 1, 2, 0, 1, 2] + + frame, start = _frame_and_start(10, 19, strand=-1) # reverse: start at the high coord + assert start == 18 and list(frame) == [2, 1, 0, 2, 1, 0, 2, 1, 0] From 926966b3a78ecbb616df7ea12b045a075b825b84 Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Tue, 23 Jun 2026 17:34:17 +0000 Subject: [PATCH 3/6] ci(evo2-sae): add a CPU lane for the model-agnostic recipe tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit A new ubuntu-latest workflow installs sae + the recipe (CPU torch) and runs the recipe's model-agnostic tests (-m 'not slow') — the label producers (#1630), eval metrics, etc. — so they run cheaply on the probing-stack branches instead of waiting for #1622's megatron GPU lane (which would run them on an L4 after a full build). Registers the 'slow' marker on the recipe pyproject so the GPU tests are excluded without an unknown-marker warning. Validated: pytest tests/ -m 'not slow' -> 16 passed (CPU). Co-Authored-By: Claude Opus 4.8 Signed-off-by: Polina Binder --- .../workflows/unit-tests-evo2-recipe-cpu.yaml | 53 +++++++++++++++++++ .../recipes/evo2/pyproject.toml | 7 +++ 2 files changed, 60 insertions(+) create mode 100644 .github/workflows/unit-tests-evo2-recipe-cpu.yaml diff --git a/.github/workflows/unit-tests-evo2-recipe-cpu.yaml b/.github/workflows/unit-tests-evo2-recipe-cpu.yaml new file mode 100644 index 0000000000..3f3a3ff8d1 --- /dev/null +++ b/.github/workflows/unit-tests-evo2-recipe-cpu.yaml @@ -0,0 +1,53 @@ +name: "BioNeMo Evo2 SAE Recipe CI (CPU)" + +# CPU unit tests for the model-agnostic parts of the evo2 SAE recipe (DNA labelers, window +# builders, eval metrics) — no model and no GPU, just CPU torch + numpy + biopython on +# ubuntu-latest. The model-loading GPU tests (@pytest.mark.slow) run in the dedicated megatron +# lane (unit-tests-interpretability-recipes.yaml); here they're excluded via -m "not slow". + +on: + push: + branches: + - "pull-request/[0-9]+" + - "dependabot/**" + paths: + - "interpretability/sparse_autoencoders/recipes/evo2/**" + - "interpretability/sparse_autoencoders/sae/**" + - ".github/workflows/unit-tests-evo2-recipe-cpu.yaml" + merge_group: + types: [checks_requested] + schedule: + - cron: "0 9 * * *" # Runs at 9 AM UTC daily (2 AM MST) + +defaults: + run: + shell: bash -x -e -u -o pipefail {0} + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs: + recipe-cpu-tests: + runs-on: ubuntu-latest + name: "evo2-recipe-tests (cpu)" + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + sparse-checkout: interpretability/sparse_autoencoders + sparse-checkout-cone-mode: false + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Install (CPU torch + sae + recipe) + working-directory: interpretability/sparse_autoencoders + run: | + pip install --extra-index-url https://download.pytorch.org/whl/cpu -e sae -e recipes/evo2 + + - name: Run model-agnostic recipe tests + working-directory: interpretability/sparse_autoencoders/recipes/evo2 + run: pytest -v tests/ -m "not slow" diff --git a/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml b/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml index 047e3e69f0..d7d988b20d 100644 --- a/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml +++ b/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml @@ -25,3 +25,10 @@ packages = [] [tool.uv.sources] sae = { workspace = true } + +# Register markers locally so the recipe is self-contained (the repo-root pytest.ini isn't in the +# CI sparse-checkout); lets the CPU lane select `-m "not slow"` without an unknown-marker warning. +[tool.pytest.ini_options] +markers = [ + "slow: GPU/integration tests that load a model (skip without CUDA + checkpoints)", +] From 8276e57d9dac5d713450a8bca29157623c0f34fa Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Tue, 23 Jun 2026 06:19:08 +0000 Subject: [PATCH 4/6] evo2 SAE eval (2/2): probing harness + CLI, rebased onto migrated #1630 Re-lands #1636 on the post-#1633 layout, on top of rebased #1630: the harness/CLI (scripts/{evo2_buffer,probe,probe_loss_recovered}.py) that runs the model to build an ActivationBuffer (#1629) from #1630's labels and emits the probing metrics. Syntax-checked; the GPU extract->score smoke is a follow-up (no unit tests in this PR yet). Co-Authored-By: Claude Opus 4.8 Signed-off-by: Polina Binder --- .../recipes/evo2/scripts/evo2_buffer.py | 125 ++++++ .../recipes/evo2/scripts/probe.py | 417 ++++++++++++++++++ .../evo2/scripts/probe_loss_recovered.py | 152 +++++++ 3 files changed, 694 insertions(+) create mode 100644 interpretability/sparse_autoencoders/recipes/evo2/scripts/evo2_buffer.py create mode 100644 interpretability/sparse_autoencoders/recipes/evo2/scripts/probe.py create mode 100644 interpretability/sparse_autoencoders/recipes/evo2/scripts/probe_loss_recovered.py diff --git a/interpretability/sparse_autoencoders/recipes/evo2/scripts/evo2_buffer.py b/interpretability/sparse_autoencoders/recipes/evo2/scripts/evo2_buffer.py new file mode 100644 index 0000000000..8a0d68fbeb --- /dev/null +++ b/interpretability/sparse_autoencoders/recipes/evo2/scripts/evo2_buffer.py @@ -0,0 +1,125 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Evo2-specific bit: turn DNA sequences into a probing ActivationBuffer. + +The only model-touching code in the probing pipeline. Streams sequences through +the Evo2SAE engine (Evo2 -> layer-L residual -> SAE.encode), keeps the dense +residual twin, and computes per-token labels (+ instance IDs) from labelers.py. +All scoring is done elsewhere by the model-agnostic sae.eval.probing metrics. +""" + +from __future__ import annotations + +import random + +import labelers as L +import numpy as np +import torch +from evo2_sae.fasta import read_fasta # shared reader (also handles .gz); kingdom check uses the id prefix +from sae.eval.probing import ActivationBuffer + + +KINGDOM_TAGS = {"prok": "|d__Bacteria|", "euk": "|d__Eukaryota|"} + + +def sample_sequences(fasta, max_tokens, seq_len, kingdoms=("prok", "euk"), seed=0): # noqa: D103 + from evo2_sae.core import clean_dna + + kingdoms = list(kingdoms) + pools = {k: [] for k in kingdoms} + need = max_tokens // seq_len + 50 + for header, seq in read_fasta(fasta): + kg = "prok" if header.lower().startswith("prok") else "euk" + if kg not in pools: + continue + dna = clean_dna(seq)[:seq_len] + if len(dna) < 60: + continue + pools[kg].append(dna) + if all(len(pools[k]) >= need for k in kingdoms): + break + rng = random.Random(seed) + for k in kingdoms: + rng.shuffle(pools[k]) + out, tok, i = [], 0, 0 + maxlen = max((len(pools[k]) for k in kingdoms), default=0) + while tok < max_tokens and i < maxlen: + for k in kingdoms: + if i < len(pools[k]): + out.append((k, pools[k][i])) + tok += len(pools[k][i]) + len(KINGDOM_TAGS[k]) + i += 1 + rng.shuffle(out) + return out + + +@torch.no_grad() +def build_buffer(engine, seqs, label_names, *, subsample, auroc_device, annotate_cds=False, batch_size=8, log=print): + """Stream seqs through engine -> ActivationBuffer (codes + dense + labels [+ cds instances]).""" + F = engine.n_features + Hd = engine.sae.pre_bias.shape[0] + dev = engine.device + S = subsample + code_buf = torch.zeros(S, F, dtype=torch.float16, device=auroc_device) + dense_buf = torch.zeros(S, Hd, dtype=torch.float16, device=auroc_device) + lab_buf = torch.zeros(S, len(label_names), dtype=torch.bool, device=auroc_device) + filled = 0 + for start in range(0, len(seqs), batch_size): + if filled >= S: + break + batch = seqs[start : start + batch_size] + id_lists, metas = [], [] + for kg, dna in batch: + tag = KINGDOM_TAGS[kg] + tids = engine.tokenize(tag) + id_lists.append(tids + engine.tokenize(dna)) + metas.append((tag, len(tids), kg, dna)) + with engine._lock: + hiddens = engine._forward_hidden(id_lists) + for h, (tag, tlen, kg, dna) in zip(hiddens, metas): + if h.shape[0] == 0 or filled >= S: + continue + hd = h.to(dev) + codes = engine.sae.encode(hd) + norm = h.float().norm(dim=-1).cpu().numpy() + T = codes.shape[0] + cds_mask = cds_frame = gene_starts = None + if annotate_cds and kg == "prok": + cds_mask, cds_frame, gene_starts = L.predict_cds(dna) + ctx = L.SeqContext( + text=(tag + dna)[:T], + tag_len=tlen, + dna=dna, + kingdom=kg, + hidden_norm=norm[:T], + cds_mask=cds_mask, + cds_frame=cds_frame, + gene_starts=gene_starts, + ) + lab = np.stack([L.LABELERS[n](ctx)[:T] for n in label_names], axis=1) + take = min(T, S - filled) + code_buf[filled : filled + take] = codes[:take].to(torch.float16).to(auroc_device) + dense_buf[filled : filled + take] = hd[:take].to(torch.float16).to(auroc_device) + lab_buf[filled : filled + take] = torch.from_numpy(lab[:take]).to(auroc_device) + filled += take + if (start // batch_size) % 10 == 0: + log(f" {start + len(batch)}/{len(seqs)} seqs | buf {filled}/{S}") + return ActivationBuffer( + codes=code_buf[:filled].cpu().numpy(), + dense=dense_buf[:filled].cpu().numpy(), + labels=lab_buf[:filled].cpu().numpy(), + label_names=list(label_names), + ) diff --git a/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe.py b/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe.py new file mode 100644 index 0000000000..5c1b69aa95 --- /dev/null +++ b/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe.py @@ -0,0 +1,417 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Unified Evo2 SAE probing CLI. All scoring is sae.eval.probing (model-agnostic); +this driver only knows how to build/load Evo2 buffers and pick label sets. + + probe.py extract --out BUF [...] build an ActivationBuffer (needs the model) + probe.py auroc --acts BUF --labels .. per-feature AUROC table (prints) + probe.py annotate --acts BUF --out P assign each feature its best concept -> annotation parquet + probe.py linear --acts BUF --labels .. SAE-vs-dense single + multi (disentanglement/distributed) + probe.py codon-aa --acts CODON_BUF codon/AA decoders + family-disjoint, SAE vs dense + probe.py euk-f1 --fasta .. --gff .. RefSeq gene-structure domain-F1 (needs the model) + probe.py domain-eval --fasta .. --track .. user annotated dataset -> per-feature domain-F1 + AUROC vs + any BED/GFF tracks (RefSeq/Rfam/JASPAR/ENCODE) (needs the model) + (SAE fidelity / loss-recovered lives in the separate probe_loss_recovered.py script — see step 5 below) + +Example end-to-end flow (7B / layer 26; $CKPT = MBridge dir, $SAE = trained SAE .pt): + + # 1. Build the probing buffer once: SAE codes + dense twin + per-token labels (needs the model) + python probe.py extract --evo2-ckpt-dir $CKPT --sae-checkpoint $SAE --layer 26 \ + --fasta probe_set.fa --out buf.npz + + # 2. Score the buffer (no model): per-feature AUROC, then SAE-vs-dense linear probes + python probe.py auroc --acts buf.npz --labels motif_ATG,motif_stop,cds_coding,is_prok + python probe.py linear --acts buf.npz --labels cds_coding,is_prok + + # 3. Persist annotations (no model): each feature's best concept (incl. base_A/C/G/T) -> + # the feature-annotation parquet the engine/dashboard load via --feature-annotations + python probe.py annotate --acts buf.npz --out feature_annotations.parquet --min-auroc 0.85 + + # 4. User annotated dataset -> per-feature domain-F1 (prec/nt, recall/annotation) + AUROC, + # vs any BED/GFF tracks (RefSeq/Rfam/JASPAR/ENCODE) (needs the model) + python probe.py domain-eval --evo2-ckpt-dir $CKPT --sae-checkpoint $SAE --layer 26 \ + --fasta GRCh38_chr20.fa --track exon=refseq.gff3:exon --track cCRE=encode_ccre.bed + + # 5. SAE fidelity (loss recovered) — separate script, needs the model + python probe_loss_recovered.py --evo2-ckpt-dir $CKPT --sae-checkpoint $SAE --layer 26 --fasta probe_set.fa +""" # noqa: D205 + +from __future__ import annotations + +import argparse +import sys +from pathlib import Path + +import numpy as np +import torch + + +_HERE = Path(__file__).resolve().parent +sys.path.insert(0, str(_HERE)) +sys.path.insert(0, str(_HERE.parent)) +sys.path.insert(0, str(_HERE.parents[2] / "sae" / "src")) # sparse_autoencoders/sae/src + +import labelers as L # noqa: E402 +from sae.eval.probing import ( # noqa: E402 + ActivationBuffer, + auroc_all, + auroc_vec, + best_single_train_test, + decode_eval, + fit_softmax, + split_indices, + standardize, +) + + +def _z(X, tr): + # Standardize X by the train-split mean/std (reuses sae.eval.probing.standardize). + mu, sd = standardize(X, tr) + return (X - mu) / sd + + +def _load(a): + """Load the probing buffer + resolve requested label names (default: all labels in the buffer).""" + buf = ActivationBuffer.load(a.acts) + if getattr(a, "labels", None): + names = a.labels.split(",") + unknown = [t for t in names if t not in buf.name_idx] + if unknown: + raise SystemExit(f"unknown label(s) {unknown}; buffer has: {list(buf.label_names)}") + else: + names = list(buf.label_names) + return buf, names + + +# ───────────────────────────────────────── buffer-only subcommands (no model) +def cmd_auroc(a): # noqa: D103 + buf, names = _load(a) + dev = a.device + X = torch.from_numpy(buf.codes).to(dev).float() + Y = torch.stack([torch.from_numpy(buf.labels[:, buf.name_idx[n]]).to(dev) for n in names], 1) + au = auroc_all(X, Y).cpu().numpy() + print(f"{'label':18s} {'%pos':>6s} {'best AUROC':>10s} {'feature':>8s}") + for i, n in enumerate(names): + print( + f"{n:18s} {buf.labels[:, buf.name_idx[n]].mean():6.1%} {au[:, i].max():10.3f} {int(au[:, i].argmax()):8d}" + ) + + +def _eval_matrix(mat, buf, names, tr, te, dev, steps, wd): + X = torch.from_numpy(mat).to(dev).float() + Xz = _z(X, tr) + out = {} + from sae.eval.probing import fit_logreg + + for n in names: + ytr = torch.from_numpy(buf.labels[tr.numpy(), buf.name_idx[n]]).to(dev).float() + yte = torch.from_numpy(buf.labels[te.numpy(), buf.name_idx[n]]).to(dev) + if ytr.sum() in (0, len(ytr)) or yte.sum() in (0, len(yte)): + out[n] = (float("nan"), float("nan")) + continue + w, b = fit_logreg(Xz[tr], ytr, steps=steps, wd=wd) + out[n] = (best_single_train_test(Xz[tr], ytr, Xz[te], yte), auroc_vec((Xz[te] @ w + b).float(), yte)) + del X, Xz + torch.cuda.empty_cache() + return out + + +def cmd_linear(a): # noqa: D103 + buf, names = _load(a) + dev = a.device + tr, te = split_indices(buf.codes.shape[0], a.test_frac, a.seed) + sae = _eval_matrix(buf.codes, buf, names, tr, te, dev, a.steps, a.weight_decay) + den = _eval_matrix(buf.dense, buf, names, tr, te, dev, a.steps, a.weight_decay) if buf.dense is not None else None + h = f"{'label':18s} {'%pos':>6s} | {'SAE single':>10s} {'SAE multi':>9s}" + if den: + h += f" | {'dense single':>12s} {'dense multi':>11s} | {'Δ':>7s}" + print(h) + for n in names: + pos = buf.labels[:, buf.name_idx[n]].mean() + ss, sm = sae[n] + row = f"{n:18s} {pos:6.1%} | {ss:10.3f} {sm:9.3f}" + if den: + ds, dm = den[n] + row += f" | {ds:12.3f} {dm:11.3f} | {ss - ds:+7.3f}" + print(row) + + +def cmd_codon_aa(a): # noqa: D103 + z = np.load(a.acts) + dev = a.device + codon = torch.from_numpy(z["codon"].astype(np.int64)).to(dev) + aa = torch.from_numpy(z["aa"].astype(np.int64)).to(dev) + codon_np = z["codon"].astype(np.int64) + ncod, naa = len(L.CODON_LIST), len(L.AA_LIST) + held = {"L": ["TTA", "TTG"], "S": ["AGT", "AGC"], "R": ["AGA", "AGG"]} + hidx = [L.CODON_TO_IDX[c] for v in held.values() for c in v] + print(f"{'matrix':6s} {'codon mAUROC':>12s} {'AA mAUROC':>10s} | family-disjoint recall L/S/R (chance)") + for nm in ("sae", "dense"): + if nm not in z.files: + continue + X = torch.from_numpy(z[nm]).to(dev).float() + tr, te = split_indices(X.shape[0], a.test_frac, a.seed) + Xz = _z(X, tr) # standardize on the train split only (no test-set leakage) + _, ca, _ = decode_eval(Xz[tr], codon[tr], Xz[te], codon[te], ncod, steps=a.steps, wd=a.weight_decay) + _, aaa, _ = decode_eval(Xz[tr], aa[tr], Xz[te], aa[te], naa, steps=a.steps, wd=a.weight_decay) + trn = torch.from_numpy(np.nonzero(~np.isin(codon_np, hidx))[0]).to(dev) + W, b = fit_softmax(Xz[trn], aa[trn], naa, steps=a.steps, wd=a.weight_decay) + rec = [] + for A, cods in held.items(): + m = np.isin(codon_np, [L.CODON_TO_IDX[c] for c in cods]) + pred = (Xz[torch.from_numpy(np.nonzero(m)[0]).to(dev)] @ W + b).argmax(1).cpu().numpy() + rec.append( + f"{A}={float((pred == L.AA_TO_IDX[A]).mean()):.2f}({float((aa == L.AA_TO_IDX[A]).float().mean()):.2f})" + ) + del X, Xz + torch.cuda.empty_cache() + print(f"{nm:6s} {ca:12.3f} {aaa:10.3f} | {' '.join(rec)}") + + +def cmd_annotate(a): + """Buffer -> feature-annotation parquet: each feature's best concept by AUROC + activation stats. + + The persist step (uses sae.eval.probing.annotate_features). Writes a feature_metadata-style + parquet — {feature_id, label, auroc, activation_freq, max_activation} — the engine/dashboard + load via --feature-annotations. Concepts default to all labels in the buffer (incl. base_*). + """ + import pyarrow as pa + import pyarrow.parquet as pq + from sae.eval.probing import annotate_features + + buf, names = _load(a) + dev = a.device + X = torch.from_numpy(buf.codes).to(dev).float() + Y = torch.stack([torch.from_numpy(buf.labels[:, buf.name_idx[n]]).to(dev) for n in names], 1) + ann = annotate_features(X, Y, names, min_auroc=a.min_auroc) + cols = {"feature_id": [], "label": [], "auroc": [], "activation_freq": [], "max_activation": []} + for r in ann: + col = X[:, r["feature_id"]] + cols["feature_id"].append(r["feature_id"]) + cols["label"].append(r["label"]) + cols["auroc"].append(r["auroc"]) + cols["activation_freq"].append(round(float((col > 0).float().mean()), 6)) + cols["max_activation"].append(round(float(col.max()), 4)) + pq.write_table(pa.table(cols), a.out, compression="snappy") + print(f"[annotate] {len(ann)} features labeled (AUROC >= {a.min_auroc}) over {len(names)} concepts -> {a.out}") + + +# ───────────────────────────────────────── model subcommands (need Evo2) +def _encode_windows(eng, windows, tag_ids, lab_keys, inst_keys, tot, a): + """Stream tiled windows through the SAE -> (code_buf[filled,F], lab{k:bool}, inst{k:long}, fmax[F]). + + Shared by euk-f1 and domain-eval: encodes each window (skipping the phylo-tag prefix) and + fills per-concept label masks (lab_keys) + instance ids (inst_keys). Buffers are trimmed to + the number of positions actually filled. + """ + adev, tlen = a.auroc_device, len(tag_ids) + code_buf = torch.zeros(tot, eng.n_features, dtype=torch.float16, device=adev) + lab = {k: torch.zeros(tot, dtype=torch.bool, device=adev) for k in lab_keys} + inst = {k: torch.full((tot,), -1, dtype=torch.long, device=adev) for k in inst_keys} + filled = 0 + for s0 in range(0, len(windows), a.batch_size): + batch = windows[s0 : s0 + a.batch_size] + with eng._lock: + for h, w in zip(eng._forward_hidden([tag_ids + eng.tokenize(w["dna"]) for w in batch]), batch): + if h.shape[0] == 0: + continue + codes = eng.sae.encode(h.to(a.device)) + take = min(len(w["dna"]), codes.shape[0] - tlen, tot - filled) + if take <= 0: + continue + code_buf[filled : filled + take] = codes[tlen : tlen + take].to(torch.float16).to(adev) + for k in lab: + lab[k][filled : filled + take] = torch.from_numpy(w["labels"][k][:take]).to(adev) + for k in inst: + inst[k][filled : filled + take] = torch.from_numpy(w["instances"][k][:take].astype(np.int64)).to( + adev + ) + filled += take + code_buf = code_buf[:filled] + for d in (lab, inst): + for k in d: + d[k] = d[k][:filled] + fmax = code_buf.max(0).values.float() if filled else torch.zeros(eng.n_features, device=adev) + return code_buf, lab, inst, fmax + + +def cmd_euk(a): + """Eukaryotic exon/intron/CDS domain-adjusted F1 vs shuffle null (chr21 FASTA+GFF).""" + from euk_windows import build_windows + from evo2_sae.core import DEFAULT_ORGANISM_TAGS, Evo2SAE + from sae.eval.probing import domain_f1 + + eng = Evo2SAE(a.evo2_ckpt_dir, a.sae_checkpoint, a.layer, device=a.device).load() + windows, stats, tot, _ = build_windows(a.fasta, a.gff, a.seq_len, a.max_tokens, seed=a.seed) + print( + f"windows={len(windows)} tokens={tot} genes={stats['genes']} exons={stats['exons']} introns={stats['introns']}" + ) + tag_ids = eng.tokenize(DEFAULT_ORGANISM_TAGS.get(a.organism, "")) + code_buf, lab, inst, fmax = _encode_windows( + eng, windows, tag_ids, ("exon", "intron", "cds"), ("exon", "intron", "gene"), tot, a + ) + filled, adev = code_buf.shape[0], a.auroc_device + g = torch.Generator(device=adev).manual_seed(a.seed) + print(f"encoded {filled} positions\n{'concept':8s} {'domF1':>6s} {'null':>6s} {'ratio':>6s} {'%pos':>6s}") + for c, ic in {"exon": "exon", "intron": "intron", "cds": "gene"}.items(): + f1, _ = domain_f1(code_buf, fmax, lab[c], inst[ic]) + order = torch.randperm(filled, generator=g, device=adev) + f1n, _ = domain_f1(code_buf, fmax, lab[c][order], inst[ic][order]) + bf, nl = float(f1.max()), float(f1n.max()) + print(f"{c:8s} {bf:6.3f} {nl:6.3f} {bf / max(nl, 1e-9):6.2f} {float(lab[c].float().mean()):6.1%}") + + +def _parse_track_spec(spec): + """Parse a ``--track NAME=PATH[:GFF_FEATURE]`` spec -> (name, path, feature_type|None).""" + name, rest = spec.split("=", 1) + ftype = None + if ":" in rest: + head, tail = rest.rsplit(":", 1) + if "/" not in tail and "." not in tail: # a GFF feature type, not part of a path + rest, ftype = head, tail + return name, rest, ftype + + +def cmd_domain_eval(a): + """User-supplied annotated dataset -> per-feature domain-F1 (prec/nt, recall/annotation) + AUROC. + + Each ``--track NAME=PATH[:GFF_FEATURE]`` is one concept; its BED/GFF intervals are the + annotation instances (RefSeq/Rfam/JASPAR/ENCODE, or anything the user supplies). The SAE + annotates the windows, then per concept we report the best feature by instance-level + domain-F1 (precision-per-nt, recall-per-annotation) and — threshold-free — by AUROC. + """ + from annot_tracks import label_windows, load_track, read_fasta_dict + from evo2_sae.core import DEFAULT_ORGANISM_TAGS, Evo2SAE + from sae.eval.probing import auroc_all, domain_f1 + + tracks = {} + for spec in a.track: + name, path, ftype = _parse_track_spec(spec) + tracks[name] = load_track(path, feature_type=ftype) + seqs = read_fasta_dict(a.fasta) + windows, stats = label_windows(seqs, tracks, a.seq_len, max_tokens=a.max_tokens) + concepts = stats["concepts"] + + eng = Evo2SAE(a.evo2_ckpt_dir, a.sae_checkpoint, a.layer, device=a.device).load() + tag_ids = eng.tokenize(DEFAULT_ORGANISM_TAGS.get(a.organism, "")) + code_buf, lab, inst, fmax = _encode_windows(eng, windows, tag_ids, concepts, concepts, stats["tokens"], a) + au = auroc_all(code_buf.float().to(a.device), torch.stack([lab[c] for c in concepts], 1).to(a.device)).cpu() + print(f"encoded {code_buf.shape[0]} positions across {len(concepts)} concept(s)") + print( + f"{'concept':14s} {'%pos':>6s} {'#inst':>6s} | " + f"{'domF1':>6s} {'@thr':>5s} {'feat':>7s} | {'AUROC':>6s} {'feat':>7s}" + ) + for i, c in enumerate(concepts): + f1, thr = domain_f1(code_buf, fmax, lab[c], inst[c]) + bi, ai = int(f1.argmax()), int(au[:, i].argmax()) + print( + f"{c:14s} {float(lab[c].float().mean()):6.1%} {stats['n_inst'][c]:6d} | " + f"{float(f1[bi]):6.3f} {float(thr[bi]):5.2f} {bi:7d} | {float(au[ai, i]):6.3f} {ai:7d}" + ) + + +def cmd_extract(a): # noqa: D103 + from evo2_buffer import build_buffer, sample_sequences + from evo2_sae.core import Evo2SAE + + eng = Evo2SAE(a.evo2_ckpt_dir, a.sae_checkpoint, a.layer, device=a.device).load() + label_names = list(L.LABELERS.keys()) + kingdoms = [k for k in a.kingdoms.split(",") if k] + seqs = sample_sequences(a.fasta, a.max_tokens, a.seq_len, kingdoms=kingdoms, seed=a.seed) + print(f"probe set: {len(seqs)} seqs (kingdoms={kingdoms})") + buf = build_buffer( + eng, + seqs, + label_names, + subsample=a.subsample, + auroc_device=a.auroc_device, + annotate_cds=a.annotate_cds, + batch_size=a.batch_size, + log=print, + ) + buf.save(a.out) + print(f"saved buffer -> {a.out} ({buf.codes.shape[0]} x {buf.codes.shape[1]}, dense {buf.dense.shape[1]})") + + +def _add_model_args(p, *, required=(), max_tokens=160_000): + """Shared model + encoding args for the model-backed subcommands (extract/euk-f1/domain-eval).""" + for arg in ("--evo2-ckpt-dir", "--sae-checkpoint", "--fasta", *required): + p.add_argument(arg, required=True) + p.add_argument("--layer", type=int, required=True) + p.add_argument("--max-tokens", type=int, default=max_tokens) + p.add_argument("--seq-len", type=int, default=1024) + p.add_argument("--batch-size", type=int, default=8) + p.add_argument("--auroc-device", default=None, help="device for the AUROC matrix; defaults to --device") + + +def main(): # noqa: D103 + ap = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + sub = ap.add_subparsers(dest="cmd", required=True) + common = argparse.ArgumentParser(add_help=False) + common.add_argument("--device", default="cuda:0") + common.add_argument("--seed", type=int, default=0) + common.add_argument("--steps", type=int, default=400) + common.add_argument("--weight-decay", type=float, default=1e-2) + common.add_argument("--test-frac", type=float, default=0.4) + for name, fn, needs_labels in [ + ("auroc", cmd_auroc, True), + ("linear", cmd_linear, True), + ("codon-aa", cmd_codon_aa, False), + ]: + p = sub.add_parser(name, parents=[common]) + p.add_argument("--acts", required=True) + if needs_labels: + p.add_argument("--labels", required=True) + p.set_defaults(func=fn) + pan = sub.add_parser("annotate", parents=[common]) + pan.add_argument("--acts", required=True) + pan.add_argument("--out", required=True) + pan.add_argument( + "--labels", default=None, help="comma-separated concept subset; default = all labels in the buffer" + ) + pan.add_argument("--min-auroc", type=float, default=0.8) + pan.set_defaults(func=cmd_annotate) + pe = sub.add_parser("extract", parents=[common]) + _add_model_args(pe, required=("--out",), max_tokens=200_000) + pe.add_argument("--kingdoms", default="prok,euk") + pe.add_argument("--annotate-cds", action="store_true") + pe.add_argument("--subsample", type=int, default=50_000) + pe.set_defaults(func=cmd_extract) + pk = sub.add_parser("euk-f1", parents=[common]) + _add_model_args(pk, required=("--gff",)) + pk.add_argument("--organism", default="Human") + pk.set_defaults(func=cmd_euk) + pd = sub.add_parser("domain-eval", parents=[common]) + _add_model_args(pd) + pd.add_argument( + "--track", + action="append", + required=True, + metavar="NAME=PATH[:GFF_FEATURE]", + help="annotation track; BED or GFF intervals = instances of concept NAME. Repeatable " + "(e.g. --track exon=refseq.gff3:exon --track tfbs=jaspar.bed --track cCRE=encode.bed).", + ) + pd.add_argument("--organism", default="Human") + pd.set_defaults(func=cmd_domain_eval) + args = ap.parse_args() + if getattr(args, "auroc_device", None) is None: # default the AUROC matrix to the model device + args.auroc_device = getattr(args, "device", "cuda:0") + torch.set_grad_enabled(False) + args.func(args) + + +if __name__ == "__main__": + main() diff --git a/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe_loss_recovered.py b/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe_loss_recovered.py new file mode 100644 index 0000000000..723a1ec3f5 --- /dev/null +++ b/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe_loss_recovered.py @@ -0,0 +1,152 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Loss recovered (fidelity) for the Evo2 SAE — reuses sae.eval.loss_recovered (Jared Wilber). + + loss_recovered = 1 - (CE_sae - CE_clean) / (CE_zero - CE_clean) + +We just provide Evo2-specific callables to his generic evaluator: + - get_hiddens(batch): capture the layer-`L` residual via a forward hook + - compute_ce(batch, override): full-model next-token CE, optionally patching the + layer-`L` output with `override` (zero-ablation or SAE reconstruction) +The SAE reconstruction is DENORMALIZED per token (normalize_input) so it is in the +raw residual space the layer actually emits. +""" + +from __future__ import annotations + +import argparse +import sys +from pathlib import Path + +import torch +import torch.nn as nn +import torch.nn.functional as Fn + + +_HERE = Path(__file__).resolve().parent +sys.path.insert(0, str(_HERE)) +sys.path.insert(0, str(_HERE.parent)) + +from evo2_buffer import sample_sequences # noqa: E402 +from evo2_sae.core import Evo2SAE # noqa: E402 +from sae.eval.loss_recovered import evaluate_loss_recovered # noqa: E402 (Jared's code) + + +KINGDOM_TAGS = {"prok": "|d__Bacteria|", "euk": "|d__Eukaryota|"} + + +class SAEWrap(nn.Module): + """sae.forward(x[N,H]) -> (recon, codes) in RAW residual space (denormalized).""" + + def __init__(self, sae): # noqa: D107 + super().__init__() + self.sae = sae + + def forward(self, x): # noqa: D102 + s = self.sae + codes = s.encode(x) # encode normalizes internally if normalize_input + recon = s.decoder(codes) + s.pre_bias + if getattr(s, "normalize_input", False): + mu = x.mean(-1, keepdim=True) + std = x.std(-1, keepdim=True) + 1e-8 + recon = recon * std + mu + return recon, codes + + +class L26Hook: # noqa: D101 + def __init__(self): # noqa: D107 + self.mode = "off" # off | capture | replace + self.override = None + self.captured = None + + def __call__(self, module, inp, output): # noqa: D102 + hs = output[0] if isinstance(output, tuple) else output + if self.mode == "replace" and self.override is not None: + new = self.override.to(hs.dtype) + return (new, *output[1:]) if isinstance(output, tuple) else new + if self.mode == "capture": + self.captured = hs.detach() + return output + + +def main(): # noqa: D103 + ap = argparse.ArgumentParser() + ap.add_argument("--evo2-ckpt-dir", required=True) + ap.add_argument("--sae-checkpoint", required=True) + ap.add_argument("--layer", type=int, required=True) + ap.add_argument("--fasta", required=True) + ap.add_argument("--n-seqs", type=int, default=80) + ap.add_argument("--seq-len", type=int, default=1024) + ap.add_argument("--device", default="cuda:0") + ap.add_argument("--seed", type=int, default=0) + args = ap.parse_args() + torch.set_grad_enabled(False) + dev = args.device + + engine = Evo2SAE(args.evo2_ckpt_dir, args.sae_checkpoint, args.layer, device=dev).load() + from megatron.core.utils import unwrap_model + + gen = engine._ensure_gen_model() + layer = unwrap_model(gen).decoder.layers[args.layer] + hook = L26Hook() + layer.register_forward_hook(hook) + + pairs = sample_sequences( + args.fasta, args.n_seqs * args.seq_len, args.seq_len, kingdoms=["prok", "euk"], seed=args.seed + )[: args.n_seqs] + batches = [] + for kingdom, dna in pairs: + ids = engine.tokenize(KINGDOM_TAGS[kingdom] + dna) + if len(ids) > 4: + batches.append(torch.tensor([ids], dtype=torch.long, device=dev)) + if not batches: + raise SystemExit(f"no evaluable sequences (>4 tokens) from {args.fasta}; check the FASTA / --seq-len") + + def fwd(ids): + return gen(input_ids=ids, position_ids=None, attention_mask=None, labels=None, runtime_gather_output=True) + + def get_hiddens(batch): + hook.mode = "capture" + fwd(batch) + hook.mode = "off" + return hook.captured # [S, 1, H] + + def compute_ce(batch, override): + if override is None: + hook.mode = "off" + else: + hook.mode = "replace" + hook.override = override + logits = fwd(batch) + hook.mode = "off" + hook.override = None + lg = logits[0, :-1].float() # [S-1, V] + tgt = batch[0, 1:] + ce = Fn.cross_entropy(lg, tgt, reduction="sum") + return float(ce), int(tgt.numel()) + + with engine._lock: + res = evaluate_loss_recovered(SAEWrap(engine.sae), batches, get_hiddens, compute_ce, device=dev) + print("\n==== Evo2 7B layer-%d SAE — loss recovered ====" % args.layer) + print(res) + print( + f"loss_recovered = {res.loss_recovered:.3f} " + f"(CE clean={res.ce_original:.3f}, SAE={res.ce_sae:.3f}, zero={res.ce_zero:.3f}, n_tok={res.n_tokens})" + ) + + +if __name__ == "__main__": + main() From 7f960def342402df1d5179451f657e3bbc54aa18 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 23 Jun 2026 18:53:58 +0000 Subject: [PATCH 5/6] test(evo2-sae): CPU integration tests for the probe harness + fix loss-recovered engine call Adds end-to-end CPU tests that cross the eval-layer seams the per-layer unit tests miss: buffer save/load -> probe.main() (auroc/annotate/linear), and annot_tracks.label_windows -> domain_f1. Guards the dense-twin round trip (the SAE-vs-dense `linear` comparison only renders if dense survives save->load) and verifies the stale-base harness still runs against #1630's current labelers. Also fixes probe_loss_recovered.py to call the real engine API (Evo2SAE._ensure_engine().model; the previous _ensure_gen_model() does not exist) and adds tests/conftest.py to centralize the scripts/ sys.path insertion. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: Polina Binder --- .../evo2/scripts/probe_loss_recovered.py | 5 +- .../recipes/evo2/tests/conftest.py | 29 +++ .../evo2/tests/test_probe_integration.py | 174 ++++++++++++++++++ 3 files changed, 207 insertions(+), 1 deletion(-) create mode 100644 interpretability/sparse_autoencoders/recipes/evo2/tests/conftest.py create mode 100644 interpretability/sparse_autoencoders/recipes/evo2/tests/test_probe_integration.py diff --git a/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe_loss_recovered.py b/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe_loss_recovered.py index 723a1ec3f5..9563cabca0 100644 --- a/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe_loss_recovered.py +++ b/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe_loss_recovered.py @@ -99,7 +99,10 @@ def main(): # noqa: D103 engine = Evo2SAE(args.evo2_ckpt_dir, args.sae_checkpoint, args.layer, device=dev).load() from megatron.core.utils import unwrap_model - gen = engine._ensure_gen_model() + # The engine builds the generation model lazily inside its inference components + # (Evo2SAE._ensure_engine() -> gen_components; the megatron model is .model). We call it + # directly below for teacher-forced CE, and hook its layer-`args.layer` output. + gen = engine._ensure_engine().model layer = unwrap_model(gen).decoder.layers[args.layer] hook = L26Hook() layer.register_forward_hook(hook) diff --git a/interpretability/sparse_autoencoders/recipes/evo2/tests/conftest.py b/interpretability/sparse_autoencoders/recipes/evo2/tests/conftest.py new file mode 100644 index 0000000000..f4f15b619e --- /dev/null +++ b/interpretability/sparse_autoencoders/recipes/evo2/tests/conftest.py @@ -0,0 +1,29 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared test setup: put the recipe's ``scripts/`` on ``sys.path``. + +The eval modules (``labelers``, ``annot_tracks``, ``euk_windows``, ``probe`` …) live in +``scripts/`` and are run as standalone scripts, not an installed package, so tests import +them by name. Centralizing the path insertion here keeps it out of every test module. +""" + +import sys +from pathlib import Path + + +_SCRIPTS = Path(__file__).resolve().parent.parent / "scripts" +if str(_SCRIPTS) not in sys.path: + sys.path.insert(0, str(_SCRIPTS)) diff --git a/interpretability/sparse_autoencoders/recipes/evo2/tests/test_probe_integration.py b/interpretability/sparse_autoencoders/recipes/evo2/tests/test_probe_integration.py new file mode 100644 index 0000000000..2b9a23ecd2 --- /dev/null +++ b/interpretability/sparse_autoencoders/recipes/evo2/tests/test_probe_integration.py @@ -0,0 +1,174 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""End-to-end CPU integration tests for the probing harness (no model). + +These cover the *seams* between the eval layers that per-layer unit tests miss: + + * ``sae.eval.probing`` (#1629) <-> the probe CLI (#1636): a buffer is written by + ``ActivationBuffer.save``, reloaded, and scored through the real ``probe.main()`` + dispatch (``auroc`` / ``annotate`` / ``linear``). This is the path that silently + broke when ``save`` dropped the dense twin — the ``linear`` SAE-vs-dense comparison + only renders if ``dense`` survives the round trip. + * ``annot_tracks.label_windows`` (#1630) <-> ``domain_f1`` (#1629): interval tracks -> + per-token mask + instance ids -> instance-level F1, with no model in the loop. + +A feature is *planted* to track one concept so the metrics have a known right answer; +everything else is noise, so a green run means the whole pipeline carried the signal. + +``scripts/`` is on ``sys.path`` via ``conftest.py``. +""" + +import sys + +import numpy as np +import probe +import pyarrow.parquet as pq +from annot_tracks import label_windows +from sae.eval.probing import ActivationBuffer, domain_f1 + + +# Feature index deliberately planted to track the "planted" concept. +_PLANTED_FEATURE = 3 + + +def _planted_buffer(path, n=400, n_features=6, hidden=8, seed=0): + """Write a synthetic buffer where ``_PLANTED_FEATURE`` cleanly separates the "planted" label. + + Includes a dense twin + per-concept instance ids so the save/load round trip and the + SAE-vs-dense ``linear`` path are both exercised. Returns the in-memory buffer too. + """ + rng = np.random.default_rng(seed) + label_names = ["planted", "noise"] + labels = np.zeros((n, len(label_names)), dtype=bool) + labels[:, 0] = rng.random(n) < 0.4 # planted positives (~40%) + labels[:, 1] = rng.random(n) < 0.4 # uncorrelated noise concept + + codes = (rng.random((n, n_features)) * 0.1).astype(np.float16) # background noise + pos = labels[:, 0] + codes[pos, _PLANTED_FEATURE] = (rng.random(int(pos.sum())) * 5 + 5).astype(np.float16) # strong, positives only + + dense = (rng.random((n, hidden))).astype(np.float16) + # one instance per run of 5 positive tokens; -1 outside the concept + inst = np.full(n, -1, np.int32) + inst[pos] = (np.cumsum(pos)[pos] - 1) // 5 + instances = {"planted": inst} + + buf = ActivationBuffer(codes=codes, labels=labels, label_names=label_names, dense=dense, instances=instances) + buf.save(str(path)) + return buf + + +def _run_cli(monkeypatch, *argv): + """Drive the real probe.py CLI (arg-parse -> dispatch) the way a user would.""" + monkeypatch.setattr(sys, "argv", ["probe.py", *argv]) + probe.main() + + +# --------------------------------------------------------------- probing <-> CLI buffer seam +def test_buffer_roundtrip_preserves_dense_and_instances(tmp_path): + """save() -> load() must carry the dense twin + instance ids (regression: they were dropped).""" + buf = _planted_buffer(tmp_path / "buf.npz") + lo = ActivationBuffer.load(str(tmp_path / "buf.npz")) + assert np.array_equal(lo.codes, buf.codes) and np.array_equal(lo.labels, buf.labels) + assert lo.dense is not None and np.array_equal(lo.dense, buf.dense) + assert lo.instances is not None and np.array_equal(lo.instances["planted"], buf.instances["planted"]) + + +def test_annotate_cli_labels_planted_feature(tmp_path, monkeypatch, capsys): + """``annotate`` over a reloaded buffer writes a parquet that labels the planted feature.""" + _planted_buffer(tmp_path / "buf.npz") + out = tmp_path / "feature_annotations.parquet" + _run_cli( + monkeypatch, + "annotate", + "--acts", + str(tmp_path / "buf.npz"), + "--out", + str(out), + "--min-auroc", + "0.85", + "--device", + "cpu", + ) + tbl = pq.read_table(out).to_pydict() + assert set(tbl) == {"feature_id", "label", "auroc", "activation_freq", "max_activation"} + rows = {fid: (lab, au) for fid, lab, au in zip(tbl["feature_id"], tbl["label"], tbl["auroc"])} + assert _PLANTED_FEATURE in rows, "planted feature not annotated" + label, auroc = rows[_PLANTED_FEATURE] + assert label == "planted" and auroc >= 0.85 + + +def test_auroc_cli_recovers_planted_feature(tmp_path, monkeypatch, capsys): + """``auroc`` ranks the planted feature ~1.0 for the planted concept (parsed from the table).""" + _planted_buffer(tmp_path / "buf.npz") + _run_cli(monkeypatch, "auroc", "--acts", str(tmp_path / "buf.npz"), "--labels", "planted,noise", "--device", "cpu") + line = next(ln for ln in capsys.readouterr().out.splitlines() if ln.startswith("planted")) + _, _pct, best_auroc, feature = line.split() + assert float(best_auroc) >= 0.95 and int(feature) == _PLANTED_FEATURE + + +def test_linear_cli_emits_dense_comparison(tmp_path, monkeypatch, capsys): + """``linear`` prints the dense columns — proof the dense twin survived save->load into the probe.""" + _planted_buffer(tmp_path / "buf.npz") + _run_cli( + monkeypatch, + "linear", + "--acts", + str(tmp_path / "buf.npz"), + "--labels", + "planted", + "--device", + "cpu", + "--steps", + "100", + ) + out = capsys.readouterr().out + assert "dense single" in out and "Δ" in out # the SAE-vs-dense comparison only renders if dense loaded + planted = next(ln for ln in out.splitlines() if ln.startswith("planted")) + sae_single = float(planted.split("|")[1].split()[0]) + assert sae_single >= 0.95 # the planted feature separates the concept under the linear probe too + + +# --------------------------------------------------- labels (#1630) <-> domain_f1 (#1629) seam +def test_label_windows_feed_domain_f1(tmp_path): + """annot_tracks windows (mask + instance ids) drive instance-level domain_f1 end to end. + + A feature planted to fire exactly on the concept mask must beat a shuffled-label null. + """ + import torch + + seqs = {"chr1": "ACGT" * 300} # 1200 bp + tracks = {"site": {"chr1": [(20, 90), (300, 380), (700, 760)]}} # three instances + windows, stats = label_windows(seqs, tracks, seq_len=200) + assert windows and stats["n_inst"]["site"] == 3 + + mask = np.concatenate([w["labels"]["site"] for w in windows]) + inst = np.concatenate([w["instances"]["site"] for w in windows]) + n = mask.shape[0] + rng = np.random.default_rng(0) + codes = (rng.random((n, 4)) * 0.1).astype(np.float32) + codes[mask, 0] = 5.0 # feature 0 fires on the concept, nowhere else + + codes_t = torch.from_numpy(codes) + fmax = codes_t.max(0).values + mask_t = torch.from_numpy(mask) + inst_t = torch.from_numpy(inst.astype(np.int64)) + f1, _thr = domain_f1(codes_t, fmax, mask_t, inst_t) + + order = torch.randperm(n, generator=torch.Generator().manual_seed(0)) + f1_null, _ = domain_f1(codes_t, fmax, mask_t[order], inst_t[order]) + assert int(f1.argmax()) == 0 # the planted feature wins + assert float(f1.max()) > 2 * float(f1_null.max()) + 1e-6 # and clears the shuffled null From b145999e03c22825482333aec8fe2ea4f4b4f5a6 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 23 Jun 2026 19:06:58 +0000 Subject: [PATCH 6/6] refactor(evo2-sae): dedup engine encode loop, drop dead codon-aa, reuse SAE forward MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Extract evo2_buffer.forward_codes(engine, id_lists) — the one place that touches the engine internals (locked GPU forward + SAE encode). build_buffer and probe._encode_windows both use it, so the #1622 engine-API coupling lives in a single spot, and the per-token label/buffer work moves out of the GPU lock. Add a CPU unit test (fake engine) for the helper's contract. - Hoist KINGDOM_TAGS to evo2_buffer (was duplicated in probe_loss_recovered). - Remove the `codon-aa` subcommand: it consumed a codon/aa npz no command produces (and was the only raw np.load); drop it and its now-unused decode_eval/fit_softmax imports until a producer exists. - SAEWrap delegates to the SAE's own forward() (top-k + normalize_input denormalization) instead of hand-rolling decoder(codes)+pre_bias and mean/std — the path the steering hook uses, so the loss-recovered recon can't drift from the SAE's actual (de)normalization. - Make evo2_buffer importable without the evo2_sae engine (lazy read_fasta), so the CPU tests exercise forward_codes and the harness imports cleanly. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: Polina Binder --- .../recipes/evo2/scripts/evo2_buffer.py | 72 +++++++++++-------- .../recipes/evo2/scripts/probe.py | 66 ++++------------- .../evo2/scripts/probe_loss_recovered.py | 21 +++--- .../evo2/tests/test_probe_integration.py | 32 ++++++++- 4 files changed, 95 insertions(+), 96 deletions(-) diff --git a/interpretability/sparse_autoencoders/recipes/evo2/scripts/evo2_buffer.py b/interpretability/sparse_autoencoders/recipes/evo2/scripts/evo2_buffer.py index 8a0d68fbeb..e7ee24c650 100644 --- a/interpretability/sparse_autoencoders/recipes/evo2/scripts/evo2_buffer.py +++ b/interpretability/sparse_autoencoders/recipes/evo2/scripts/evo2_buffer.py @@ -28,15 +28,30 @@ import labelers as L import numpy as np import torch -from evo2_sae.fasta import read_fasta # shared reader (also handles .gz); kingdom check uses the id prefix from sae.eval.probing import ActivationBuffer KINGDOM_TAGS = {"prok": "|d__Bacteria|", "euk": "|d__Eukaryota|"} +@torch.no_grad() +def forward_codes(engine, id_lists): + """Run token-id lists through the engine -> ``[(hidden[S,H], sae_codes[S,F])]`` on engine.device. + + The single place the probing harness touches the engine internals — the GPU forward plus the + SAE encode, serialized by the engine lock. ``build_buffer`` (here) and ``probe._encode_windows`` + both go through this, so the coupling to the engine API lives in exactly one spot. + """ + dev = engine.device + with engine._lock: + return [(h, engine.sae.encode(h.to(dev))) for h in engine._forward_hidden(id_lists)] + + def sample_sequences(fasta, max_tokens, seq_len, kingdoms=("prok", "euk"), seed=0): # noqa: D103 + # Imported here (not at module top) so this module is usable without the evo2_sae engine + # installed — e.g. the CPU tests that exercise forward_codes / labels. from evo2_sae.core import clean_dna + from evo2_sae.fasta import read_fasta # shared reader; also handles .gz kingdoms = list(kingdoms) pools = {k: [] for k in kingdoms} @@ -87,34 +102,33 @@ def build_buffer(engine, seqs, label_names, *, subsample, auroc_device, annotate tids = engine.tokenize(tag) id_lists.append(tids + engine.tokenize(dna)) metas.append((tag, len(tids), kg, dna)) - with engine._lock: - hiddens = engine._forward_hidden(id_lists) - for h, (tag, tlen, kg, dna) in zip(hiddens, metas): - if h.shape[0] == 0 or filled >= S: - continue - hd = h.to(dev) - codes = engine.sae.encode(hd) - norm = h.float().norm(dim=-1).cpu().numpy() - T = codes.shape[0] - cds_mask = cds_frame = gene_starts = None - if annotate_cds and kg == "prok": - cds_mask, cds_frame, gene_starts = L.predict_cds(dna) - ctx = L.SeqContext( - text=(tag + dna)[:T], - tag_len=tlen, - dna=dna, - kingdom=kg, - hidden_norm=norm[:T], - cds_mask=cds_mask, - cds_frame=cds_frame, - gene_starts=gene_starts, - ) - lab = np.stack([L.LABELERS[n](ctx)[:T] for n in label_names], axis=1) - take = min(T, S - filled) - code_buf[filled : filled + take] = codes[:take].to(torch.float16).to(auroc_device) - dense_buf[filled : filled + take] = hd[:take].to(torch.float16).to(auroc_device) - lab_buf[filled : filled + take] = torch.from_numpy(lab[:take]).to(auroc_device) - filled += take + # GPU forward + SAE encode happen inside forward_codes (engine-locked); the per-token + # label computation + buffer fills below are CPU/copy work and need no lock. + for (h, codes), (tag, tlen, kg, dna) in zip(forward_codes(engine, id_lists), metas): + if h.shape[0] == 0 or filled >= S: + continue + hd = h.to(dev) + norm = h.float().norm(dim=-1).cpu().numpy() + T = codes.shape[0] + cds_mask = cds_frame = gene_starts = None + if annotate_cds and kg == "prok": + cds_mask, cds_frame, gene_starts = L.predict_cds(dna) + ctx = L.SeqContext( + text=(tag + dna)[:T], + tag_len=tlen, + dna=dna, + kingdom=kg, + hidden_norm=norm[:T], + cds_mask=cds_mask, + cds_frame=cds_frame, + gene_starts=gene_starts, + ) + lab = np.stack([L.LABELERS[n](ctx)[:T] for n in label_names], axis=1) + take = min(T, S - filled) + code_buf[filled : filled + take] = codes[:take].to(torch.float16).to(auroc_device) + dense_buf[filled : filled + take] = hd[:take].to(torch.float16).to(auroc_device) + lab_buf[filled : filled + take] = torch.from_numpy(lab[:take]).to(auroc_device) + filled += take if (start // batch_size) % 10 == 0: log(f" {start + len(batch)}/{len(seqs)} seqs | buf {filled}/{S}") return ActivationBuffer( diff --git a/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe.py b/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe.py index 5c1b69aa95..564c743d84 100644 --- a/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe.py +++ b/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe.py @@ -20,7 +20,6 @@ probe.py auroc --acts BUF --labels .. per-feature AUROC table (prints) probe.py annotate --acts BUF --out P assign each feature its best concept -> annotation parquet probe.py linear --acts BUF --labels .. SAE-vs-dense single + multi (disentanglement/distributed) - probe.py codon-aa --acts CODON_BUF codon/AA decoders + family-disjoint, SAE vs dense probe.py euk-f1 --fasta .. --gff .. RefSeq gene-structure domain-F1 (needs the model) probe.py domain-eval --fasta .. --track .. user annotated dataset -> per-feature domain-F1 + AUROC vs any BED/GFF tracks (RefSeq/Rfam/JASPAR/ENCODE) (needs the model) @@ -65,13 +64,12 @@ sys.path.insert(0, str(_HERE.parents[2] / "sae" / "src")) # sparse_autoencoders/sae/src import labelers as L # noqa: E402 +from evo2_buffer import forward_codes # noqa: E402 (engine-free import; only the call needs a model) from sae.eval.probing import ( # noqa: E402 ActivationBuffer, auroc_all, auroc_vec, best_single_train_test, - decode_eval, - fit_softmax, split_indices, standardize, ) @@ -149,38 +147,6 @@ def cmd_linear(a): # noqa: D103 print(row) -def cmd_codon_aa(a): # noqa: D103 - z = np.load(a.acts) - dev = a.device - codon = torch.from_numpy(z["codon"].astype(np.int64)).to(dev) - aa = torch.from_numpy(z["aa"].astype(np.int64)).to(dev) - codon_np = z["codon"].astype(np.int64) - ncod, naa = len(L.CODON_LIST), len(L.AA_LIST) - held = {"L": ["TTA", "TTG"], "S": ["AGT", "AGC"], "R": ["AGA", "AGG"]} - hidx = [L.CODON_TO_IDX[c] for v in held.values() for c in v] - print(f"{'matrix':6s} {'codon mAUROC':>12s} {'AA mAUROC':>10s} | family-disjoint recall L/S/R (chance)") - for nm in ("sae", "dense"): - if nm not in z.files: - continue - X = torch.from_numpy(z[nm]).to(dev).float() - tr, te = split_indices(X.shape[0], a.test_frac, a.seed) - Xz = _z(X, tr) # standardize on the train split only (no test-set leakage) - _, ca, _ = decode_eval(Xz[tr], codon[tr], Xz[te], codon[te], ncod, steps=a.steps, wd=a.weight_decay) - _, aaa, _ = decode_eval(Xz[tr], aa[tr], Xz[te], aa[te], naa, steps=a.steps, wd=a.weight_decay) - trn = torch.from_numpy(np.nonzero(~np.isin(codon_np, hidx))[0]).to(dev) - W, b = fit_softmax(Xz[trn], aa[trn], naa, steps=a.steps, wd=a.weight_decay) - rec = [] - for A, cods in held.items(): - m = np.isin(codon_np, [L.CODON_TO_IDX[c] for c in cods]) - pred = (Xz[torch.from_numpy(np.nonzero(m)[0]).to(dev)] @ W + b).argmax(1).cpu().numpy() - rec.append( - f"{A}={float((pred == L.AA_TO_IDX[A]).mean()):.2f}({float((aa == L.AA_TO_IDX[A]).float().mean()):.2f})" - ) - del X, Xz - torch.cuda.empty_cache() - print(f"{nm:6s} {ca:12.3f} {aaa:10.3f} | {' '.join(rec)}") - - def cmd_annotate(a): """Buffer -> feature-annotation parquet: each feature's best concept by AUROC + activation stats. @@ -224,22 +190,19 @@ def _encode_windows(eng, windows, tag_ids, lab_keys, inst_keys, tot, a): filled = 0 for s0 in range(0, len(windows), a.batch_size): batch = windows[s0 : s0 + a.batch_size] - with eng._lock: - for h, w in zip(eng._forward_hidden([tag_ids + eng.tokenize(w["dna"]) for w in batch]), batch): - if h.shape[0] == 0: - continue - codes = eng.sae.encode(h.to(a.device)) - take = min(len(w["dna"]), codes.shape[0] - tlen, tot - filled) - if take <= 0: - continue - code_buf[filled : filled + take] = codes[tlen : tlen + take].to(torch.float16).to(adev) - for k in lab: - lab[k][filled : filled + take] = torch.from_numpy(w["labels"][k][:take]).to(adev) - for k in inst: - inst[k][filled : filled + take] = torch.from_numpy(w["instances"][k][:take].astype(np.int64)).to( - adev - ) - filled += take + id_lists = [tag_ids + eng.tokenize(w["dna"]) for w in batch] + for (h, codes), w in zip(forward_codes(eng, id_lists), batch): + if h.shape[0] == 0: + continue + take = min(len(w["dna"]), codes.shape[0] - tlen, tot - filled) + if take <= 0: + continue + code_buf[filled : filled + take] = codes[tlen : tlen + take].to(torch.float16).to(adev) + for k in lab: + lab[k][filled : filled + take] = torch.from_numpy(w["labels"][k][:take]).to(adev) + for k in inst: + inst[k][filled : filled + take] = torch.from_numpy(w["instances"][k][:take].astype(np.int64)).to(adev) + filled += take code_buf = code_buf[:filled] for d in (lab, inst): for k in d: @@ -369,7 +332,6 @@ def main(): # noqa: D103 for name, fn, needs_labels in [ ("auroc", cmd_auroc, True), ("linear", cmd_linear, True), - ("codon-aa", cmd_codon_aa, False), ]: p = sub.add_parser(name, parents=[common]) p.add_argument("--acts", required=True) diff --git a/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe_loss_recovered.py b/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe_loss_recovered.py index 9563cabca0..f4b6cfd087 100644 --- a/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe_loss_recovered.py +++ b/interpretability/sparse_autoencoders/recipes/evo2/scripts/probe_loss_recovered.py @@ -40,30 +40,25 @@ sys.path.insert(0, str(_HERE)) sys.path.insert(0, str(_HERE.parent)) -from evo2_buffer import sample_sequences # noqa: E402 +from evo2_buffer import KINGDOM_TAGS, sample_sequences # noqa: E402 from evo2_sae.core import Evo2SAE # noqa: E402 from sae.eval.loss_recovered import evaluate_loss_recovered # noqa: E402 (Jared's code) -KINGDOM_TAGS = {"prok": "|d__Bacteria|", "euk": "|d__Eukaryota|"} - - class SAEWrap(nn.Module): - """sae.forward(x[N,H]) -> (recon, codes) in RAW residual space (denormalized).""" + """Wrap the SAE so ``forward(x[N,H]) -> (recon, codes)`` with recon on the raw residual scale. + + Delegates to the SAE's own ``forward``, which applies top-k and denormalizes to the input + scale when ``normalize_input`` is set — the same code path the steering hook uses. Computing + the reconstruction here by hand risks drifting from the SAE's actual (de)normalization. + """ def __init__(self, sae): # noqa: D107 super().__init__() self.sae = sae def forward(self, x): # noqa: D102 - s = self.sae - codes = s.encode(x) # encode normalizes internally if normalize_input - recon = s.decoder(codes) + s.pre_bias - if getattr(s, "normalize_input", False): - mu = x.mean(-1, keepdim=True) - std = x.std(-1, keepdim=True) + 1e-8 - recon = recon * std + mu - return recon, codes + return self.sae(x) class L26Hook: # noqa: D101 diff --git a/interpretability/sparse_autoencoders/recipes/evo2/tests/test_probe_integration.py b/interpretability/sparse_autoencoders/recipes/evo2/tests/test_probe_integration.py index 2b9a23ecd2..50267b5eaa 100644 --- a/interpretability/sparse_autoencoders/recipes/evo2/tests/test_probe_integration.py +++ b/interpretability/sparse_autoencoders/recipes/evo2/tests/test_probe_integration.py @@ -31,12 +31,16 @@ ``scripts/`` is on ``sys.path`` via ``conftest.py``. """ +import contextlib import sys import numpy as np import probe import pyarrow.parquet as pq +import torch from annot_tracks import label_windows +from evo2_buffer import forward_codes +from sae.architectures import TopKSAE from sae.eval.probing import ActivationBuffer, domain_f1 @@ -142,14 +146,38 @@ def test_linear_cli_emits_dense_comparison(tmp_path, monkeypatch, capsys): assert sae_single >= 0.95 # the planted feature separates the concept under the linear probe too +# ----------------------------------------------- shared engine->codes helper (evo2_buffer) +class _FakeEngine: + """Minimal stand-in for the Evo2SAE engine: a real SAE + random hidden states, no model.""" + + def __init__(self, sae): + self.device = "cpu" + self.sae = sae + self._lock = contextlib.nullcontext() # no GPU to serialize in the CPU test + + def _forward_hidden(self, id_lists): + h = self.sae.pre_bias.shape[0] + return [torch.randn(len(ids), h) for ids in id_lists] + + +def test_forward_codes_pairs_hidden_with_sae_codes(): + """forward_codes returns (hidden, codes) per input, codes == the SAE's own encode of hidden.""" + sae = TopKSAE(input_dim=8, hidden_dim=16, top_k=4, normalize_input=False) + eng = _FakeEngine(sae) + id_lists = [[1, 2, 3], [4, 5]] + out = forward_codes(eng, id_lists) + assert len(out) == len(id_lists) + for (h, codes), ids in zip(out, id_lists): + assert h.shape == (len(ids), 8) and codes.shape == (len(ids), 16) + assert torch.allclose(codes, sae.encode(h)) + + # --------------------------------------------------- labels (#1630) <-> domain_f1 (#1629) seam def test_label_windows_feed_domain_f1(tmp_path): """annot_tracks windows (mask + instance ids) drive instance-level domain_f1 end to end. A feature planted to fire exactly on the concept mask must beat a shuffled-label null. """ - import torch - seqs = {"chr1": "ACGT" * 300} # 1200 bp tracks = {"site": {"chr1": [(20, 90), (300, 380), (700, 760)]}} # three instances windows, stats = label_windows(seqs, tracks, seq_len=200)