Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions config.default.toml
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,45 @@ column = "text"
dataset = "mlabonne/harmful_behaviors"
split = "test[:100]"
column = "text"

# The parameters used to choose suggest settings for abliteration. With the
# exception of direction_scope and max_weight, all values are between 0 and 1.
# min_weight is set relative to max_weight, and the other values are relative
# to the number of layers in the model.
# By default, the same ranges are used across all components. For parameters
# that are sampled per component, the ranges are specified for each component.
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which parameters are not sampled per component?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

direction_scope and direction_fraction (which together determine direction_index).

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, although this itself is something that the user might want to configure. Is that possible with your current implementation?

[parameters]

# The different refusal direction scopes that can be applied to each trial.
direction_scope = [
# Choose a refusal direction by interpolating between 2 layers and apply it globally.
"global",
# For each layer within range, apply the layer's own refusal direction to itself.
"per layer",
]

# For the global direction scope, the layer from which to choose the refusal direction.
direction_fraction = { low = 0.4, high = 0.9 }

# The maximum weight with which to apply the abliteration. Set log = true to
# sample from the log space, which will select lower values more frequently.
# Note that low must be greater than 0 when log = true.
[parameters.max_weight]
"attn.o_proj" = { low = 0.8, high = 1.5, log = false }
"mlp.down_proj" = { low = 0.8, high = 1.5, log = false }

# The position (layer) at which the maximum weight should be applied.
[parameters.max_weight_position_fraction]
"attn.o_proj" = { low = 0.6, high = 1.0 }
"mlp.down_proj" = { low = 0.6, high = 1.0 }

# The minimum weight as a fraction of the maximum weight.
[parameters.min_weight_relative]
"attn.o_proj" = { low = 0.0, high = 1.0 }
"mlp.down_proj" = { low = 0.0, high = 1.0 }

# The distance from max_weight_position across which the weight drops from
# max_weight to min_weight. Beyond this distance, the weight is set to 0.
[parameters.min_weight_distance_fraction]
"attn.o_proj" = { low = 0.0, high = 0.6 }
"mlp.down_proj" = { low = 0.0, high = 0.6 }
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ heretic = "heretic.main:main"
requires = ["uv_build>=0.8.11,<0.9.0"]
build-backend = "uv_build"

[tool.ty.environment]
python-version = "3.10"
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't ty read .python-version? I'm regularly seeing warnings about code that only works in Python >3.10.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, maybe it does! But .python-version is currently set to 3.12, so I guess that doesn't help.

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, best to leave this setting then.

I'd love to get rid of Python 3.10, but too many cloud providers still serve it as the default. I guess we'll have to wait for the major libraries like Transformers to move on before we can do the same.


[tool.uv]
exclude-newer = "7 days"

Expand Down
7 changes: 7 additions & 0 deletions src/heretic/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
TomlConfigSettingsSource,
)

from .parameters import ParameterSpecification

# !!!IMPORTANT!!!
#
# Any settings added to the classes defined in this module
Expand Down Expand Up @@ -493,6 +495,11 @@ class Settings(BaseSettings):
description="Dataset of prompts that tend to result in refusals (used for evaluating model performance).",
)

parameters: ParameterSpecification = Field(
default=ParameterSpecification(),
description="The parameter specifications, per parameter or per component within each parameter.",
)

