From 3557958bef78f2a7c31f6dde0198bccd6c6e4769 Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Tue, 23 Jun 2026 06:17:52 +0000 Subject: [PATCH 1/3] evo2 SAE eval (1/2): DNA label producers, rebased onto migrated #1629 Re-lands #1630 on the post-#1633 layout, on top of the rebased #1629: the DNA label producers (scripts/{labelers,annot_tracks,euk_windows}.py) that emit per-token concept labels (genes/exons/ motifs) to fill #1629's ActivationBuffer, + biopython dep (genetic code in labelers.py). Validated: tests/{test_labelers,test_annot_tracks}.py -> 8 passed (CPU). Co-Authored-By: Claude Opus 4.8 Signed-off-by: Polina Binder --- .../recipes/evo2/pyproject.toml | 1 + .../recipes/evo2/scripts/annot_tracks.py | 146 +++++++ .../recipes/evo2/scripts/euk_windows.py | 250 ++++++++++++ .../recipes/evo2/scripts/labelers.py | 367 ++++++++++++++++++ .../recipes/evo2/tests/test_annot_tracks.py | 76 ++++ .../recipes/evo2/tests/test_labelers.py | 53 +++ 6 files changed, 893 insertions(+) create mode 100644 interpretability/sparse_autoencoders/recipes/evo2/scripts/annot_tracks.py create mode 100644 interpretability/sparse_autoencoders/recipes/evo2/scripts/euk_windows.py create mode 100644 interpretability/sparse_autoencoders/recipes/evo2/scripts/labelers.py create mode 100644 interpretability/sparse_autoencoders/recipes/evo2/tests/test_annot_tracks.py create mode 100644 interpretability/sparse_autoencoders/recipes/evo2/tests/test_labelers.py diff --git a/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml b/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml index f7b23d8eff..ba794c5e74 100644 --- a/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml +++ b/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml @@ -13,6 +13,7 @@ dependencies = [ "torch>=2.0", "numpy>=1.20", "pyarrow>=23.0.0", + "biopython>=1.80", # genetic code / translation in labelers.py ] # No package code lives here yet — the recipe is just an entry-point for diff --git a/interpretability/sparse_autoencoders/recipes/evo2/scripts/annot_tracks.py b/interpretability/sparse_autoencoders/recipes/evo2/scripts/annot_tracks.py new file mode 100644 index 0000000000..88994c2113 --- /dev/null +++ b/interpretability/sparse_autoencoders/recipes/evo2/scripts/annot_tracks.py @@ -0,0 +1,146 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Generic interval-track loader for the "user-supplied annotated dataset" eval. + +The user hands in an annotated dataset: a FASTA of sequences + one or more annotation +tracks (BED or GFF) naming intervals — RefSeq genes/exons, Rfam ncRNA, JASPAR TFBS, +ENCODE cCREs, etc. Each interval is one annotation **instance**. This module tiles the +sequences into windows and produces, per concept, a per-token boolean mask + per-token +**global** instance IDs (stable across the windows an interval spans) — exactly the +inputs `sae.eval.probing.domain_f1` (recall-per-instance) and `auroc_all` (per-feature) +consume. No model here; the SAE-encode step lives in the probe CLI (`probe.py domain-eval`). + +This is the generic sibling of `euk_windows.py` (which decomposes RefSeq gene models into +exon/intron/cds). Both feed the same shared scorers. +""" + +from __future__ import annotations + +import gzip +from collections import defaultdict + + +def _open(path): + """Open a path for text reading, transparently handling ``.gz``.""" + return (gzip.open if str(path).endswith(".gz") else open)(path, "rt") + + +def read_fasta_dict(path: str) -> dict[str, str]: + """Read a (multi-record) FASTA into ``{seq_id: sequence}`` (``.gz`` transparent). + + ``seq_id`` is the first whitespace token of the header — matches the chrom/seqid of BED/GFF. + """ + seqs: dict[str, str] = {} + name, parts = None, [] + with _open(path) as fh: + for line in fh: + line = line.rstrip() + if line.startswith(">"): + if name is not None: + seqs[name] = "".join(parts) + tok = line[1:].split() + name, parts = (tok[0] if tok else f"seq_{len(seqs)}"), [] + elif line: + parts.append(line) + if name is not None: + seqs[name] = "".join(parts) + return seqs + + +def _intervals(path, fmt, feature_type=None): + """Yield (seqid, start0, end0) from BED (0-based) or GFF/GTF (1-based -> 0-based half-open). + + GFF rows are optionally filtered to a single column-3 ``feature_type`` (e.g. ``exon``). + """ + chrom_i, start_i, end_i, off = (0, 1, 2, 0) if fmt == "bed" else (0, 3, 4, 1) + with _open(path) as fh: + for line in fh: + if not line.strip() or line[0] == "#" or line.startswith(("track", "browser")): + continue + f = line.split("\t") + if len(f) <= end_i or (feature_type and fmt != "bed" and f[2] != feature_type): + continue + yield f[chrom_i], int(f[start_i]) - off, int(f[end_i]) + + +def load_track(path, feature_type=None, fmt=None): + """Load one annotation track into ``{seqid: [(start0, end0), ...]}`` (0-based half-open, sorted). + + ``fmt`` (``bed``/``gff``) is inferred from the extension; ``feature_type`` filters GFF column 3. + Every interval is one annotation instance. + """ + fmt = fmt or ("gff" if str(path).replace(".gz", "").endswith((".gff", ".gff3", ".gtf")) else "bed") + by_seq = defaultdict(list) + for chrom, s, e in _intervals(path, fmt, feature_type): + if e > s: + by_seq[chrom].append((s, e)) + return {k: sorted(v) for k, v in by_seq.items()} + + +def label_windows(seqs, tracks, seq_len=1024, stride=None, max_tokens=None, min_n_frac=0.5): + """Tile sequences into windows, labeling each position per concept (mask + global instance id). + + Args: + seqs: ``{seqid: dna_str}``. + tracks: ``{concept: {seqid: [(start0, end0), ...]}}`` (e.g. from `load_track`). + seq_len: window length in bp. + stride: step between windows (defaults to non-overlapping = seq_len). + max_tokens: stop once this many positions are emitted (None = all). + min_n_frac: skip windows whose ``N`` fraction exceeds this. + + Returns: + (windows, stats). Each window is ``{"dna": str, "labels": {concept: bool[L]}, + "instances": {concept: int32[L]}}``. Each interval gets one global id, stable across + the windows it spans, so `domain_f1`'s recall-per-instance counts a split interval once. + """ + import numpy as np + + stride = stride or seq_len + concepts = list(tracks.keys()) + # assign a global instance id to every interval, per concept + concept_iv: dict[str, dict[str, list[tuple[int, int, int]]]] = {} + n_inst: dict[str, int] = {} + for concept in concepts: + gid = 0 + cc: dict[str, list[tuple[int, int, int]]] = {} + for seqid, ivs in tracks[concept].items(): + cc[seqid] = [(s, e, (gid := gid + 1) - 1) for (s, e) in ivs] + concept_iv[concept] = cc + n_inst[concept] = gid + + windows, tot = [], 0 + for seqid, dna in seqs.items(): + dna = dna.upper() + N = len(dna) + for w0 in range(0, max(1, N - seq_len + 1), stride): + w1 = min(N, w0 + seq_len) + sub = dna[w0:w1] + L = w1 - w0 + if L < 60 or sub.count("N") > min_n_frac * L: + continue + labels = {c: np.zeros(L, bool) for c in concepts} + inst = {c: np.full(L, -1, np.int32) for c in concepts} + for c in concepts: + for s, e, gid in concept_iv[c].get(seqid, []): + if e <= w0 or s >= w1: + continue + labels[c][max(s, w0) - w0 : min(e, w1) - w0] = True + inst[c][max(s, w0) - w0 : min(e, w1) - w0] = gid + windows.append({"dna": sub, "labels": labels, "instances": inst}) + tot += L + if max_tokens and tot >= max_tokens: + return windows, {"tokens": tot, "n_inst": n_inst, "concepts": concepts} + return windows, {"tokens": tot, "n_inst": n_inst, "concepts": concepts} diff --git a/interpretability/sparse_autoencoders/recipes/evo2/scripts/euk_windows.py b/interpretability/sparse_autoencoders/recipes/evo2/scripts/euk_windows.py new file mode 100644 index 0000000000..0f381df84b --- /dev/null +++ b/interpretability/sparse_autoencoders/recipes/evo2/scripts/euk_windows.py @@ -0,0 +1,250 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Build instance-level exon/intron/CDS-labeled windows from a genome FASTA + GFF3. + +Eukaryotic gene-structure annotation for SAE feature probing. Unlike the +sequence-derived labelers, these labels come from real gene models, and crucially +carry *instance IDs* (which exon / which intron / which gene each position belongs +to) so domain-adjusted F1 can compute recall PER ANNOTATION INSTANCE (a feature +"recalls" an exon if it fires anywhere inside it), not per position. + +For each protein-coding gene we take a representative transcript (longest by total +exon length), tile its span ± flank into windows, and label every position: + exon / intron / cds / utr / intergenic (+ per-position instance IDs for + exon, intron, gene) + +`python euk_windows.py --fasta chr21.fa --gff chr21.gff3 --dry-run` prints +coverage stats without building sequences. +""" + +from __future__ import annotations + +import argparse +import random +from collections import defaultdict + +import numpy as np + + +def _attrs(s): + return dict(kv.split("=", 1) for kv in s.strip().split(";") if "=" in kv) + + +def parse_gff(gff_path): + """Return {gene_id: {strand, tx: {tx_id: {'exon': [(s,e)], 'cds': [(s,e)]}}}} (protein_coding).""" + gene_strand, gene_biotype = {}, {} + tx_gene, tx_biotype = {}, {} + tx_exon = defaultdict(list) + tx_cds = defaultdict(list) + with open(gff_path) as fh: + for line in fh: + if line.startswith("#"): + continue + f = line.rstrip("\n").split("\t") + if len(f) < 9: + continue + typ, s, e, strand, attr = f[2], int(f[3]), int(f[4]), f[6], f[8] + a = _attrs(attr) + if typ == "gene": + gid = a.get("ID", "").replace("gene:", "") + gene_strand[gid] = strand + gene_biotype[gid] = a.get("biotype", "") + elif typ in ("mRNA", "transcript"): + tid = a.get("ID", "").replace("transcript:", "") + tx_gene[tid] = a.get("Parent", "").replace("gene:", "") + tx_biotype[tid] = a.get("biotype", "") + elif typ == "exon": + for tid in a.get("Parent", "").replace("transcript:", "").split(","): + if tid: + tx_exon[tid].append((s, e)) + elif typ == "CDS": + for tid in a.get("Parent", "").replace("transcript:", "").split(","): + if tid: + tx_cds[tid].append((s, e)) + genes = {} + for tid, gid in tx_gene.items(): + if gene_biotype.get(gid) != "protein_coding" or tx_biotype.get(tid) != "protein_coding": + continue + if not tx_exon.get(tid): + continue + genes.setdefault(gid, {"strand": gene_strand.get(gid, "+"), "tx": {}}) + genes[gid]["tx"][tid] = {"exon": sorted(tx_exon[tid]), "cds": sorted(tx_cds.get(tid, []))} + return genes + + +def representative_tx(gene): + """Longest transcript by total exon length.""" + best, best_len = None, -1 + for tid, t in gene["tx"].items(): + ln = sum(e - s + 1 for s, e in t["exon"]) + if ln > best_len: + best, best_len = tid, ln + return best, gene["tx"][best] + + +def _label_window(chrom, w0, w1, gm, N): + """Label a window [w0,w1) using one gene model's intervals (central-gene approx).""" + L = w1 - w0 + pos = np.arange(w0, w1) + lab = {k: np.zeros(L, bool) for k in ("exon", "intron", "cds", "utr", "intergenic")} + inst = {k: np.full(L, -1, np.int32) for k in ("exon", "intron", "gene")} + g_start, g_end = gm["span"] + in_tx = (pos >= g_start - 1) & (pos < g_end) + lab["intergenic"][~in_tx] = True + inst["gene"][in_tx] = gm["gi"] + for (s, e), iid in zip(gm["exons"], gm["exon_ids"]): + m = (pos >= s - 1) & (pos < e) + lab["exon"][m] = True + inst["exon"][m] = iid + for (s, e), iid in zip(gm["introns"], gm["intron_ids"]): + m = (pos >= s - 1) & (pos < e) + lab["intron"][m] = True + inst["intron"][m] = iid + for s, e in gm["cds"]: + lab["cds"][(pos >= s - 1) & (pos < e)] = True + lab["utr"] = lab["exon"] & ~lab["cds"] + return {"dna": chrom[w0:w1], "labels": lab, "instances": inst} + + +def build_windows( # noqa: D103 + fasta, gff, seq_len=1024, max_tokens=300_000, flank=300, seed=0, intergenic_frac=0.12, dry_run=False +): + seqs = [] + with open(fasta) as fh: + for line in fh: + if not line.startswith(">"): + seqs.append(line.strip()) + chrom = "".join(seqs).upper() + N = len(chrom) + genes = parse_gff(gff) + + exon_id, intron_id, gene_id = {}, {}, {} + stats = defaultdict(int) + gene_models, gene_spans = [], [] + for gid, gene in genes.items(): + tid, tx = representative_tx(gene) + exons, cds = tx["exon"], tx["cds"] + if not exons: + continue + g_start, g_end = exons[0][0], exons[-1][1] + introns = [ + (exons[i][1] + 1, exons[i + 1][0] - 1) + for i in range(len(exons) - 1) + if exons[i + 1][0] - 1 >= exons[i][1] + 1 + ] + gi = gene_id.setdefault(gid, len(gene_id)) + eids = [exon_id.setdefault((tid, i), len(exon_id)) for i in range(len(exons))] + iids = [intron_id.setdefault((tid, i), len(intron_id)) for i in range(len(introns))] + gene_models.append( + { + "exons": exons, + "introns": introns, + "cds": cds, + "gi": gi, + "exon_ids": eids, + "intron_ids": iids, + "span": (g_start, g_end), + } + ) + gene_spans.append((g_start, g_end)) + stats["genes"] += 1 + stats["exons"] += len(exons) + stats["introns"] += len(introns) + stats["exon_bp"] += sum(e - s + 1 for s, e in exons) + stats["intron_bp"] += sum(e - s + 1 for s, e in introns) + stats["cds_bp"] += sum(e - s + 1 for s, e in cds) + if dry_run: + return [], dict(stats), 0, N + + rng = random.Random(seed) + # exon-centered windows sampled across ALL genes' exons (diverse + exon/intron balanced) + exon_refs = [(gi, ei) for gi, gm in enumerate(gene_models) for ei in range(len(gm["exons"]))] + rng.shuffle(exon_refs) + windows, tot = [], 0 + budget_genic = int(max_tokens * (1 - intergenic_frac)) + for gi, ei in exon_refs: + if tot >= budget_genic: + break + gm = gene_models[gi] + s, e = gm["exons"][ei] + center = (s - 1 + e) // 2 + w0 = max(0, center - seq_len // 2) + w1 = min(N, w0 + seq_len) + if w1 - w0 < 60: + continue + win = _label_window(chrom, w0, w1, gm, N) + if win["dna"].count("N") > 0.5 * len(win["dna"]): + continue + windows.append(win) + tot += w1 - w0 + # intergenic windows: random spots clear of any gene span (+flank) + spans = sorted(gene_spans) + tries = 0 + while tot < max_tokens and tries < 20000: + tries += 1 + w0 = rng.randint(0, N - seq_len) + w1 = w0 + seq_len + if any(not (w1 < gs - flank or w0 > ge + flank) for gs, ge in spans): + continue + dna = chrom[w0:w1] + if dna.count("N") > 0.5 * seq_len: + continue + lab = {k: np.zeros(seq_len, bool) for k in ("exon", "intron", "cds", "utr", "intergenic")} + lab["intergenic"][:] = True + inst = {k: np.full(seq_len, -1, np.int32) for k in ("exon", "intron", "gene")} + windows.append({"dna": dna, "labels": lab, "instances": inst}) + tot += seq_len + return windows, dict(stats), tot, N + + +def main(): # noqa: D103 + ap = argparse.ArgumentParser() + ap.add_argument("--fasta", required=True) + ap.add_argument("--gff", required=True) + ap.add_argument("--seq-len", type=int, default=1024) + ap.add_argument("--max-tokens", type=int, default=300_000) + ap.add_argument("--flank", type=int, default=300) + ap.add_argument("--dry-run", action="store_true") + args = ap.parse_args() + windows, stats, tot, N = build_windows( + args.fasta, args.gff, args.seq_len, args.max_tokens, args.flank, dry_run=args.dry_run + ) + print(f"chromosome length: {N:,} bp") + print(f"protein-coding genes used: {stats.get('genes', 0):,}") + print(f"exons: {stats.get('exons', 0):,} introns: {stats.get('introns', 0):,}") + if args.dry_run: + print( + f"exon bp: {stats.get('exon_bp', 0):,} intron bp: {stats.get('intron_bp', 0):,} cds bp: {stats.get('cds_bp', 0):,}" + ) + return + print(f"windows built: {len(windows):,} total tokens: {tot:,}") + # coverage over built windows + cov = defaultdict(int) + ninst = {k: set() for k in ("exon", "intron", "gene")} + for w in windows: + for k, m in w["labels"].items(): + cov[k] += int(m.sum()) + for k in ninst: + ids = w["instances"][k] + ninst[k].update(int(x) for x in np.unique(ids) if x >= 0) + print("per-position coverage (of built windows):") + for k in ("exon", "intron", "cds", "utr", "intergenic"): + print(f" {k:11s} {cov[k]:>9,} ({100 * cov[k] / max(1, tot):5.1f}%)") + print(f"instances: exons={len(ninst['exon']):,} introns={len(ninst['intron']):,} genes={len(ninst['gene']):,}") + + +if __name__ == "__main__": + main() diff --git a/interpretability/sparse_autoencoders/recipes/evo2/scripts/labelers.py b/interpretability/sparse_autoencoders/recipes/evo2/scripts/labelers.py new file mode 100644 index 0000000000..3243742d75 --- /dev/null +++ b/interpretability/sparse_autoencoders/recipes/evo2/scripts/labelers.py @@ -0,0 +1,367 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Extensible per-token biological labelers for SAE feature probing. + +Each labeler maps a `SeqContext` (one tokenized sequence) to a per-token boolean +mask of length `T`. The per-feature AUROC probe (`probe_features.py`) asks, for +every label and every SAE feature, how well the feature's activation separates +positive from negative tokens. + +Adding a feature is just writing a function and decorating it: + + @labeler("my_concept") + def _my(ctx): + return some_bool_array_len_T + +`complex=True` flags labelers that are proxies or need real external annotation +(e.g. true gene models) and should be refined later — they're the natural home +for the "more complicated features" we want to add at the end. + +Conventions +----------- +* Tokens 0..tag_len-1 are the phylogenetic-tag prefix; sequence-derived motif / + positional labels are False there (use `_dna_mask`). Sequence-level labels + (`is_prok`) and norm-based labels (`is_sink_token`) may mark tag tokens. +* Byte-level Evo2 tokenization is 1 char = 1 token, so token i in the DNA region + corresponds to base `ctx.dna[i - tag_len]`. +""" + +from __future__ import annotations + +import re +from dataclasses import dataclass +from typing import Optional + +import numpy as np +from Bio.Data import CodonTable +from Bio.Seq import Seq + + +# name -> fn(ctx) -> np.ndarray[bool] of length T +LABELERS: dict[str, callable] = {} +# labelers that are proxies / need real annotations (documented, refine later) +COMPLEX_LABELERS: set[str] = set() + +# Sink-token norm threshold (residual L2). Set by the driver from the data +# (Evo2 7B layer-26 sinks sit ~1638 vs a ~21 median, so this cleanly isolates them). +SINK_NORM_THRESHOLD: float = 100.0 + + +def labeler(name: str, complex: bool = False): + """Register a per-token labeler under `name`.""" + + def deco(fn): + LABELERS[name] = fn + if complex: + COMPLEX_LABELERS.add(name) + return fn + + return deco + + +@dataclass +class SeqContext: + """Everything a labeler needs about one tokenized sequence.""" + + text: str # tag + dna (1 char == 1 token) + tag_len: int # number of leading phylo-tag tokens + dna: str # the DNA region (uppercase ACGTN), len == T - tag_len + kingdom: str # 'prok' | 'euk' + hidden_norm: np.ndarray # [T] residual L2 norm per token + # Gene-structure annotation over DNA positions (filled by a gene caller; None if absent). + cds_mask: Optional[np.ndarray] = None # bool[len(dna)] — within a predicted CDS (either strand) + cds_frame: Optional[np.ndarray] = None # int8[len(dna)] — codon position 0/1/2 within CDS, -1 if not + gene_starts: Optional[np.ndarray] = None # bool[len(dna)] — predicted translation start positions + + @property + def T(self) -> int: # noqa: D102 + return self.tag_len + len(self.dna) + + +def _dna_mask(ctx: SeqContext, dna_bool: np.ndarray) -> np.ndarray: + """Lift a per-DNA-position bool array to a per-token mask (tag tokens False).""" + out = np.zeros(ctx.T, dtype=bool) + out[ctx.tag_len : ctx.tag_len + len(dna_bool)] = dna_bool + return out + + +def _bytes(dna: str) -> np.ndarray: + return np.frombuffer(dna.encode("ascii", "replace"), dtype=np.uint8) + + +# --------------------------------------------------------------------- positional +@labeler("first_100bp") +def _first(ctx): + d = np.zeros(len(ctx.dna), bool) + d[:100] = True + return _dna_mask(ctx, d) + + +@labeler("last_100bp") +def _last(ctx): + d = np.zeros(len(ctx.dna), bool) + if len(d): + d[-100:] = True + return _dna_mask(ctx, d) + + +# --------------------------------------------------------------------- single base +# Evo2 is nucleotide-level (1 token = 1 base), so the most atomic feature is a single-base +# detector. One labeler per nucleotide — fires at every position whose base equals it — +# so probing can surface features that monosemantically track a single base (e.g. base_G). +def _register_base_labelers(): + for base in "ACGT": + labeler(f"base_{base}")(lambda ctx, b=base: _dna_mask(ctx, _bytes(ctx.dna) == ord(b))) + + +_register_base_labelers() + + +# --------------------------------------------------------------------- composition +def _gc_window(dna: str, radius: int = 10) -> np.ndarray: + arr = _bytes(dna) + gc = ((arr == ord("G")) | (arr == ord("C"))).astype(np.float64) + csum = np.concatenate([[0.0], np.cumsum(gc)]) + n = len(gc) + idx = np.arange(n) + lo = np.maximum(0, idx - radius) + hi = np.minimum(n, idx + radius + 1) + return (csum[hi] - csum[lo]) / np.maximum(1, hi - lo) + + +@labeler("gc_high_window") +def _gch(ctx): + return _dna_mask(ctx, _gc_window(ctx.dna) >= 0.60) + + +@labeler("gc_low_window") +def _gcl(ctx): + return _dna_mask(ctx, _gc_window(ctx.dna) <= 0.30) + + +@labeler("homopolymer_window") +def _homo(ctx, k: int = 5): + d, n = ctx.dna, len(ctx.dna) + out = np.zeros(n, bool) + i = 0 + while i < n: + j = i + while j + 1 < n and d[j + 1] == d[i]: + j += 1 + if j - i + 1 >= k: + out[i : j + 1] = True + i = j + 1 + return _dna_mask(ctx, out) + + +@labeler("dinuc_repeat_window") +def _dinuc(ctx, min_reps: int = 3): + d, n = ctx.dna, len(ctx.dna) + out = np.zeros(n, bool) + i = 0 + while i < n - 1: + if d[i] != d[i + 1]: + j = i + while j + 2 < n and d[j + 2] == d[j]: + j += 1 + span = j + 2 - i + if span >= 2 * min_reps: + out[i : j + 2] = True + i = max(j + 1, i + 1) + else: + i += 1 + return _dna_mask(ctx, out) + + +# --------------------------------------------------------------------- motifs +def _starts(dna: str, pattern: str) -> np.ndarray: + out = np.zeros(len(dna), bool) + for m in re.finditer(pattern, dna): + out[m.start()] = True + return out + + +def _spans(dna: str, pattern: str) -> np.ndarray: + out = np.zeros(len(dna), bool) + for m in re.finditer(pattern, dna): + out[m.start() : m.end()] = True + return out + + +# Consensus motifs: (name, matcher, regex) — `_starts` marks the match start, `_spans` the whole match. +_MOTIFS = [ + ("motif_ATG", _starts, r"ATG"), + ("motif_stop", _starts, r"TAA|TAG|TGA"), + ("motif_TATA", _spans, r"TATA[AT]A"), + ("motif_RBS_SD", _spans, r"AGGAGG"), # Shine-Dalgarno ribosome-binding site +] +for _name, _match, _pat in _MOTIFS: + labeler(_name)(lambda ctx, m=_match, p=_pat: _dna_mask(ctx, m(ctx.dna, p))) + + +# --------------------------------------------------- complex / consensus (refine later) +@labeler("kozak_atg", complex=True) +def _kozak(ctx): + # Kozak: (A/G)xxATGG — mark the ATG start (match start + 3) + out = np.zeros(len(ctx.dna), bool) + for m in re.finditer(r"[AG]..ATGG", ctx.dna): + out[m.start() + 3] = True + return _dna_mask(ctx, out) + + +@labeler("splice_donor", complex=True) +def _sd(ctx): + # 5' donor consensus GT(A/G)AGT — mark the GT + return _dna_mask(ctx, _starts(ctx.dna, r"GT[AG]AGT")) + + +@labeler("splice_acceptor", complex=True) +def _sa(ctx): + # 3' acceptor: polypyrimidine tract then AG — mark the AG + out = np.zeros(len(ctx.dna), bool) + for m in re.finditer(r"[CT]{6}[ACGT]?AG", ctx.dna): + out[m.end() - 2 : m.end()] = True + return _dna_mask(ctx, out) + + +# --------------------------------------------------------------- sequence / norm level +@labeler("is_prok") +def _prok(ctx): + return np.full(ctx.T, ctx.kingdom == "prok", dtype=bool) + + +@labeler("is_sink_token", complex=True) +def _sink(ctx): + return ctx.hidden_norm > SINK_NORM_THRESHOLD + + +# --------------------------------------------- gene structure (real annotation, prok) +# These read a CDS annotation attached to the context by a gene caller (see +# predict_cds, prokaryotes only). They are no-ops when the annotation is absent. +@labeler("cds_coding", complex=True) +def _cds(ctx): + if ctx.cds_mask is None: + return np.zeros(ctx.T, bool) + return _dna_mask(ctx, ctx.cds_mask) + + +@labeler("cds_start", complex=True) +def _cds_start(ctx): + if ctx.gene_starts is None: + return np.zeros(ctx.T, bool) + return _dna_mask(ctx, ctx.gene_starts) + + +@labeler("cds_frame_1", complex=True) +def _cds_f1(ctx): + # codon position 1 within a REAL predicted CDS (not the frame-0-from-start proxy) + if ctx.cds_frame is None: + return np.zeros(ctx.T, bool) + return _dna_mask(ctx, ctx.cds_frame == 0) + + +@labeler("cds_frame_3", complex=True) +def _cds_f3(ctx): + if ctx.cds_frame is None: + return np.zeros(ctx.T, bool) + return _dna_mask(ctx, ctx.cds_frame == 2) + + +_GENE_FINDER = None + +# Standard genetic code (NCBI translation table 1) via Biopython; codon -> amino acid ('*' = stop). +_STD_CODE = CodonTable.unambiguous_dna_by_id[1] +CODON_TABLE = {**_STD_CODE.forward_table, **dict.fromkeys(_STD_CODE.stop_codons, "*")} +CODON_LIST = sorted(CODON_TABLE) # 64 codons +CODON_TO_IDX = {c: i for i, c in enumerate(CODON_LIST)} +AA_LIST = sorted(set(CODON_TABLE.values())) # 20 aa + '*' (stop) +AA_TO_IDX = {a: i for i, a in enumerate(AA_LIST)} + + +def _revcomp(s): + return str(Seq(s).reverse_complement()) + + +def predict_codons(dna: str): + """In-frame codon + amino-acid identity at strand-correct codon anchors (prok genes). + + Returns (codon_id[N], aa_id[N]) over forward DNA coordinates; the anchor is the + first translated base of each codon (low coord on +strand, high coord on -strand), + other positions are -1. codon_id in 0..63 (CODON_LIST), aa_id in 0..20 (AA_LIST). + """ + global _GENE_FINDER + n = len(dna) + codon_id = np.full(n, -1, dtype=np.int16) + aa_id = np.full(n, -1, dtype=np.int8) + if n < 60: + return codon_id, aa_id + if _GENE_FINDER is None: + import pyrodigal + + _GENE_FINDER = pyrodigal.GeneFinder(meta=True) + for g in _GENE_FINDER.find_genes(dna.encode("ascii", "replace")): + b, e = max(0, g.begin - 1), min(n, g.end) + sub = dna[b:e] + coding = sub if g.strand == 1 else _revcomp(sub) + for i in range(len(coding) // 3): + cod = coding[3 * i : 3 * i + 3] + j = CODON_TO_IDX.get(cod) + if j is None: + continue + p = b + 3 * i if g.strand == 1 else (e - 1 - 3 * i) + if 0 <= p < n: + codon_id[p] = j + aa_id[p] = AA_TO_IDX[CODON_TABLE[cod]] + return codon_id, aa_id + + +def predict_cds(dna: str): + """Prokaryotic gene calling via pyrodigal (meta mode) on a single DNA chunk. + + Returns (cds_mask, cds_frame, gene_starts) over forward DNA coordinates: + cds_mask[i] True if position i lies within any predicted CDS (either strand) + cds_frame[i] codon position 0/1/2 relative to that gene's start (strand-aware), else -1 + gene_starts[i] True at predicted translation starts + """ + global _GENE_FINDER + n = len(dna) + cds_mask = np.zeros(n, dtype=bool) + cds_frame = np.full(n, -1, dtype=np.int8) + gene_starts = np.zeros(n, dtype=bool) + if n < 60: + return cds_mask, cds_frame, gene_starts + if _GENE_FINDER is None: + import pyrodigal + + _GENE_FINDER = pyrodigal.GeneFinder(meta=True) + for g in _GENE_FINDER.find_genes(dna.encode("ascii", "replace")): + b, e = g.begin - 1, g.end # 0-based half-open, forward coords + b, e = max(0, b), min(n, e) + if e <= b: + continue + cds_mask[b:e] = True + idx = np.arange(b, e) + if g.strand == 1: + gene_starts[b] = True + cds_frame[b:e] = (idx - b) % 3 + else: # reverse strand: start codon sits at the (forward) end + gene_starts[e - 1] = True + cds_frame[b:e] = ((e - 1) - idx) % 3 + return cds_mask, cds_frame, gene_starts + + +# Default label set for the probe (order preserved in outputs). +DEFAULT_LABELS = list(LABELERS.keys()) diff --git a/interpretability/sparse_autoencoders/recipes/evo2/tests/test_annot_tracks.py b/interpretability/sparse_autoencoders/recipes/evo2/tests/test_annot_tracks.py new file mode 100644 index 0000000000..e927c9620d --- /dev/null +++ b/interpretability/sparse_autoencoders/recipes/evo2/tests/test_annot_tracks.py @@ -0,0 +1,76 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CPU unit tests for the generic interval-track loader (no model / no torch-CUDA).""" + +import sys +from pathlib import Path + + +sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "scripts")) + +from annot_tracks import label_windows, load_track, read_fasta_dict + + +def test_read_fasta_dict_uses_first_token(tmp_path): + """seq_id is the first header token (so it matches BED/GFF chrom).""" + fa = tmp_path / "g.fa" + fa.write_text(">chr1 Homo sapiens\nACGT\nACGT\n>chr2\nTTTT\n") + assert read_fasta_dict(fa) == {"chr1": "ACGTACGT", "chr2": "TTTT"} + + +def test_load_bed_is_half_open(tmp_path): + """BED is 0-based half-open and used as-is.""" + bed = tmp_path / "t.bed" + bed.write_text("chr1\t2\t5\tsiteA\nchr1\t10\t12\tsiteB\n") + assert load_track(str(bed)) == {"chr1": [(2, 5), (10, 12)]} + + +def test_load_gff_converts_to_half_open_and_filters_type(tmp_path): + """GFF 1-based inclusive -> 0-based half-open; feature_type filters column 3.""" + gff = tmp_path / "t.gff3" + gff.write_text( + "# comment\n" + "chr1\tsrc\texon\t3\t5\t.\t+\t.\tID=e1\n" # 1-based [3,5] -> [2,5) + "chr1\tsrc\tCDS\t3\t5\t.\t+\t.\tID=c1\n" + ) + assert load_track(str(gff), feature_type="exon") == {"chr1": [(2, 5)]} + + +def test_label_windows_mask_and_instance_ids(tmp_path): + """Each interval is one instance; mask + instance id line up with the window positions.""" + seqs = {"chr1": "ACGT" * 25} # 100 bp (above the 60 bp window floor) + tracks = {"site": {"chr1": [(2, 5), (10, 12)]}} # two instances + windows, stats = label_windows(seqs, tracks, seq_len=100) + assert len(windows) == 1 + w = windows[0] + mask, inst = w["labels"]["site"], w["instances"]["site"] + assert list(mask.nonzero()[0]) == [2, 3, 4, 10, 11] + assert set(inst[mask].tolist()) == {0, 1} # two distinct instances + assert (inst[~mask] == -1).all() + assert stats["n_inst"]["site"] == 2 + + +def test_instance_id_stable_across_split_windows(): + """An interval spanning a window boundary keeps ONE global instance id (recall counts it once).""" + seqs = {"chr1": "A" * 200} + tracks = {"big": {"chr1": [(90, 110)]}} # straddles the 0-100 / 100-200 boundary + windows, stats = label_windows(seqs, tracks, seq_len=100) + ids = set() + for w in windows: + inst = w["instances"]["big"] + ids.update(int(x) for x in inst[inst >= 0]) + assert ids == {0} # same id in both windows + assert stats["n_inst"]["big"] == 1 diff --git a/interpretability/sparse_autoencoders/recipes/evo2/tests/test_labelers.py b/interpretability/sparse_autoencoders/recipes/evo2/tests/test_labelers.py new file mode 100644 index 0000000000..e99cf236b2 --- /dev/null +++ b/interpretability/sparse_autoencoders/recipes/evo2/tests/test_labelers.py @@ -0,0 +1,53 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CPU tests for the per-token labelers (pure masks, no model).""" + +import sys +from pathlib import Path + +import numpy as np + + +sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "scripts")) + +from labelers import LABELERS, SeqContext + + +def _ctx(dna, tag_len=0): + text = "X" * tag_len + dna + return SeqContext(text=text, tag_len=tag_len, dna=dna, kingdom="prok", hidden_norm=np.zeros(tag_len + len(dna))) + + +def test_consensus_motifs_fire_at_match_positions(): + """The table-driven motifs mark the right positions (ATG/stop = start, TATA = span).""" + ctx = _ctx("ATGTAACGT") # ATG @0 ; TAA (stop) @3 + assert list(LABELERS["motif_ATG"](ctx).nonzero()[0]) == [0] + assert list(LABELERS["motif_stop"](ctx).nonzero()[0]) == [3] + assert list(LABELERS["motif_TATA"](_ctx("TATAAA")).nonzero()[0]) == [0, 1, 2, 3, 4, 5] # spans the match + + +def test_base_labelers_fire_per_nucleotide(): + """base_A/C/G/T each fire exactly on their nucleotide.""" + ctx = _ctx("ACGTAA") + assert list(LABELERS["base_A"](ctx).nonzero()[0]) == [0, 4, 5] + assert list(LABELERS["base_G"](ctx).nonzero()[0]) == [2] + + +def test_tag_prefix_is_unlabeled(): + """Sequence-derived labels are False over the leading phylo-tag tokens.""" + ctx = _ctx("ATG", tag_len=2) # tokens: [tag, tag, A, T, G] + m = LABELERS["motif_ATG"](ctx) + assert len(m) == 5 and not m[:2].any() and m[2] # ATG starts at DNA pos 0 -> token 2 From 6566f2a071af73a8056130e05a12f93b61f9c9fe Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Tue, 23 Jun 2026 06:36:32 +0000 Subject: [PATCH 2/3] =?UTF-8?q?fix(evo2-sae):=20address=20#1630=20review?= =?UTF-8?q?=20=E2=80=94=20deps,=20single-chrom=20guard,=20label=20correctn?= =?UTF-8?q?ess,=20tests?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - declare pyrodigal (predict_cds/predict_codons import it lazily; only biopython was declared, so the gene-calling path would ImportError in a clean env) - euk_windows: reject a multi-record FASTA (parse_gff drops seqid + we concatenate, so a multi-chrom GFF would silently mislabel against a blob) and document the single-chromosome assumption - skip windows overlapping a *second* gene's span (central-gene approx mislabels a neighbor's exons as intergenic) and de-dup near-duplicate adjacent-exon windows so they don't eat the token budget - extract the CDS reverse-strand frame math into a unit-tested helper (_frame_and_start) - tests: euk_windows had none — add parse_gff/1-based->0-based labeling/single-chrom guard; cover the previously-untested labelers (gc/homopolymer/dinuc/kozak/splice) + the +/- strand frame anchor #1630 CPU tests: 16 passed. Co-Authored-By: Claude Opus 4.8 Signed-off-by: Polina Binder --- .../recipes/evo2/pyproject.toml | 1 + .../recipes/evo2/scripts/euk_windows.py | 36 +++++++- .../recipes/evo2/scripts/labelers.py | 24 ++++-- .../recipes/evo2/tests/test_euk_windows.py | 82 +++++++++++++++++++ .../recipes/evo2/tests/test_labelers.py | 38 +++++++++ 5 files changed, 170 insertions(+), 11 deletions(-) create mode 100644 interpretability/sparse_autoencoders/recipes/evo2/tests/test_euk_windows.py diff --git a/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml b/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml index ba794c5e74..047e3e69f0 100644 --- a/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml +++ b/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml @@ -14,6 +14,7 @@ dependencies = [ "numpy>=1.20", "pyarrow>=23.0.0", "biopython>=1.80", # genetic code / translation in labelers.py + "pyrodigal>=3.0", # prokaryotic gene calling in labelers.predict_cds / predict_codons ] # No package code lives here yet — the recipe is just an entry-point for diff --git a/interpretability/sparse_autoencoders/recipes/evo2/scripts/euk_windows.py b/interpretability/sparse_autoencoders/recipes/evo2/scripts/euk_windows.py index 0f381df84b..f656428b9e 100644 --- a/interpretability/sparse_autoencoders/recipes/evo2/scripts/euk_windows.py +++ b/interpretability/sparse_autoencoders/recipes/evo2/scripts/euk_windows.py @@ -28,6 +28,11 @@ `python euk_windows.py --fasta chr21.fa --gff chr21.gff3 --dry-run` prints coverage stats without building sequences. + +Single chromosome only: the FASTA is read as one concatenated sequence and ``parse_gff`` +drops the seqid column, so a multi-record FASTA + multi-chromosome GFF would apply +per-chromosome coordinates against a concatenated blob (silently wrong labels). ``build_windows`` +asserts a single FASTA record; run one chromosome at a time. """ from __future__ import annotations @@ -96,7 +101,12 @@ def representative_tx(gene): def _label_window(chrom, w0, w1, gm, N): - """Label a window [w0,w1) using one gene model's intervals (central-gene approx).""" + """Label a window [w0,w1) using one gene model's intervals (central-gene approx). + + Caveat: labels derive from a single gene model — every position outside this gene's span is + marked intergenic, so a *second* gene overlapping the window would have its exons mislabeled. + build_windows skips windows overlapping another gene's span to avoid injecting that bias. + """ L = w1 - w0 pos = np.arange(w0, w1) lab = {k: np.zeros(L, bool) for k in ("exon", "intron", "cds", "utr", "intergenic")} @@ -122,11 +132,20 @@ def _label_window(chrom, w0, w1, gm, N): def build_windows( # noqa: D103 fasta, gff, seq_len=1024, max_tokens=300_000, flank=300, seed=0, intergenic_frac=0.12, dry_run=False ): - seqs = [] + seqs, n_records = [], 0 with open(fasta) as fh: for line in fh: - if not line.startswith(">"): + if line.startswith(">"): + n_records += 1 + else: seqs.append(line.strip()) + if n_records != 1: + # parse_gff keeps no seqid, and we concatenate all records into one `chrom`; a multi-record + # FASTA + multi-chromosome GFF would silently apply per-chromosome coordinates to the blob. + raise ValueError( + f"euk_windows expects a single-chromosome FASTA matching the GFF, got {n_records} " + f"records. Run one chromosome at a time (e.g. chr21.fa + chr21.gff3)." + ) chrom = "".join(seqs).upper() N = len(chrom) genes = parse_gff(gff) @@ -173,7 +192,7 @@ def build_windows( # noqa: D103 # exon-centered windows sampled across ALL genes' exons (diverse + exon/intron balanced) exon_refs = [(gi, ei) for gi, gm in enumerate(gene_models) for ei in range(len(gm["exons"]))] rng.shuffle(exon_refs) - windows, tot = [], 0 + windows, tot, accepted = [], 0, [] budget_genic = int(max_tokens * (1 - intergenic_frac)) for gi, ei in exon_refs: if tot >= budget_genic: @@ -185,10 +204,19 @@ def build_windows( # noqa: D103 w1 = min(N, w0 + seq_len) if w1 - w0 < 60: continue + # _label_window labels from one gene model, so a second gene overlapping the window would + # have its exons mislabeled intergenic — skip those windows to keep eval labels trustworthy. + if any(j != gi and w0 < ge and w1 > gs - 1 for j, (gs, ge) in enumerate(gene_spans)): + continue + # adjacent exons of one gene center on nearly the same span; drop near-duplicate windows so + # heavily-overlapping (correlated) positions don't quietly eat the token budget. + if any(min(w1, aw1) - max(w0, aw0) > seq_len // 2 for aw0, aw1 in accepted): + continue win = _label_window(chrom, w0, w1, gm, N) if win["dna"].count("N") > 0.5 * len(win["dna"]): continue windows.append(win) + accepted.append((w0, w1)) tot += w1 - w0 # intergenic windows: random spots clear of any gene span (+flank) spans = sorted(gene_spans) diff --git a/interpretability/sparse_autoencoders/recipes/evo2/scripts/labelers.py b/interpretability/sparse_autoencoders/recipes/evo2/scripts/labelers.py index 3243742d75..04b4c0d53e 100644 --- a/interpretability/sparse_autoencoders/recipes/evo2/scripts/labelers.py +++ b/interpretability/sparse_autoencoders/recipes/evo2/scripts/labelers.py @@ -328,6 +328,20 @@ def predict_codons(dna: str): return codon_id, aa_id +def _frame_and_start(b: int, e: int, strand: int): + """Codon-frame array (0/1/2 along [b, e)) + the translation-start position for one CDS. + + Frame is relative to the strand-correct start: the start is the low coord on the + strand + (frame counts up from ``b``) and the high coord on the - strand (frame counts back from + ``e-1``). Pure arithmetic, split out from predict_cds so the reverse-strand anchor (an easy + off-by-one) is unit-tested directly. + """ + idx = np.arange(b, e) + if strand == 1: + return ((idx - b) % 3).astype(np.int8), b + return (((e - 1) - idx) % 3).astype(np.int8), e - 1 + + def predict_cds(dna: str): """Prokaryotic gene calling via pyrodigal (meta mode) on a single DNA chunk. @@ -353,13 +367,9 @@ def predict_cds(dna: str): if e <= b: continue cds_mask[b:e] = True - idx = np.arange(b, e) - if g.strand == 1: - gene_starts[b] = True - cds_frame[b:e] = (idx - b) % 3 - else: # reverse strand: start codon sits at the (forward) end - gene_starts[e - 1] = True - cds_frame[b:e] = ((e - 1) - idx) % 3 + frame, start = _frame_and_start(b, e, g.strand) + cds_frame[b:e] = frame + gene_starts[start] = True return cds_mask, cds_frame, gene_starts diff --git a/interpretability/sparse_autoencoders/recipes/evo2/tests/test_euk_windows.py b/interpretability/sparse_autoencoders/recipes/evo2/tests/test_euk_windows.py new file mode 100644 index 0000000000..5296fa21e5 --- /dev/null +++ b/interpretability/sparse_autoencoders/recipes/evo2/tests/test_euk_windows.py @@ -0,0 +1,82 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CPU tests for euk_windows: GFF parsing, 1-based->0-based labeling, single-chromosome guard.""" + +import sys +from pathlib import Path + +import pytest + + +sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "scripts")) + +from euk_windows import _label_window, build_windows, parse_gff, representative_tx + + +_GFF = ( + "##gff-version 3\n" + "1\tt\tgene\t11\t40\t.\t+\t.\tID=gene:g1;biotype=protein_coding\n" + "1\tt\tmRNA\t11\t40\t.\t+\t.\tID=transcript:t1;Parent=gene:g1;biotype=protein_coding\n" + "1\tt\texon\t11\t20\t.\t+\t.\tParent=transcript:t1\n" + "1\tt\texon\t31\t40\t.\t+\t.\tParent=transcript:t1\n" + "1\tt\tCDS\t11\t20\t.\t+\t.\tParent=transcript:t1\n" + # a non-coding gene that must be filtered out + "1\tt\tgene\t100\t120\t.\t+\t.\tID=gene:g2;biotype=lncRNA\n" + "1\tt\tmRNA\t100\t120\t.\t+\t.\tID=transcript:t2;Parent=gene:g2;biotype=lncRNA\n" + "1\tt\texon\t100\t120\t.\t+\t.\tParent=transcript:t2\n" +) + + +def test_parse_gff_keeps_only_protein_coding(tmp_path): + """parse_gff returns only protein_coding genes, with sorted exon/CDS coords (1-based, inclusive).""" + gff = tmp_path / "g.gff3" + gff.write_text(_GFF) + genes = parse_gff(str(gff)) + assert set(genes) == {"g1"} # the lncRNA gene is filtered out + _, tx = representative_tx(genes["g1"]) + assert tx["exon"] == [(11, 20), (31, 40)] and tx["cds"] == [(11, 20)] + + +def test_label_window_converts_1based_inclusive_to_0based_halfopen(): + """A GFF interval s..e (1-based inclusive) labels 0-based positions s-1..e-1; intergenic/UTR + fall out correctly.""" + chrom = "A" * 50 + gm = { + "exons": [(11, 20)], + "introns": [(21, 25)], + "cds": [(11, 15)], + "gi": 0, + "exon_ids": [0], + "intron_ids": [0], + "span": (11, 25), + } + lab = _label_window(chrom, 0, 50, gm, 50)["labels"] + assert list(lab["exon"].nonzero()[0]) == list(range(10, 20)) # 11..20 -> 10..19 + assert list(lab["intron"].nonzero()[0]) == list(range(20, 25)) # 21..25 -> 20..24 + assert list(lab["cds"].nonzero()[0]) == list(range(10, 15)) # 11..15 -> 10..14 + assert list(lab["utr"].nonzero()[0]) == list(range(15, 20)) # exon minus cds + assert lab["intergenic"][:10].all() and lab["intergenic"][25:].all() # outside the gene span + assert not lab["intergenic"][10:25].any() + + +def test_build_windows_rejects_multi_record_fasta(tmp_path): + """A multi-record FASTA is rejected (coords would be applied against a concatenated blob).""" + fasta = tmp_path / "f.fa" + fasta.write_text(">chr1\nACGTACGT\n>chr2\nACGTACGT\n") + gff = tmp_path / "g.gff3" + gff.write_text("##gff-version 3\n") + with pytest.raises(ValueError, match="single-chromosome"): + build_windows(str(fasta), str(gff)) diff --git a/interpretability/sparse_autoencoders/recipes/evo2/tests/test_labelers.py b/interpretability/sparse_autoencoders/recipes/evo2/tests/test_labelers.py index e99cf236b2..54daaabf65 100644 --- a/interpretability/sparse_autoencoders/recipes/evo2/tests/test_labelers.py +++ b/interpretability/sparse_autoencoders/recipes/evo2/tests/test_labelers.py @@ -51,3 +51,41 @@ def test_tag_prefix_is_unlabeled(): ctx = _ctx("ATG", tag_len=2) # tokens: [tag, tag, A, T, G] m = LABELERS["motif_ATG"](ctx) assert len(m) == 5 and not m[:2].any() and m[2] # ATG starts at DNA pos 0 -> token 2 + + +def test_gc_window_labelers(): + """gc_high/low fire on GC-rich / AT-rich windows (mean GC over the ±10 window).""" + assert LABELERS["gc_high_window"](_ctx("G" * 30)).all() # all-GC -> high everywhere + assert not LABELERS["gc_high_window"](_ctx("A" * 30)).any() + assert LABELERS["gc_low_window"](_ctx("A" * 30)).all() # all-AT -> low everywhere + + +def test_homopolymer_window(): + """homopolymer_window marks runs of >=5 identical bases (and nothing shorter).""" + m = LABELERS["homopolymer_window"](_ctx("TTAAAAACC")) # AAAAA (5) at positions 2..6 + assert list(m.nonzero()[0]) == [2, 3, 4, 5, 6] + + +def test_dinuc_repeat_window(): + """dinuc_repeat_window marks alternating dinucleotide repeats (>=3 reps, span >=6).""" + assert LABELERS["dinuc_repeat_window"](_ctx("ATATATAT")).all() # (AT)x4 + assert not LABELERS["dinuc_repeat_window"](_ctx("ATAT")).any() # (AT)x2 -> too short + + +def test_consensus_offsets_kozak_and_splice(): + """The consensus labelers mark the biologically meaningful offset within the match.""" + assert list(LABELERS["kozak_atg"](_ctx("AAAATGG")).nonzero()[0]) == [3] # [AG]..ATGG -> the ATG + assert list(LABELERS["splice_donor"](_ctx("CCGTAAGTCC")).nonzero()[0]) == [2] # GT[AG]AGT -> the GT + assert list(LABELERS["splice_acceptor"](_ctx("TTTTTTAG")).nonzero()[0]) == [6, 7] # ...AG -> the AG + + +def test_frame_and_start_forward_and_reverse(): + """The CDS frame helper anchors frame 0 at the strand-correct start (the reverse-strand + off-by-one): start = low coord on +strand, high coord on -strand, counting back.""" + from labelers import _frame_and_start + + frame, start = _frame_and_start(10, 19, strand=1) # 9 bp = 3 codons, forward + assert start == 10 and list(frame) == [0, 1, 2, 0, 1, 2, 0, 1, 2] + + frame, start = _frame_and_start(10, 19, strand=-1) # reverse: start at the high coord + assert start == 18 and list(frame) == [2, 1, 0, 2, 1, 0, 2, 1, 0] From 926966b3a78ecbb616df7ea12b045a075b825b84 Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Tue, 23 Jun 2026 17:34:17 +0000 Subject: [PATCH 3/3] ci(evo2-sae): add a CPU lane for the model-agnostic recipe tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit A new ubuntu-latest workflow installs sae + the recipe (CPU torch) and runs the recipe's model-agnostic tests (-m 'not slow') — the label producers (#1630), eval metrics, etc. — so they run cheaply on the probing-stack branches instead of waiting for #1622's megatron GPU lane (which would run them on an L4 after a full build). Registers the 'slow' marker on the recipe pyproject so the GPU tests are excluded without an unknown-marker warning. Validated: pytest tests/ -m 'not slow' -> 16 passed (CPU). Co-Authored-By: Claude Opus 4.8 Signed-off-by: Polina Binder --- .../workflows/unit-tests-evo2-recipe-cpu.yaml | 53 +++++++++++++++++++ .../recipes/evo2/pyproject.toml | 7 +++ 2 files changed, 60 insertions(+) create mode 100644 .github/workflows/unit-tests-evo2-recipe-cpu.yaml diff --git a/.github/workflows/unit-tests-evo2-recipe-cpu.yaml b/.github/workflows/unit-tests-evo2-recipe-cpu.yaml new file mode 100644 index 0000000000..3f3a3ff8d1 --- /dev/null +++ b/.github/workflows/unit-tests-evo2-recipe-cpu.yaml @@ -0,0 +1,53 @@ +name: "BioNeMo Evo2 SAE Recipe CI (CPU)" + +# CPU unit tests for the model-agnostic parts of the evo2 SAE recipe (DNA labelers, window +# builders, eval metrics) — no model and no GPU, just CPU torch + numpy + biopython on +# ubuntu-latest. The model-loading GPU tests (@pytest.mark.slow) run in the dedicated megatron +# lane (unit-tests-interpretability-recipes.yaml); here they're excluded via -m "not slow". + +on: + push: + branches: + - "pull-request/[0-9]+" + - "dependabot/**" + paths: + - "interpretability/sparse_autoencoders/recipes/evo2/**" + - "interpretability/sparse_autoencoders/sae/**" + - ".github/workflows/unit-tests-evo2-recipe-cpu.yaml" + merge_group: + types: [checks_requested] + schedule: + - cron: "0 9 * * *" # Runs at 9 AM UTC daily (2 AM MST) + +defaults: + run: + shell: bash -x -e -u -o pipefail {0} + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs: + recipe-cpu-tests: + runs-on: ubuntu-latest + name: "evo2-recipe-tests (cpu)" + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + sparse-checkout: interpretability/sparse_autoencoders + sparse-checkout-cone-mode: false + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Install (CPU torch + sae + recipe) + working-directory: interpretability/sparse_autoencoders + run: | + pip install --extra-index-url https://download.pytorch.org/whl/cpu -e sae -e recipes/evo2 + + - name: Run model-agnostic recipe tests + working-directory: interpretability/sparse_autoencoders/recipes/evo2 + run: pytest -v tests/ -m "not slow" diff --git a/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml b/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml index 047e3e69f0..d7d988b20d 100644 --- a/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml +++ b/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml @@ -25,3 +25,10 @@ packages = [] [tool.uv.sources] sae = { workspace = true } + +# Register markers locally so the recipe is self-contained (the repo-root pytest.ini isn't in the +# CI sparse-checkout); lets the CPU lane select `-m "not slow"` without an unknown-marker warning. +[tool.pytest.ini_options] +markers = [ + "slow: GPU/integration tests that load a model (skip without CUDA + checkpoints)", +]