diff --git a/src/heretic/config.py b/src/heretic/config.py index 77de3c59..b5bdb6bf 100644 --- a/src/heretic/config.py +++ b/src/heretic/config.py @@ -203,6 +203,18 @@ class Settings(BaseSettings): "instead of traditional directional ablation." ), ) + + use_ara_lora: bool = Field( + default=False, + description=( + "Use LoRA in ARA instead of full-weight editing. Makes it compatible with quantization and removes model reloads." + ), + ) + + ara_lora_rank: int = Field( + default=128, + description="If LoRA is used in ARA, this sets up its rank. Keep it high enough to simulate the 'arbitrary' effect.", + ) use_piqa: bool = Field( default=False, diff --git a/src/heretic/main.py b/src/heretic/main.py index 0abd11d7..b633484b 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -622,7 +622,16 @@ def objective(trial: Trial) -> tuple[float, float]: print("* Parameters:") for name, value in get_trial_parameters(settings, trial).items(): print(f" * {name} = [bold]{value}[/]") - if settings.use_ara: + if settings.use_ara_lora: + print("* Resetting model...") + model.reset_model() + print("* Abliterating (Arbitrary-Rank Ablation with LoRA)...") + model.ara_lora_abliterate( + good_module_io, + bad_module_io, + ARAParameters(**trial.user_attrs["ara_parameters"]), + ) + elif settings.use_ara: print("* Reloading model...") model.reset_model() print("* Abliterating (Arbitrary-Rank Ablation)...") @@ -810,7 +819,16 @@ def count_completed_trials() -> int: print("* Parameters:") for name, value in get_trial_parameters(settings, trial).items(): print(f" * {name} = [bold]{value}[/]") - if settings.use_ara: + if settings.use_ara_lora: + print("* Resetting model...") + model.reset_model() + print("* Abliterating (Arbitrary-Rank Ablation with LoRA)...") + model.ara_lora_abliterate( + good_module_io, + bad_module_io, + ARAParameters(**trial.user_attrs["ara_parameters"]), + ) + elif settings.use_ara: print("* Reloading model...") model.reset_model() print("* Abliterating (Arbitrary-Rank Ablation)...") diff --git a/src/heretic/model.py b/src/heretic/model.py index 108e9dfd..53130389 100644 --- a/src/heretic/model.py +++ b/src/heretic/model.py @@ -161,7 +161,7 @@ def __init__(self, settings: Settings): if self.model is None: raise Exception("Failed to load model with all configured dtypes.") - if not settings.use_ara: + if not settings.use_ara or settings.use_ara_lora: self._apply_lora() # LoRA B matrices are initialized to zero by default in PEFT, @@ -188,21 +188,24 @@ def _apply_lora(self): # because hybrid models like Qwen3.5 MoE have modules with different names # across layers (e.g. "o_proj" on attention layers, "out_proj" on linear attention layers). target_modules_set: set[str] = set() + + module_id_to_full_name = { + id(module): module_name + for module_name, module in self.model.named_modules() + } - for layer_index, layer in enumerate(self.get_layers()): - module_id_to_leaf_name = { - id(module): module_name.split(".")[-1] - for module_name, module in layer.named_modules() - } - + for layer_index in range(len(self.get_layers())): for modules in self.get_layer_modules(layer_index).values(): for module in modules: - if id(module) in module_id_to_leaf_name: - target_modules_set.add(module_id_to_leaf_name[id(module)]) + full_name = module_id_to_full_name.get(id(module)) + if full_name is not None: + target_modules_set.add(full_name) - target_modules = list(target_modules_set) + target_modules = sorted(target_modules_set) - if self.settings.row_normalization != RowNormalization.FULL: + if self.settings.use_ara_lora: + lora_rank = self.settings.ara_lora_rank + elif self.settings.row_normalization != RowNormalization.FULL: # Rank 1 is sufficient for directional ablation without renormalization. lora_rank = 1 else: @@ -224,7 +227,10 @@ def _apply_lora(self): # so the result is a PeftModel rather than a PeftMixedModel. self.model = cast(PeftModel, get_peft_model(self.model, self.peft_config)) - print(f"* LoRA adapters initialized (targets: {', '.join(target_modules)})") + display_targets = sorted({name.rsplit(".", 1)[-1] for name in target_modules}) + print( + f"* LoRA adapters initialized (target types: {', '.join(display_targets)})" + ) def _get_quantization_config(self, dtype: str) -> BitsAndBytesConfig | None: """ @@ -311,7 +317,7 @@ def reset_model(self): if ( current_model == self.settings.model and not self.needs_reload - and not self.settings.use_ara + and (not self.settings.use_ara or self.settings.use_ara_lora) ): # Reset LoRA adapters to zero (identity transformation) for name, module in self.model.named_modules(): @@ -341,7 +347,7 @@ def reset_model(self): **extra_kwargs, ) - if not self.settings.use_ara: + if not self.settings.use_ara or self.settings.use_ara_lora: self._apply_lora() self.needs_reload = False @@ -668,6 +674,118 @@ def closure() -> Tensor: with torch.no_grad(): matrix.copy_(get_matrix()) + def ara_lora_abliterate( + self, + good_module_io: ModuleIO, + bad_module_io: ModuleIO, + parameters: ARAParameters, + ): + for layer_index in range( + parameters.start_layer_index, + parameters.end_layer_index, + ): + for component, modules in self.get_layer_modules(layer_index).items(): + for module_index, module in enumerate(modules): + # Cast to Linear to access weights and LoRA adapters. + module = cast(Linear, module) + + # Base weight handling and dequantization. + # We need the base weight in float32 to compute the effective weight. + base_weight = cast(Tensor, module.base_layer.weight) + quant_state = getattr(base_weight, "quant_state", None) + + if quant_state is None: + W_base = base_weight.to(torch.float32) + else: + # Maintain the original dequantization logic for bitsandbytes. + W_base = cast( + Tensor, + bnb.functional.dequantize_4bit( + base_weight.data, + quant_state + ).to(torch.float32), + ) + + # Row normalization setup. + # Pre-calculate the original row norms to preserve them. + # This implements the RowNormalization.FULL logic. + W_row_norms = LA.vector_norm(W_base, dim=1, keepdim=True).detach() + + # Adapter target identification. + # We optimize the LoRA weights A and B. + lora_A = cast(Tensor, module.lora_A["default"].weight) + lora_B = cast(Tensor, module.lora_B["default"].weight) + + # Data preparation. + # Move I/O tensors to the device of the adapter weights. + good_input, good_output = good_module_io[layer_index][component][module_index] + bad_input, bad_output = bad_module_io[layer_index][component][module_index] + + good_input = good_input.float().to(lora_A.device) + good_output = good_output.float().to(lora_A.device) + bad_input = bad_input.float().to(lora_A.device) + bad_output = bad_output.float().to(lora_A.device) + + # The objective function. + def objective(A: Tensor, B: Tensor) -> Tensor: + # Calculate effective weight: W_eff = W_base + B @ A. + W_eff = W_base + (B @ A) + + # Apply Row Normalization (keep original norms). + if self.settings.row_normalization == RowNormalization.FULL: + # Normalize to unit length, then scale by original norms. + W_eff = F.normalize(W_eff, p=2, dim=1) * W_row_norms + + # Compute outputs using the effective weight. + new_good_output = good_input @ W_eff.T + new_bad_output = bad_input @ W_eff.T + + # The original ARA loss function. + preserve_good_behavior = ( + (new_good_output - good_output) ** 2 + ).mean() + + steer_bad_behavior = ( + mean_distances_to_knn( + new_bad_output, + good_output, + parameters.neighbor_count, + ).mean() + + parameters.overcorrect_relative_weight + * -mean_distances_to_knn( + new_bad_output, + bad_output, + parameters.neighbor_count, + ).mean() + ) + + return ( + parameters.preserve_good_behavior_weight + * preserve_good_behavior + + parameters.steer_bad_behavior_weight * steer_bad_behavior + ) + + # Optimization loop. + # We optimize A and B, not the base matrix. + optimizer = LBFGS( + [lora_A, lora_B], + lr=1.0, + max_iter=20, + history_size=10, + line_search_fn="strong_wolfe", + ) + + def closure(): + optimizer.zero_grad() + # Pass the actual tensors being optimized to the objective. + loss = objective(lora_A, lora_B) + loss.backward() + return loss + + # Run optimization steps. + for step in range(5): + optimizer.step(closure) + def generate( self, prompts: list[Prompt],