Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/heretic/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,18 @@ class Settings(BaseSettings):
"instead of traditional directional ablation."
),
)

use_ara_lora: bool = Field(
default=False,
description=(
"Use LoRA in ARA instead of full-weight editing. Makes it compatible with quantization and removes model reloads."
),
)

ara_lora_rank: int = Field(
default=128,
description="If LoRA is used in ARA, this sets up its rank. Keep it high enough to simulate the 'arbitrary' effect",
)
Comment thread
kabachuha marked this conversation as resolved.

use_piqa: bool = Field(
default=False,
Expand Down
22 changes: 20 additions & 2 deletions src/heretic/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,16 @@ def objective(trial: Trial) -> tuple[float, float]:
print("* Parameters:")
for name, value in get_trial_parameters(settings, trial).items():
print(f" * {name} = [bold]{value}[/]")
if settings.use_ara:
if settings.use_ara_lora:
print("* Reloading model...")
model.reset_model()
print("* Abliterating (Arbitrary-Rank Ablation with LoRA)...")
Comment thread
kabachuha marked this conversation as resolved.
model.ara_lora_abliterate(
good_module_io,
bad_module_io,
ARAParameters(**trial.user_attrs["ara_parameters"]),
)
elif settings.use_ara:
print("* Reloading model...")
model.reset_model()
print("* Abliterating (Arbitrary-Rank Ablation)...")
Expand Down Expand Up @@ -810,7 +819,16 @@ def count_completed_trials() -> int:
print("* Parameters:")
for name, value in get_trial_parameters(settings, trial).items():
print(f" * {name} = [bold]{value}[/]")
if settings.use_ara:
if settings.use_ara_lora:
print("* Resetting model...")
model.reset_model()
print("* Abliterating (Arbitrary-Rank Ablation with LoRA)...")
model.ara_lora_abliterate(
good_module_io,
bad_module_io,
ARAParameters(**trial.user_attrs["ara_parameters"]),
)
elif settings.use_ara:
print("* Reloading model...")
model.reset_model()
print("* Abliterating (Arbitrary-Rank Ablation)...")
Expand Down
146 changes: 132 additions & 14 deletions src/heretic/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def __init__(self, settings: Settings):
if self.model is None:
raise Exception("Failed to load model with all configured dtypes.")

if not settings.use_ara:
if not settings.use_ara or settings.use_ara_lora:
self._apply_lora()

# LoRA B matrices are initialized to zero by default in PEFT,
Expand All @@ -188,21 +188,24 @@ def _apply_lora(self):
# because hybrid models like Qwen3.5 MoE have modules with different names
# across layers (e.g. "o_proj" on attention layers, "out_proj" on linear attention layers).
target_modules_set: set[str] = set()

module_id_to_full_name = {
id(module): module_name
for module_name, module in self.model.named_modules()
}

for layer_index, layer in enumerate(self.get_layers()):
module_id_to_leaf_name = {
id(module): module_name.split(".")[-1]
for module_name, module in layer.named_modules()
}

for layer_index in range(len(self.get_layers())):
for modules in self.get_layer_modules(layer_index).values():
for module in modules:
if id(module) in module_id_to_leaf_name:
target_modules_set.add(module_id_to_leaf_name[id(module)])
full_name = module_id_to_full_name.get(id(module))
if full_name is not None:
target_modules_set.add(full_name)

target_modules = list(target_modules_set)
target_modules = sorted(target_modules_set)

if self.settings.row_normalization != RowNormalization.FULL:
if self.settings.use_ara_lora:
lora_rank = self.settings.ara_lora_rank
elif self.settings.row_normalization != RowNormalization.FULL:
# Rank 1 is sufficient for directional ablation without renormalization.
lora_rank = 1
else:
Expand All @@ -224,7 +227,10 @@ def _apply_lora(self):
# so the result is a PeftModel rather than a PeftMixedModel.
self.model = cast(PeftModel, get_peft_model(self.model, self.peft_config))

print(f"* LoRA adapters initialized (targets: {', '.join(target_modules)})")
display_targets = sorted({name.rsplit(".", 1)[-1] for name in target_modules})
print(
f"* LoRA adapters initialized (target types: {', '.join(display_targets)})"
)

def _get_quantization_config(self, dtype: str) -> BitsAndBytesConfig | None:
"""
Expand Down Expand Up @@ -311,7 +317,7 @@ def reset_model(self):
if (
current_model == self.settings.model
and not self.needs_reload
and not self.settings.use_ara
and (not self.settings.use_ara or self.settings.use_ara_lora)
):
# Reset LoRA adapters to zero (identity transformation)
for name, module in self.model.named_modules():
Expand Down Expand Up @@ -341,7 +347,7 @@ def reset_model(self):
**extra_kwargs,
)

if not self.settings.use_ara:
if not self.settings.use_ara or self.settings.use_ara_lora:
self._apply_lora()

self.needs_reload = False
Expand Down Expand Up @@ -668,6 +674,118 @@ def closure() -> Tensor:
with torch.no_grad():
matrix.copy_(get_matrix())

def ara_lora_abliterate(
self,
good_module_io: ModuleIO,
bad_module_io: ModuleIO,
parameters: ARAParameters,
):
for layer_index in range(
parameters.start_layer_index,
parameters.end_layer_index,
):
for component, modules in self.get_layer_modules(layer_index).items():
for module_index, module in enumerate(modules):
# Cast to Linear to access weights and LoRA adapters
module = cast(Linear, module)

# --- 1. Base Weight Handling & Dequantization ---
# We need the base weight in float32 to compute the effective weight
base_weight = cast(Tensor, module.base_layer.weight)
quant_state = getattr(base_weight, "quant_state", None)

if quant_state is None:
W_base = base_weight.to(torch.float32)
else:
# Maintain the original dequantization logic for bitsandbytes
W_base = cast(
Tensor,
bnb.functional.dequantize_4bit(
base_weight.data,
quant_state
).to(torch.float32),
)

# --- 2. Row Normalization Setup ---
# Pre-calculate the original row norms to preserve them
# This implements the RowNormalization.FULL logic
W_row_norms = LA.vector_norm(W_base, dim=1, keepdim=True)

# --- 3. Adapter Target Identification ---
# We optimize the LoRA weights A and B
lora_A = cast(Tensor, module.lora_A["default"].weight)
lora_B = cast(Tensor, module.lora_B["default"].weight)

# --- 4. Data Preparation ---
# Move I/O tensors to the device of the adapter weights
good_input, good_output = good_module_io[layer_index][component][module_index]
bad_input, bad_output = bad_module_io[layer_index][component][module_index]

good_input = good_input.float().to(lora_A.device)
good_output = good_output.float().to(lora_A.device)
bad_input = bad_input.float().to(lora_A.device)
bad_output = bad_output.float().to(lora_A.device)

# --- 5. The Objective Function ---
def objective(A: Tensor, B: Tensor) -> Tensor:
# Calculate effective weight: W_eff = W_base + B @ A
W_eff = W_base + (B @ A)

# Apply Row Normalization (keep original norms)
if self.settings.row_normalization == RowNormalization.FULL:
# Normalize to unit length, then scale by original norms
W_eff = F.normalize(W_eff, p=2, dim=1) * W_row_norms

# Compute outputs using the effective weight
new_good_output = good_input @ W_eff.T
new_bad_output = bad_input @ W_eff.T

# The original ARA loss function
preserve_good_behavior = (
(new_good_output - good_output) ** 2
).mean()

steer_bad_behavior = (
mean_distances_to_knn(
new_bad_output,
good_output,
parameters.neighbor_count,
).mean()
+ parameters.overcorrect_relative_weight
* -mean_distances_to_knn(
new_bad_output,
bad_output,
parameters.neighbor_count,
).mean()
)

return (
parameters.preserve_good_behavior_weight
* preserve_good_behavior
+ parameters.steer_bad_behavior_weight * steer_bad_behavior
)

# --- 6. Optimization Loop ---
# We optimize A and B, not the base matrix
optimizer = LBFGS(
[lora_A, lora_B],
lr=1.0,
max_iter=20,
history_size=10,
line_search_fn="strong_wolfe",
)

def closure():
optimizer.zero_grad()
# Pass the actual tensors being optimized to the objective
loss = objective(lora_A, lora_B)
loss.backward()
return loss

# Run optimization steps
for step in range(5):
optimizer.step(closure)

Comment thread
kabachuha marked this conversation as resolved.
def generate(
self,
prompts: list[Prompt],
Expand Down
Loading