Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions brainscore_vision/models/clip_vitb32_marrenj/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from brainscore_vision import model_registry
from brainscore_vision.model_helpers.brain_transformation import ModelCommitment
from .model import get_model, LAYERS, BEHAVIORAL_READOUT_LAYER

# No region_layer_map — brain-score's search picks the best visual.transformer
# block per region. behavioral_readout_layer is pinned to ln_post (the
# pre-projection 768-d CLS feature, matching our own alignment metric).
model_registry['clip_vitb32_marrenj'] = lambda: ModelCommitment(
identifier='clip_vitb32_marrenj',
activations_model=get_model(),
layers=LAYERS,
behavioral_readout_layer=BEHAVIORAL_READOUT_LAYER,
)
145 changes: 145 additions & 0 deletions brainscore_vision/models/clip_vitb32_marrenj/clip_arch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
"""Self-contained vision tower for our DeCLIP-trained CLIP-ViT-B/32, bundled
inside the brain-score submission plugin so the CI sandbox doesn't need to clone
our research repo. Trimmed from src/model.py to JUST the VisionTransformer +
direct dependencies (Attention, TransformerBlock, LayerNorm, QuickGELU). The
text encoder and full CLIP wrapper are dropped — brain-score only needs visual
feature extraction.

State dict loading: our Lightning checkpoint has keys like
`model.visual.conv1.weight`, `model.text.*`, `model.logit_scale`. The plugin's
model.py strips the `model.visual.` prefix and filters to visual-only keys
before calling `VisionTransformer.load_state_dict`.
"""
import math
from typing import Callable, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F


class QuickGELU(nn.Module):
def forward(self, x):
return x * torch.sigmoid(1.702 * x)


class LayerNorm(nn.LayerNorm):
"""Standard nn.LayerNorm — kept under its original name so state_dict keys match."""
def forward(self, x):
orig_type = x.dtype
out = super().forward(x)
return out.to(orig_type)


class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)

def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
if hasattr(F, 'scaled_dot_product_attention'):
x = F.scaled_dot_product_attention(q, k, v, attn_mask=None,
dropout_p=self.attn_drop.p if self.training else 0.0,
is_causal=False)
x = x.transpose(1, 2).reshape(B, N, C)
else:
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x


class TransformerBlock(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4.0, qkv_bias=False, drop=0.,
attn_drop=0., ls_init_value: Optional[float] = None,
act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias,
attn_drop=attn_drop, proj_drop=drop)
self.ls_init_value = ls_init_value
if ls_init_value is not None:
self.gamma_1 = nn.Parameter(ls_init_value * torch.ones(dim))
self.gamma_2 = nn.Parameter(ls_init_value * torch.ones(dim))
else:
self.gamma_1 = None
self.gamma_2 = None
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_hidden_dim),
act_layer(),
nn.Dropout(drop),
nn.Linear(mlp_hidden_dim, dim),
nn.Dropout(drop),
)

def forward(self, x):
if self.gamma_1 is None:
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
else:
x = x + self.gamma_1 * self.attn(self.norm1(x))
x = x + self.gamma_2 * self.mlp(self.norm2(x))
return x


