From 904edc575a018c3bb4a6ad96b2991a1f39e44c4c Mon Sep 17 00:00:00 2001 From: kabachuha Date: Mon, 25 May 2026 22:26:40 +0300 Subject: [PATCH 1/3] ARA, but it's LoRA --- src/heretic/config.py | 12 ++++ src/heretic/main.py | 22 ++++++- src/heretic/model.py | 146 ++++++++++++++++++++++++++++++++++++++---- 3 files changed, 164 insertions(+), 16 deletions(-) diff --git a/src/heretic/config.py b/src/heretic/config.py index 77de3c59..3f0e44db 100644 --- a/src/heretic/config.py +++ b/src/heretic/config.py @@ -203,6 +203,18 @@ class Settings(BaseSettings): "instead of traditional directional ablation." ), ) + + use_ara_lora: bool = Field( + default=False, + description=( + "Use LoRA in ARA instead of full-weight editing. Makes it compatible with quantization and removes model reloads." + ), + ) + + ara_lora_rank: int = Field( + default=128, + description="If LoRA is used in ARA, this sets up its rank. Keep it high enough to simulate the 'arbitrary' effect", + ) use_piqa: bool = Field( default=False, diff --git a/src/heretic/main.py b/src/heretic/main.py index 0abd11d7..ba8e4e0b 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -622,7 +622,16 @@ def objective(trial: Trial) -> tuple[float, float]: print("* Parameters:") for name, value in get_trial_parameters(settings, trial).items(): print(f" * {name} = [bold]{value}[/]") - if settings.use_ara: + if settings.use_ara_lora: + print("* Reloading model...") + model.reset_model() + print("* Abliterating (Arbitrary-Rank Ablation with LoRA)...") + model.ara_lora_abliterate( + good_module_io, + bad_module_io, + ARAParameters(**trial.user_attrs["ara_parameters"]), + ) + elif settings.use_ara: print("* Reloading model...") model.reset_model() print("* Abliterating (Arbitrary-Rank Ablation)...") @@ -810,7 +819,16 @@ def count_completed_trials() -> int: print("* Parameters:") for name, value in get_trial_parameters(settings, trial).items(): print(f" * {name} = [bold]{value}[/]") - if settings.use_ara: + if settings.use_ara_lora: + print("* Resetting model...") + model.reset_model() + print("* Abliterating (Arbitrary-Rank Ablation with LoRA)...") + model.ara_lora_abliterate( + good_module_io, + bad_module_io, + ARAParameters(**trial.user_attrs["ara_parameters"]), + ) + elif settings.use_ara: print("* Reloading model...") model.reset_model() print("* Abliterating (Arbitrary-Rank Ablation)...") diff --git a/src/heretic/model.py b/src/heretic/model.py index 108e9dfd..338a9abf 100644 --- a/src/heretic/model.py +++ b/src/heretic/model.py @@ -161,7 +161,7 @@ def __init__(self, settings: Settings): if self.model is None: raise Exception("Failed to load model with all configured dtypes.") - if not settings.use_ara: + if not settings.use_ara or settings.use_ara_lora: self._apply_lora() # LoRA B matrices are initialized to zero by default in PEFT, @@ -188,21 +188,24 @@ def _apply_lora(self): # because hybrid models like Qwen3.5 MoE have modules with different names # across layers (e.g. "o_proj" on attention layers, "out_proj" on linear attention layers). target_modules_set: set[str] = set() + + module_id_to_full_name = { + id(module): module_name + for module_name, module in self.model.named_modules() + } - for layer_index, layer in enumerate(self.get_layers()): - module_id_to_leaf_name = { - id(module): module_name.split(".")[-1] - for module_name, module in layer.named_modules() - } - + for layer_index in range(len(self.get_layers())): for modules in self.get_layer_modules(layer_index).values(): for module in modules: - if id(module) in module_id_to_leaf_name: - target_modules_set.add(module_id_to_leaf_name[id(module)]) + full_name = module_id_to_full_name.get(id(module)) + if full_name is not None: + target_modules_set.add(full_name) - target_modules = list(target_modules_set) + target_modules = sorted(target_modules_set) - if self.settings.row_normalization != RowNormalization.FULL: + if self.settings.use_ara_lora: + lora_rank = self.settings.ara_lora_rank + elif self.settings.row_normalization != RowNormalization.FULL: # Rank 1 is sufficient for directional ablation without renormalization. lora_rank = 1 else: @@ -224,7 +227,10 @@ def _apply_lora(self): # so the result is a PeftModel rather than a PeftMixedModel. self.model = cast(PeftModel, get_peft_model(self.model, self.peft_config)) - print(f"* LoRA adapters initialized (targets: {', '.join(target_modules)})") + display_targets = sorted({name.rsplit(".", 1)[-1] for name in target_modules}) + print( + f"* LoRA adapters initialized (target types: {', '.join(display_targets)})" + ) def _get_quantization_config(self, dtype: str) -> BitsAndBytesConfig | None: """ @@ -311,7 +317,7 @@ def reset_model(self): if ( current_model == self.settings.model and not self.needs_reload - and not self.settings.use_ara + and (not self.settings.use_ara or self.settings.use_ara_lora) ): # Reset LoRA adapters to zero (identity transformation) for name, module in self.model.named_modules(): @@ -341,7 +347,7 @@ def reset_model(self): **extra_kwargs, ) - if not self.settings.use_ara: + if not self.settings.use_ara or self.settings.use_ara_lora: self._apply_lora() self.needs_reload = False @@ -668,6 +674,118 @@ def closure() -> Tensor: with torch.no_grad(): matrix.copy_(get_matrix()) + def ara_lora_abliterate( + self, + good_module_io: ModuleIO, + bad_module_io: ModuleIO, + parameters: ARAParameters, + ): + for layer_index in range( + parameters.start_layer_index, + parameters.end_layer_index, + ): + for component, modules in self.get_layer_modules(layer_index).items(): + for module_index, module in enumerate(modules): + # Cast to Linear to access weights and LoRA adapters + module = cast(Linear, module) + + # --- 1. Base Weight Handling & Dequantization --- + # We need the base weight in float32 to compute the effective weight + base_weight = cast(Tensor, module.base_layer.weight) + quant_state = getattr(base_weight, "quant_state", None) + + if quant_state is None: + W_base = base_weight.to(torch.float32) + else: + # Maintain the original dequantization logic for bitsandbytes + W_base = cast( + Tensor, + bnb.functional.dequantize_4bit( + base_weight.data, + quant_state + ).to(torch.float32), + ) + + # --- 2. Row Normalization Setup --- + # Pre-calculate the original row norms to preserve them + # This implements the RowNormalization.FULL logic + W_row_norms = LA.vector_norm(W_base, dim=1, keepdim=True) + + # --- 3. Adapter Target Identification --- + # We optimize the LoRA weights A and B + lora_A = cast(Tensor, module.lora_A["default"].weight) + lora_B = cast(Tensor, module.lora_B["default"].weight) + + # --- 4. Data Preparation --- + # Move I/O tensors to the device of the adapter weights + good_input, good_output = good_module_io[layer_index][component][module_index] + bad_input, bad_output = bad_module_io[layer_index][component][module_index] + + good_input = good_input.float().to(lora_A.device) + good_output = good_output.float().to(lora_A.device) + bad_input = bad_input.float().to(lora_A.device) + bad_output = bad_output.float().to(lora_A.device) + + # --- 5. The Objective Function --- + def objective(A: Tensor, B: Tensor) -> Tensor: + # Calculate effective weight: W_eff = W_base + B @ A + W_eff = W_base + (B @ A) + + # Apply Row Normalization (keep original norms) + if self.settings.row_normalization == RowNormalization.FULL: + # Normalize to unit length, then scale by original norms + W_eff = F.normalize(W_eff, p=2, dim=1) * W_row_norms + + # Compute outputs using the effective weight + new_good_output = good_input @ W_eff.T + new_bad_output = bad_input @ W_eff.T + + # The original ARA loss function + preserve_good_behavior = ( + (new_good_output - good_output) ** 2 + ).mean() + + steer_bad_behavior = ( + mean_distances_to_knn( + new_bad_output, + good_output, + parameters.neighbor_count, + ).mean() + + parameters.overcorrect_relative_weight + * -mean_distances_to_knn( + new_bad_output, + bad_output, + parameters.neighbor_count, + ).mean() + ) + + return ( + parameters.preserve_good_behavior_weight + * preserve_good_behavior + + parameters.steer_bad_behavior_weight * steer_bad_behavior + ) + + # --- 6. Optimization Loop --- + # We optimize A and B, not the base matrix + optimizer = LBFGS( + [lora_A, lora_B], + lr=1.0, + max_iter=20, + history_size=10, + line_search_fn="strong_wolfe", + ) + + def closure(): + optimizer.zero_grad() + # Pass the actual tensors being optimized to the objective + loss = objective(lora_A, lora_B) + loss.backward() + return loss + + # Run optimization steps + for step in range(5): + optimizer.step(closure) + def generate( self, prompts: list[Prompt], From 5567b3d7e88a5cc96ffca54effe010081f5e02bd Mon Sep 17 00:00:00 2001 From: kabachuha Date: Tue, 26 May 2026 21:05:27 +0300 Subject: [PATCH 2/3] ARA, but it's LoRA: address Gemini's review --- src/heretic/config.py | 2 +- src/heretic/main.py | 2 +- src/heretic/model.py | 56 +++++++++++++++++++++++++------------------ 3 files changed, 35 insertions(+), 25 deletions(-) diff --git a/src/heretic/config.py b/src/heretic/config.py index 3f0e44db..b5bdb6bf 100644 --- a/src/heretic/config.py +++ b/src/heretic/config.py @@ -213,7 +213,7 @@ class Settings(BaseSettings): ara_lora_rank: int = Field( default=128, - description="If LoRA is used in ARA, this sets up its rank. Keep it high enough to simulate the 'arbitrary' effect", + description="If LoRA is used in ARA, this sets up its rank. Keep it high enough to simulate the 'arbitrary' effect.", ) use_piqa: bool = Field( diff --git a/src/heretic/main.py b/src/heretic/main.py index ba8e4e0b..b633484b 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -623,7 +623,7 @@ def objective(trial: Trial) -> tuple[float, float]: for name, value in get_trial_parameters(settings, trial).items(): print(f" * {name} = [bold]{value}[/]") if settings.use_ara_lora: - print("* Reloading model...") + print("* Resetting model...") model.reset_model() print("* Abliterating (Arbitrary-Rank Ablation with LoRA)...") model.ara_lora_abliterate( diff --git a/src/heretic/model.py b/src/heretic/model.py index 338a9abf..041b7248 100644 --- a/src/heretic/model.py +++ b/src/heretic/model.py @@ -686,18 +686,18 @@ def ara_lora_abliterate( ): for component, modules in self.get_layer_modules(layer_index).items(): for module_index, module in enumerate(modules): - # Cast to Linear to access weights and LoRA adapters + # Cast to Linear to access weights and LoRA adapters. module = cast(Linear, module) - # --- 1. Base Weight Handling & Dequantization --- - # We need the base weight in float32 to compute the effective weight + # Base weight handling and dequantization. + # We need the base weight in float32 to compute the effective weight. base_weight = cast(Tensor, module.base_layer.weight) quant_state = getattr(base_weight, "quant_state", None) if quant_state is None: W_base = base_weight.to(torch.float32) else: - # Maintain the original dequantization logic for bitsandbytes + # Maintain the original dequantization logic for bitsandbytes. W_base = cast( Tensor, bnb.functional.dequantize_4bit( @@ -706,18 +706,23 @@ def ara_lora_abliterate( ).to(torch.float32), ) - # --- 2. Row Normalization Setup --- - # Pre-calculate the original row norms to preserve them - # This implements the RowNormalization.FULL logic - W_row_norms = LA.vector_norm(W_base, dim=1, keepdim=True) + # Row normalization setup. + # Pre-calculate the original row norms to preserve them. + # This implements the RowNormalization.FULL logic. + W_row_norms = LA.vector_norm(W_base, dim=1, keepdim=True).detach() - # --- 3. Adapter Target Identification --- - # We optimize the LoRA weights A and B + # Adapter target identification. + # We optimize the LoRA weights A and B. lora_A = cast(Tensor, module.lora_A["default"].weight) lora_B = cast(Tensor, module.lora_B["default"].weight) - # --- 4. Data Preparation --- - # Move I/O tensors to the device of the adapter weights + # Create float32 copies for stable optimization. + # LBFGS is numerically unstable when optimizing float16/bfloat16 parameters directly. + A_opt = lora_A.detach().float().clone().requires_grad_(True) + B_opt = lora_B.detach().float().clone().requires_grad_(True) + + # Data preparation. + # Move I/O tensors to the device of the adapter weights. good_input, good_output = good_module_io[layer_index][component][module_index] bad_input, bad_output = bad_module_io[layer_index][component][module_index] @@ -726,21 +731,21 @@ def ara_lora_abliterate( bad_input = bad_input.float().to(lora_A.device) bad_output = bad_output.float().to(lora_A.device) - # --- 5. The Objective Function --- + # The objective function. def objective(A: Tensor, B: Tensor) -> Tensor: - # Calculate effective weight: W_eff = W_base + B @ A + # Calculate effective weight: W_eff = W_base + B @ A. W_eff = W_base + (B @ A) # Apply Row Normalization (keep original norms) if self.settings.row_normalization == RowNormalization.FULL: - # Normalize to unit length, then scale by original norms + # Normalize to unit length, then scale by original norms. W_eff = F.normalize(W_eff, p=2, dim=1) * W_row_norms - # Compute outputs using the effective weight + # Compute outputs using the effective weight. new_good_output = good_input @ W_eff.T new_bad_output = bad_input @ W_eff.T - # The original ARA loss function + # The original ARA loss function. preserve_good_behavior = ( (new_good_output - good_output) ** 2 ).mean() @@ -765,10 +770,10 @@ def objective(A: Tensor, B: Tensor) -> Tensor: + parameters.steer_bad_behavior_weight * steer_bad_behavior ) - # --- 6. Optimization Loop --- - # We optimize A and B, not the base matrix + # Optimization loop. + # We optimize A and B, not the base matrix. optimizer = LBFGS( - [lora_A, lora_B], + [A_opt, B_opt], lr=1.0, max_iter=20, history_size=10, @@ -777,15 +782,20 @@ def objective(A: Tensor, B: Tensor) -> Tensor: def closure(): optimizer.zero_grad() - # Pass the actual tensors being optimized to the objective - loss = objective(lora_A, lora_B) + # Pass the actual tensors being optimized to the objective. + loss = objective(A_opt, B_opt) loss.backward() return loss - # Run optimization steps + # Run optimization steps. for step in range(5): optimizer.step(closure) + # Copy the optimized weights back to the original parameters. + with torch.no_grad(): + lora_A.copy_(A_opt.to(lora_A.dtype)) + lora_B.copy_(B_opt.to(lora_B.dtype)) + def generate( self, prompts: list[Prompt], From bf41e81c608a0876deec5311549e608916d9ee51 Mon Sep 17 00:00:00 2001 From: kabachuha Date: Wed, 27 May 2026 12:10:31 +0300 Subject: [PATCH 3/3] ARA, but it's LoRA: Gemini is stupid --- src/heretic/model.py | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/src/heretic/model.py b/src/heretic/model.py index 041b7248..53130389 100644 --- a/src/heretic/model.py +++ b/src/heretic/model.py @@ -716,11 +716,6 @@ def ara_lora_abliterate( lora_A = cast(Tensor, module.lora_A["default"].weight) lora_B = cast(Tensor, module.lora_B["default"].weight) - # Create float32 copies for stable optimization. - # LBFGS is numerically unstable when optimizing float16/bfloat16 parameters directly. - A_opt = lora_A.detach().float().clone().requires_grad_(True) - B_opt = lora_B.detach().float().clone().requires_grad_(True) - # Data preparation. # Move I/O tensors to the device of the adapter weights. good_input, good_output = good_module_io[layer_index][component][module_index] @@ -736,7 +731,7 @@ def objective(A: Tensor, B: Tensor) -> Tensor: # Calculate effective weight: W_eff = W_base + B @ A. W_eff = W_base + (B @ A) - # Apply Row Normalization (keep original norms) + # Apply Row Normalization (keep original norms). if self.settings.row_normalization == RowNormalization.FULL: # Normalize to unit length, then scale by original norms. W_eff = F.normalize(W_eff, p=2, dim=1) * W_row_norms @@ -773,7 +768,7 @@ def objective(A: Tensor, B: Tensor) -> Tensor: # Optimization loop. # We optimize A and B, not the base matrix. optimizer = LBFGS( - [A_opt, B_opt], + [lora_A, lora_B], lr=1.0, max_iter=20, history_size=10, @@ -783,7 +778,7 @@ def objective(A: Tensor, B: Tensor) -> Tensor: def closure(): optimizer.zero_grad() # Pass the actual tensors being optimized to the objective. - loss = objective(A_opt, B_opt) + loss = objective(lora_A, lora_B) loss.backward() return loss @@ -791,11 +786,6 @@ def closure(): for step in range(5): optimizer.step(closure) - # Copy the optimized weights back to the original parameters. - with torch.no_grad(): - lora_A.copy_(A_opt.to(lora_A.dtype)) - lora_B.copy_(B_opt.to(lora_B.dtype)) - def generate( self, prompts: list[Prompt],