@classmethod
def settings_customise_sources(
cls,
Expand Down
97 changes: 45 additions & 52 deletions src/heretic/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,15 @@ def _is_help_invocation() -> bool:
from optuna.trial import TrialState
from pydantic import ValidationError
from questionary import Choice, Style
from rich.markup import escape
from rich.table import Table
from rich.traceback import install

from .analyzer import Analyzer
from .config import QuantizationMethod
from .evaluator import Evaluator
from .model import AbliterationParameters, Model, get_model_class
from .parameters import Parameters, ParamKind
from .reproduce import collect_reproducibles
from .system import empty_cache, get_accelerator_info
from .utils import (
Expand Down Expand Up @@ -196,7 +198,19 @@ def run():
print(f"[red]Configuration contains [bold]{error.error_count()}[/] errors:[/]")

for error in error.errors():
print(f"[bold]{error['loc'][0]}[/]: [yellow]{error['msg']}[/]")
full_loc = str(error["loc"][0])
for loc in error["loc"][1:]:
if loc in ParamKind:
continue # Skip param kinds added to loc by discriminators.
if not isinstance(loc, str):
full_loc += f"[{loc}]"
elif loc == "[key]":
full_loc += " (key error)"
elif "." in loc:
full_loc += f'["{loc}"]'
else:
full_loc += f".{loc}"
print(f"[bold]{escape(full_loc)}[/]: [yellow]{error['msg']}[/]")

print()
print(
Expand Down Expand Up @@ -322,6 +336,7 @@ def run():
return

model = Model(settings)
params = Parameters(settings.parameters)
print()
print_memory_usage()

Expand Down Expand Up @@ -490,69 +505,47 @@ def objective(trial: Trial) -> tuple[float, float]:
trial_index += 1
trial.set_user_attr("index", trial_index)

direction_scope = trial.suggest_categorical(
"direction_scope",
[
"global",
"per layer",
],
)
direction_scope_param = params.direction_scope.get()
direction_scope = direction_scope_param.suggest(trial)

last_layer_index = len(model.get_layers()) - 1

# Discrimination between "harmful" and "harmless" inputs is usually strongest
# in layers slightly past the midpoint of the layer stack. See the original
# abliteration paper (https://arxiv.org/abs/2406.11717) for a deeper analysis.
#
# Note that we always sample this parameter even though we only need it for
# the "global" direction scope. The reason is that multivariate TPE doesn't
# work with conditional or variable-range parameters.
direction_index = trial.suggest_float(
"direction_index",
0.4 * last_layer_index,
0.9 * last_layer_index,
)

if direction_scope == "per layer":
direction_index = None
# Note that we always sample this parameter when the "global" direction
# scope is included in the choices, even though we only need it for the
# "global" direction scope itself. The reason is that multivariate TPE
# doesn't work with conditional or variable-range parameters.
if "global" in direction_scope_param.choices:
direction_fraction = params.direction_fraction.suggest(trial)
else:
direction_fraction = None

parameters = {}
parameters: dict[str, AbliterationParameters] = {}

for component in model.get_abliterable_components():
# The parameter ranges are based on experiments with various models
# and much wider ranges. They are not set in stone and might have to be
# adjusted for future models.
max_weight = trial.suggest_float(
f"{component}.max_weight",
0.8,
1.5,
)
max_weight_position = trial.suggest_float(
f"{component}.max_weight_position",
0.6 * last_layer_index,
1.0 * last_layer_index,
)
# For sampling purposes, min_weight is expressed as a fraction of max_weight,
# again because multivariate TPE doesn't support variable-range parameters.
# The value is transformed into the actual min_weight value below.
min_weight = trial.suggest_float(
f"{component}.min_weight",
0.0,
1.0,
max_weight = params.max_weight.suggest(trial, component)

max_weight_position_fraction = params.max_weight_position_fraction.suggest(
trial, component
)
min_weight_distance = trial.suggest_float(
f"{component}.min_weight_distance",
1.0,
0.6 * last_layer_index,

min_weight_relative = params.min_weight_relative.suggest(trial, component)

min_weight_distance_fraction = params.min_weight_distance_fraction.suggest(
trial, component
)

parameters[component] = AbliterationParameters(
parameters[component.value] = AbliterationParameters(
max_weight=max_weight,
max_weight_position=max_weight_position,
min_weight=(min_weight * max_weight),
min_weight_distance=min_weight_distance,
max_weight_position=max_weight_position_fraction * last_layer_index,
min_weight=min_weight_relative * max_weight,
min_weight_distance=min_weight_distance_fraction * last_layer_index,
)

if direction_fraction is None or direction_scope != "global":
direction_index = None
else:
direction_index = direction_fraction * last_layer_index

trial.set_user_attr("direction_index", direction_index)
trial.set_user_attr("parameters", {k: asdict(v) for k, v in parameters.items()})

Expand Down
31 changes: 16 additions & 15 deletions src/heretic/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
)

from .config import QuantizationMethod, RowNormalization, Settings
from .parameters import ModelComponent
from .system import empty_cache
from .utils import Prompt, batchify, print

Expand Down Expand Up @@ -155,7 +156,7 @@ def __init__(self, settings: Settings):

print(f"* Transformer model with [bold]{len(self.get_layers())}[/] layers")

all_components = {}
all_components: dict[ModelComponent, int] = {}
for layer_index in range(len(self.get_layers())):
for component, modules in self.get_layer_modules(layer_index).items():
if component not in all_components:
Expand All @@ -164,7 +165,7 @@ def __init__(self, settings: Settings):

print("* Abliterable components:")
for component, count in all_components.items():
print(f" * [bold]{component}[/]: [bold]{count}[/] modules total")
print(f" * [bold]{component.value}[/]: [bold]{count}[/] modules total")

def _apply_lora(self):
# Guard against calling this method at the wrong time.
Expand Down Expand Up @@ -349,12 +350,12 @@ def get_layers(self) -> ModuleList:
# Text-only models.
return model.model.layers

def get_layer_modules(self, layer_index: int) -> dict[str, list[Module]]:
def get_layer_modules(self, layer_index: int) -> dict[ModelComponent, list[Module]]:
layer = self.get_layers()[layer_index]

modules = {}
modules: dict[ModelComponent, list[Module]] = {}

def try_add(component: str, module: Any):
def try_add(component: ModelComponent, module: Any):
# Only add if it's a proper nn.Module (PEFT can wrap these with LoRA)
if isinstance(module, Module):
if component not in modules:
Expand All @@ -363,49 +364,49 @@ def try_add(component: str, module: Any):
else:
# Assert for unexpected types (catches architecture changes)
assert not isinstance(module, Tensor), (
f"Unexpected Tensor in {component} - expected nn.Module"
f"Unexpected Tensor in {component.value} - expected nn.Module"
)

# Standard self-attention out-projection (most models).
with suppress(Exception):
try_add("attn.o_proj", layer.self_attn.o_proj) # ty:ignore[possibly-missing-attribute]
try_add(ModelComponent.ATTN_O_PROJ, layer.self_attn.o_proj) # ty:ignore[possibly-missing-attribute]

# Qwen3.5 MoE hybrid layers use GatedDeltaNet (linear attention) instead of
# standard self-attention, so self_attn.o_proj doesn't exist on those layers.
with suppress(Exception):
try_add("attn.o_proj", layer.linear_attn.out_proj) # ty:ignore[possibly-missing-attribute]
try_add(ModelComponent.ATTN_O_PROJ, layer.linear_attn.out_proj) # ty:ignore[possibly-missing-attribute]

# Most dense models.
with suppress(Exception):
try_add("mlp.down_proj", layer.mlp.down_proj) # ty:ignore[possibly-missing-attribute]
try_add(ModelComponent.MLP_DOWN_PROJ, layer.mlp.down_proj) # ty:ignore[possibly-missing-attribute]

# Some MoE models (e.g. Qwen3).
with suppress(Exception):
for expert in layer.mlp.experts: # ty:ignore[possibly-missing-attribute, not-iterable]
try_add("mlp.down_proj", expert.down_proj) # ty:ignore[possibly-missing-attribute]
try_add(ModelComponent.MLP_DOWN_PROJ, expert.down_proj) # ty:ignore[possibly-missing-attribute]

# Phi-3.5-MoE (and possibly others).
with suppress(Exception):
for expert in layer.block_sparse_moe.experts: # ty:ignore[possibly-missing-attribute, not-iterable]
try_add("mlp.down_proj", expert.w2) # ty:ignore[possibly-missing-attribute]
try_add(ModelComponent.MLP_DOWN_PROJ, expert.w2) # ty:ignore[possibly-missing-attribute]

# Granite MoE Hybrid - attention layers with shared_mlp.
with suppress(Exception):
try_add("mlp.down_proj", layer.shared_mlp.output_linear) # ty:ignore[possibly-missing-attribute]
try_add(ModelComponent.MLP_DOWN_PROJ, layer.shared_mlp.output_linear) # ty:ignore[possibly-missing-attribute]

# Granite MoE Hybrid - MoE layers with experts.
with suppress(Exception):
for expert in layer.moe.experts: # ty:ignore[possibly-missing-attribute, not-iterable]
try_add("mlp.down_proj", expert.output_linear) # ty:ignore[possibly-missing-attribute]
try_add(ModelComponent.MLP_DOWN_PROJ, expert.output_linear) # ty:ignore[possibly-missing-attribute]

# We need at least one module across all components for abliteration to work.
total_modules = sum(len(mods) for mods in modules.values())
assert total_modules > 0, "No abliterable modules found in layer"

return modules

def get_abliterable_components(self) -> list[str]:
components: set[str] = set()
def get_abliterable_components(self) -> list[ModelComponent]:
components: set[ModelComponent] = set()

# Scan all layers because hybrid models (e.g. Qwen3.5 MoE) have different
# components on different layers (some have self_attn, others linear_attn).
Expand Down
Loading