class VisionTransformer(nn.Module):
"""ViT-B/32 visual encoder for CLIP-style models. Architectural defaults are
hardcoded to match VIT_B_32_CONFIG from our codebase (image_size=224,
patch_size=32, width=768, layers=12, heads=12, mlp_ratio=4.0,
output_dim=512, no layer scale)."""
def __init__(self, image_size=224, patch_size=32, width=768, layers=12,
heads=12, mlp_ratio=4.0, output_dim=512,
ls_init_value: Optional[float] = None,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm):
super().__init__()
self.image_size = image_size
self.patch_size = patch_size
self.width = width
self.output_dim = output_dim
self.grid_size = (image_size // patch_size, image_size // patch_size)
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.conv1 = nn.Conv2d(3, width, kernel_size=patch_size, stride=patch_size, bias=False)
self.class_embedding = nn.Parameter(torch.randn(width))
self.positional_embedding = nn.Parameter(torch.randn(self.num_patches + 1, width) * 0.01)
self.ln_pre = norm_layer(width)
self.transformer = nn.Sequential(*[
TransformerBlock(dim=width, num_heads=heads, mlp_ratio=mlp_ratio,
qkv_bias=True, ls_init_value=ls_init_value,
act_layer=act_layer, norm_layer=norm_layer)
for _ in range(layers)
])
self.ln_post = norm_layer(width)
self.proj = nn.Parameter(torch.randn(width, output_dim) * (1 / width ** 0.5))

def forward(self, x):
x = self.conv1(x)
x = x.reshape(x.shape[0], x.shape[1], -1).permute(0, 2, 1)
x = torch.cat([
self.class_embedding.to(x.dtype) + torch.zeros(
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device,
),
x,
], dim=1)
x = x + self.positional_embedding.to(x.dtype)
x = self.ln_pre(x)
for block in self.transformer:
x = block(x)
x = x[:, 0]
x = self.ln_post(x)
x = x @ self.proj
return x
116 changes: 116 additions & 0 deletions brainscore_vision/models/clip_vitb32_marrenj/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
"""CLIP-ViT-B/32 Vanderbilt — final checkpoint of our DeCLIP/YFCC15M-trained
contrastive baseline (epoch 31, 32-epoch budget). The model is our own custom
class (not OpenCLIP); we bundle a trimmed visual-only version under
`clip_arch.py` so the brain-score sandbox doesn't need our research repo.

Brain-Score coverage NOTE: behavioral benchmarks that decode from a 1000-class
ImageNet logits layer (e.g. Geirhos2021-*) will not work for this model because
contrastive CLIP doesn't have a native ImageNet classifier. To run those, we'd
need a zero-shot CLIP classifier head (cosine sim between visual features and
text embeddings of the 1000 ImageNet class names). For this first submission
we leave that out — neural V1/V2/V4/IT + Rajalingham2018-i2n behavioral all
work directly off the visual encoder's features.

We use CLIP-style preprocessing (Resize(224) → CenterCrop(224) → CLIP mean/std)
to match how the model was trained.
"""
import functools
import numpy as np
import torch
from PIL import Image
from torchvision import transforms as T
from huggingface_hub import hf_hub_download
from brainscore_vision.model_helpers.activations.pytorch import PytorchWrapper

from .clip_arch import VisionTransformer

# === EDIT BEFORE SUBMITTING ===
HF_REPO_ID = "marrenj/temporal-dynamics-baselines"
HF_FILENAME = "clip_vitb32_baseline_ep031.ckpt"

# OpenAI CLIP normalization — what the visual encoder was trained with.
CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
CLIP_STD = (0.26862954, 0.26130258, 0.27577711)

# Candidate visual-encoder layers for brain-score's region commitment search.
# `ln_post` is the pre-projection 768-d CLS feature we already use as the
# behavioral readout in our own alignment metric.
LAYERS = [
"transformer.0",
"transformer.3",
"transformer.6",
"transformer.9",
"transformer.11",
"ln_post",
]
BEHAVIORAL_READOUT_LAYER = "ln_post"


BIBTEX = """@misc{marrenj_temporal_dynamics_2026,
title={Temporal Dynamics of Human Behavioral Alignment in ImageNet-trained Models},
author={Wallace Lab},
year={2026},
note={CLIP-ViT-B/32 (custom DeCLIP-trained), YFCC15M, 32 epochs},
}"""


def _clip_preprocessing(image_filepaths):
"""CLIP-style preprocessing pipeline that matches our model's training
preprocessing (src/dataset.py eval branch). Returns a (B, C, 224, 224)
numpy stack — brain-score's PytorchWrapper expects this exact shape."""
val_transform = T.Compose([
T.Resize(224, interpolation=T.InterpolationMode.BICUBIC),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(CLIP_MEAN, CLIP_STD),
])
out = []
for p in image_filepaths:
img = Image.open(p).convert("RGB")
out.append(val_transform(img).numpy())
return np.stack(out)


def get_model():
weights_path = hf_hub_download(repo_id=HF_REPO_ID, filename=HF_FILENAME)
raw = torch.load(weights_path, map_location="cpu", weights_only=True)
state_dict = raw.get("state_dict", raw)
# Lightning prefix: model.visual.<...> | model.text.* | model.logit_scale
# We need: <...> for just the visual encoder.
visual_sd = {
k[len("model.visual."):]: v
for k, v in state_dict.items()
if k.startswith("model.visual.")
}
if not visual_sd:
# In case the ckpt isn't Lightning-wrapped: try `visual.<...>` direct.
visual_sd = {
k[len("visual."):]: v for k, v in state_dict.items()
if k.startswith("visual.")
}
if not visual_sd:
raise RuntimeError(
f"could not find any visual-encoder keys in checkpoint {HF_FILENAME}. "
"Expected keys prefixed with 'model.visual.' or 'visual.'."
)

model = VisionTransformer() # uses ViT-B/32 defaults
missing, unexpected = model.load_state_dict(visual_sd, strict=False)
if missing or unexpected:
print(f" [clip_vitb32_marrenj] state_dict load: "
f"missing={len(missing)}, unexpected={len(unexpected)}")
if missing[:3]: print(f" sample missing: {missing[:3]}")
if unexpected[:3]: print(f" sample unexpected: {unexpected[:3]}")
model.eval()

wrapper = PytorchWrapper(
identifier="clip_vitb32_marrenj",
model=model,
preprocessing=_clip_preprocessing,
)
wrapper.image_size = 224
return wrapper


def get_bibtex(model_identifier):
return BIBTEX
5 changes: 5 additions & 0 deletions brainscore_vision/models/clip_vitb32_marrenj/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
torch==2.2.1
torchvision==0.17.1
huggingface_hub>=0.25
numpy
pillow
9 changes: 9 additions & 0 deletions brainscore_vision/models/clip_vitb32_marrenj/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""Minimum sanity test."""
import pytest
import brainscore_vision


@pytest.mark.private_access
def test_has_identifier():
model = brainscore_vision.load_model('clip_vitb32_marrenj')
assert model.identifier == 'clip_vitb32_marrenj'
Loading