From 4df78d733033701a2fee1037271cb6eb3bdc767e Mon Sep 17 00:00:00 2001 From: Philipp Emanuel Weidmann Date: Tue, 12 May 2026 12:11:48 +0530 Subject: [PATCH 1/4] feat: load reproduction information --- src/heretic/config.py | 9 +++++++++ src/heretic/main.py | 13 +++++++++++-- src/heretic/reproduce.py | 19 +++++++++++++++++++ 3 files changed, 39 insertions(+), 2 deletions(-) diff --git a/src/heretic/config.py b/src/heretic/config.py index c2879f32..e12821bf 100644 --- a/src/heretic/config.py +++ b/src/heretic/config.py @@ -113,6 +113,15 @@ class Settings(BaseSettings): exclude=True, ) + reproduce: str | None = Field( + default=None, + description=( + "If this path or URL to a reproduce.json file is set, load reproduction information " + "from that file, and attempt to reproduce the abliterated model it originated from." + ), + exclude=True, + ) + dtypes: list[str] = Field( default=[ # In practice, "auto" almost always means bfloat16. diff --git a/src/heretic/main.py b/src/heretic/main.py index 48eece75..431a81af 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -65,7 +65,7 @@ def _is_help_invocation() -> bool: from .config import QuantizationMethod from .evaluator import Evaluator from .model import AbliterationParameters, Model, get_model_class -from .reproduce import collect_reproducibles +from .reproduce import collect_reproducibles, load_reproduction_information from .system import empty_cache, get_accelerator_info from .utils import ( format_duration, @@ -175,6 +175,7 @@ def run(): len(sys.argv) > 1 # Heretic is being invoked in standard (model processing) mode. and "--collect-reproducibles" not in sys.argv + and "--reproduce" not in sys.argv # No model has been explicitly provided. and "--model" not in sys.argv # The last argument is a parameter value rather than a flag (such as "--help"). @@ -185,7 +186,9 @@ def run(): # Work around the "model" argument being required # when Heretic is invoked in a non-processing mode. - if "--collect-reproducibles" in sys.argv and "--model" not in sys.argv: + if ( + "--collect-reproducibles" in sys.argv or "--reproduce" in sys.argv + ) and "--model" not in sys.argv: sys.argv.extend(["--model", ""]) try: @@ -208,6 +211,12 @@ def run(): collect_reproducibles(settings.collect_reproducibles) return + if settings.reproduce is not None: + print(f"Loading reproduction information from [bold]{settings.reproduce}[/]...") + reproduction_information = load_reproduction_information(settings.reproduce) + print(reproduction_information) + return + if settings.seed is None: settings.seed = random.randint(0, 2**32 - 1) diff --git a/src/heretic/reproduce.py b/src/heretic/reproduce.py index 52c0f87d..7717cee6 100644 --- a/src/heretic/reproduce.py +++ b/src/heretic/reproduce.py @@ -1,8 +1,11 @@ # SPDX-License-Identifier: AGPL-3.0-or-later # Copyright (C) 2025-2026 Philipp Emanuel Weidmann + contributors +import json import shutil from pathlib import Path +from typing import Any +from urllib.request import urlopen from huggingface_hub import HfApi, hf_hub_download from huggingface_hub.utils import disable_progress_bars, enable_progress_bars @@ -81,3 +84,19 @@ def collect_reproducibles(path: str): print(f"Found: [bold]{found}[/] files") print(f"Downloaded: [bold]{downloaded}[/] files") print(f"Already stored: [bold]{found - downloaded}[/] files") + + +def load_reproduction_information(path: str) -> dict[str, Any]: + if path.lower().startswith(("http://", "https://")): + # The path is a URL on the web. + + # Obtain raw download URL. + path = path.replace("/blob/", "/raw/") # Hugging Face, GitHub + path = path.replace("/src/branch/", "/raw/branch/") # Codeberg + + json_str = urlopen(path).read().decode("utf-8") + else: + # The path is (assumed to be) a local file system path. + json_str = Path(path).read_text(encoding="utf-8") + + return json.loads(json_str) From 223d0f88bf33620524dd14612c1e75f5396b7c37 Mon Sep 17 00:00:00 2001 From: Philipp Emanuel Weidmann Date: Thu, 21 May 2026 19:34:08 +0530 Subject: [PATCH 2/4] feat: check reproduction environment against original environment --- src/heretic/main.py | 21 ++- src/heretic/reproduce.py | 273 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 290 insertions(+), 4 deletions(-) diff --git a/src/heretic/main.py b/src/heretic/main.py index 431a81af..1fb49739 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -65,7 +65,11 @@ def _is_help_invocation() -> bool: from .config import QuantizationMethod from .evaluator import Evaluator from .model import AbliterationParameters, Model, get_model_class -from .reproduce import collect_reproducibles, load_reproduction_information +from .reproduce import ( + check_environment, + collect_reproducibles, + load_reproduction_information, +) from .system import empty_cache, get_accelerator_info from .utils import ( format_duration, @@ -213,8 +217,21 @@ def run(): if settings.reproduce is not None: print(f"Loading reproduction information from [bold]{settings.reproduce}[/]...") + # FIXME: "Reproduction"/"reproducibility" name inconsistency! reproduction_information = load_reproduction_information(settings.reproduce) - print(reproduction_information) + + if reproduction_information["version"] not in ["1"]: + print( + ( + f"[red]Unsupported file format version: [bold]{reproduction_information['version']}[/].[/] " + "Try loading the file with a newer version of Heretic." + ) + ) + return + + if not check_environment(reproduction_information): + return + return if settings.seed is None: diff --git a/src/heretic/reproduce.py b/src/heretic/reproduce.py index 7717cee6..7717dee9 100644 --- a/src/heretic/reproduce.py +++ b/src/heretic/reproduce.py @@ -2,15 +2,28 @@ # Copyright (C) 2025-2026 Philipp Emanuel Weidmann + contributors import json +import platform +import random import shutil +from dataclasses import asdict +from enum import IntEnum from pathlib import Path -from typing import Any +from typing import Any, cast from urllib.request import urlopen +import cpuinfo +import torch from huggingface_hub import HfApi, hf_hub_download from huggingface_hub.utils import disable_progress_bars, enable_progress_bars +from questionary import Choice +from rich.table import Table -from .utils import print +from .system import ( + get_accelerator_info_dict, + get_heretic_version_info, + get_requirements_dict, +) +from .utils import print, prompt_select def collect_reproducibles(path: str): @@ -100,3 +113,259 @@ def load_reproduction_information(path: str) -> dict[str, Any]: json_str = Path(path).read_text(encoding="utf-8") return json.loads(json_str) + + +class MismatchSeverity(IntEnum): + LOW = 1 + MEDIUM = 2 + HIGH = 3 + CRITICAL = 4 + + def __rich__(self) -> str: + match self: + case MismatchSeverity.LOW: + return "[green]low[/]" + case MismatchSeverity.MEDIUM: + return "[yellow]medium[/]" + case MismatchSeverity.HIGH: + return "[red]high[/]" + case MismatchSeverity.CRITICAL: + return "[bold red]critical[/]" + case _: + raise ValueError(f"unknown MismatchSeverity value: {self}") + + +def get_package_mismatch_severity(package_name: str) -> MismatchSeverity: + if package_name in [ + "heretic-llm", + ]: + return MismatchSeverity.CRITICAL + elif package_name in [ + "torch", + "transformers", + ]: + return MismatchSeverity.HIGH + elif package_name in [ + "accelerate", + "bitsandbytes", + "kernels", + "optuna", + "peft", + "tokenizers", + "triton", + ]: + return MismatchSeverity.MEDIUM + else: + return MismatchSeverity.LOW + + +def format_version_information(version_information: dict[str, Any]) -> str: + version = version_information["version"] + metadata = version_information["metadata"] + + if "type" in metadata: + match metadata["type"]: + case "pypi": + return version + case "git": + return f"{version}-git+{metadata['url']}@{metadata['commit_hash']}" + case "local": + # Append a random number to ensure that two local installations + # are always considered to be different versions. + return f"{version}-local-{random.randint(2**16, 2**17)}" + case _: + raise ValueError( + f"unknown metadata.type value in version information: {metadata['type']}" + ) + else: + return f"{version}-unknown-{random.randint(2**16, 2**17)}" + + +def check_environment(reproduction_information: dict[str, Any]) -> bool: + mismatch_severity: MismatchSeverity | None = None + + system_mismatches = [] + package_mismatches = [] + + def verify( + mismatch_list: list[tuple[str, Any, Any, MismatchSeverity]], + name: str, + this: Any, + original: Any, + severity: MismatchSeverity, + ): + nonlocal mismatch_severity + if this != original: + mismatch_list.append((name, this, original, severity)) + if mismatch_severity is None: + mismatch_severity = severity + else: + mismatch_severity = max(severity, mismatch_severity) + + if "system" in reproduction_information: + system = reproduction_information["system"] + + verify( + system_mismatches, + "Python version", + platform.python_version(), + system["python"]["version"], + MismatchSeverity.LOW, + ) + + verify( + system_mismatches, + "Operating system", + platform.platform(), + system["os"]["platform"], + MismatchSeverity.LOW, + ) + + verify( + system_mismatches, + "CPU", + cpuinfo.get_cpu_info().get("brand_raw"), + system["cpu"]["brand"], + MismatchSeverity.LOW, + ) + + accelerators = get_accelerator_info_dict() + + verify( + system_mismatches, + "Accelerator type", + accelerators["type"], + system["accelerators"]["type"], + MismatchSeverity.HIGH, + ) + + if ( + accelerators["type"] + and accelerators["type"] == system["accelerators"]["type"] + ): + verify( + system_mismatches, + accelerators["api_name"], + accelerators["api_version"], + system["accelerators"]["api_version"], + MismatchSeverity.MEDIUM, + ) + verify( + system_mismatches, + "Driver version", + accelerators["driver_version"], + system["accelerators"]["driver_version"], + MismatchSeverity.MEDIUM, + ) + verify( + system_mismatches, + "Devices", + "\n".join([device["name"] for device in accelerators["devices"]]), + "\n".join( + [device["name"] for device in system["accelerators"]["devices"]] + ), + MismatchSeverity.MEDIUM, + ) + + else: + print( + ( + "[yellow]The provided JSON file does not contain system information. " + "Some system parameters can affect reproducibility, but due to the lack of system information, " + "Heretic is unable to verify that those parameters match the original environment. " + "Reproduction may or may not produce a byte-for-byte identical model.[/]" + ) + ) + + requirements = get_requirements_dict() + requirements["heretic-llm"] = format_version_information( + asdict(get_heretic_version_info()) + ) + requirements["torch"] = torch.__version__ + + original_requirements = reproduction_information["environment"]["requirements"] + original_requirements["heretic-llm"] = format_version_information( + reproduction_information["environment"]["heretic"] + ) + original_requirements["torch"] = reproduction_information["environment"][ + "pytorch_version" + ] + + package_names = sorted(requirements.keys() | original_requirements.keys()) + + for package_name in package_names: + verify( + package_mismatches, + package_name, + requirements.get(package_name), + original_requirements.get(package_name), + get_package_mismatch_severity(package_name), + ) + + if system_mismatches or package_mismatches: + print() + print( + ( + "[yellow]Your local environment doesn't perfectly match the environment " + "used to produce the original model. The following components differ:[/]" + ) + ) + + if system_mismatches: + table = Table() + table.add_column("Component") + table.add_column("This system", overflow="fold") + table.add_column("Original system", overflow="fold") + table.add_column("Severity", width=8) + + for component, this, original, severity in system_mismatches: + table.add_row(f"[bold]{component}[/]", this, original, severity) + + print() + print("[bold]System Mismatches[/]") + print(table) + + if package_mismatches: + table = Table() + table.add_column("Package") + table.add_column("This system", overflow="fold") + table.add_column("Original system", overflow="fold") + table.add_column("Severity", width=8) + + for package, this, original, severity in package_mismatches: + table.add_row(f"[bold]{package}[/]", this, original, severity) + + print() + print("[bold]Package Mismatches[/]") + print(table) + + if system_mismatches or package_mismatches: + print() + print( + ( + f"There is a {cast(MismatchSeverity, mismatch_severity).__rich__()} chance " + "that reproduction won't produce a byte-for-byte identical model. " + "However, the resulting model will very likely still behave similarly " + "to the original model." + ) + ) + + print() + choice = prompt_select( + "How would you like to proceed?", + [ + Choice( + title="Attempt to reproduce the model anyway", + value=True, + ), + Choice( + title="Exit program", + value=False, + ), + ], + ) + + return choice + else: + # There are no mismatches at all, so there is nothing to confirm. + return True From 96e4d2b381b7f6cc88ec019f7ad433664a33e5c7 Mon Sep 17 00:00:00 2001 From: Philipp Emanuel Weidmann Date: Fri, 29 May 2026 11:38:04 +0530 Subject: [PATCH 3/4] fix: remove `trust_remote_code` setting This improves security when running Heretic with an untrusted config file. The prompt is now always shown. This is NOT a breaking change, because we currently ignore values for unknown settings, so existing configs continue to work. --- src/heretic/config.py | 7 ------- src/heretic/main.py | 4 +++- src/heretic/model.py | 23 +++++++++++++---------- 3 files changed, 16 insertions(+), 18 deletions(-) diff --git a/src/heretic/config.py b/src/heretic/config.py index e12821bf..cfaf44e4 100644 --- a/src/heretic/config.py +++ b/src/heretic/config.py @@ -170,13 +170,6 @@ class Settings(BaseSettings): ), ) - trust_remote_code: bool | None = Field( - default=None, - description="Whether to trust remote code when loading the model.", - # For security reasons, we don't store this setting. - exclude=True, - ) - batch_size: int = Field( default=0, # auto description="Number of input sequences to process in parallel (0 = auto).", diff --git a/src/heretic/main.py b/src/heretic/main.py index 1fb49739..17fcd5ca 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -117,7 +117,9 @@ def obtain_merge_strategy(settings: Settings, model: Model) -> str | None: settings.model, device_map="meta", torch_dtype=torch.bfloat16, - trust_remote_code=model.trusted_models.get(settings.model), + trust_remote_code=True + if settings.model in model.trusted_models + else None, **model.revision_kwargs, ) footprint_bytes = meta_model.get_memory_footprint() diff --git a/src/heretic/model.py b/src/heretic/model.py index 41a8e71c..06e17116 100644 --- a/src/heretic/model.py +++ b/src/heretic/model.py @@ -71,7 +71,6 @@ def __init__(self, settings: Settings): self.tokenizer = AutoTokenizer.from_pretrained( settings.model, - trust_remote_code=settings.trust_remote_code, **self.revision_kwargs, ) @@ -90,10 +89,8 @@ def __init__(self, settings: Settings): if settings.max_memory else None ) - self.trusted_models = {settings.model: settings.trust_remote_code} - if self.settings.evaluate_model is not None: - self.trusted_models[settings.evaluate_model] = settings.trust_remote_code + self.trusted_models = set() for dtype in settings.dtypes: print(f"* Trying dtype [bold]{dtype}[/]...") @@ -112,15 +109,17 @@ def __init__(self, settings: Settings): dtype=dtype, device_map=settings.device_map, max_memory=self.max_memory, - trust_remote_code=self.trusted_models.get(settings.model), + trust_remote_code=True + if settings.model in self.trusted_models + else None, **self.revision_kwargs, **extra_kwargs, ) # If we reach this point and the model requires trust_remote_code, - # either the user accepted, or settings.trust_remote_code is True. - if self.trusted_models.get(settings.model) is None: - self.trusted_models[settings.model] = True + # the user must have agreed when prompted to execute remote code, + # because from_pretrained raises an exception otherwise. + self.trusted_models.add(settings.model) # A test run can reveal dtype-related problems such as the infamous # "RuntimeError: probability tensor contains either `inf`, `nan` or element < 0" @@ -264,7 +263,9 @@ def get_merged_model(self) -> PreTrainedModel: self.settings.model, torch_dtype=self.model.dtype, device_map="cpu", - trust_remote_code=self.trusted_models.get(self.settings.model), + trust_remote_code=True + if self.settings.model in self.trusted_models + else None, **self.revision_kwargs, ) @@ -326,7 +327,9 @@ def reset_model(self): dtype=dtype, device_map=self.settings.device_map, max_memory=self.max_memory, - trust_remote_code=self.trusted_models.get(self.settings.model), + trust_remote_code=True + if self.settings.model in self.trusted_models + else None, **self.revision_kwargs, **extra_kwargs, ) From 3bcb2f6d2b5fd419af855e3d52690bc7cdc58d77 Mon Sep 17 00:00:00 2001 From: Philipp Emanuel Weidmann Date: Thu, 4 Jun 2026 11:28:02 +0530 Subject: [PATCH 4/4] feat: reproduce model from JSON file --- src/heretic/main.py | 312 +++++++++++++++++++++++++------------------ src/heretic/utils.py | 13 +- 2 files changed, 186 insertions(+), 139 deletions(-) diff --git a/src/heretic/main.py b/src/heretic/main.py index 17fcd5ca..15237cd4 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -55,7 +55,7 @@ def _is_help_invocation() -> bool: from optuna.storages import JournalStorage from optuna.storages.journal import JournalFileBackend, JournalFileOpenLock from optuna.study import StudyDirection -from optuna.trial import TrialState +from optuna.trial import TrialState, create_trial from pydantic import ValidationError from questionary import Choice, Style from rich.table import Table @@ -217,6 +217,8 @@ def run(): collect_reproducibles(settings.collect_reproducibles) return + reproduction_mode = settings.reproduce is not None + if settings.reproduce is not None: print(f"Loading reproduction information from [bold]{settings.reproduce}[/]...") # FIXME: "Reproduction"/"reproducibility" name inconsistency! @@ -234,7 +236,9 @@ def run(): if not check_environment(reproduction_information): return - return + print() + + settings = Settings.model_validate(reproduction_information["settings"]) if settings.seed is None: settings.seed = random.randint(0, 2**32 - 1) @@ -285,7 +289,11 @@ def run(): except IndexError: existing_study = None - if existing_study is not None and settings.evaluate_model is None: + if ( + existing_study is not None + and settings.evaluate_model is None + and not reproduction_mode + ): choices = [] if existing_study.user_attrs["finished"]: @@ -625,155 +633,183 @@ def objective_wrapper(trial: Trial) -> tuple[float, float]: trial.study.stop() raise TrialPruned() - study = optuna.create_study( - sampler=TPESampler( - n_startup_trials=settings.n_startup_trials, - n_ei_candidates=128, - multivariate=True, - seed=settings.seed, - ), - directions=[StudyDirection.MINIMIZE, StudyDirection.MINIMIZE], - storage=storage, - study_name="heretic", - load_if_exists=True, - ) + if not reproduction_mode: + study = optuna.create_study( + sampler=TPESampler( + n_startup_trials=settings.n_startup_trials, + n_ei_candidates=128, + multivariate=True, + seed=settings.seed, + ), + directions=[StudyDirection.MINIMIZE, StudyDirection.MINIMIZE], + storage=storage, + study_name="heretic", + load_if_exists=True, + ) - study.set_user_attr("settings", settings.model_dump_json()) - study.set_user_attr("finished", False) + study.set_user_attr("settings", settings.model_dump_json()) + study.set_user_attr("finished", False) - def count_completed_trials() -> int: - # Count number of complete trials to compute trials to run. - return sum([(1 if t.state == TrialState.COMPLETE else 0) for t in study.trials]) + def count_completed_trials() -> int: + # Count number of complete trials to compute trials to run. + return sum( + [(1 if t.state == TrialState.COMPLETE else 0) for t in study.trials] + ) - start_index = trial_index = count_completed_trials() - if start_index > 0: - print() - print("Resuming existing study.") + start_index = trial_index = count_completed_trials() + if start_index > 0: + print() + print("Resuming existing study.") - try: - study.optimize( - objective_wrapper, - n_trials=settings.n_trials - count_completed_trials(), - ) - except KeyboardInterrupt: - # This additional handler takes care of the small chance that KeyboardInterrupt - # is raised just between trials, which wouldn't be caught by the handler - # defined in objective_wrapper above. - pass + try: + study.optimize( + objective_wrapper, + n_trials=settings.n_trials - count_completed_trials(), + ) + except KeyboardInterrupt: + # This additional handler takes care of the small chance that KeyboardInterrupt + # is raised just between trials, which wouldn't be caught by the handler + # defined in objective_wrapper above. + pass - if count_completed_trials() == settings.n_trials: - study.set_user_attr("finished", True) + if count_completed_trials() == settings.n_trials: + study.set_user_attr("finished", True) while True: - # If no trials at all have been evaluated, the study must have been stopped - # by pressing Ctrl+C while the first trial was running. In this case, we just - # re-raise the interrupt to invoke the standard handler defined below. - completed_trials = [t for t in study.trials if t.state == TrialState.COMPLETE] - if not completed_trials: - raise KeyboardInterrupt - - # Get the Pareto front of trials. We can't use study.best_trials directly - # as get_score() doesn't return the pure KL divergence and refusal count. - # Note: Unlike study.best_trials, this does not handle objective constraints. - sorted_trials = sorted( - completed_trials, - key=lambda trial: ( - trial.user_attrs["refusals"], - trial.user_attrs["kl_divergence"], - ), - ) - min_divergence = math.inf - best_trials = [] - for trial in sorted_trials: - kl_divergence = trial.user_attrs["kl_divergence"] - if kl_divergence < min_divergence: - min_divergence = kl_divergence - best_trials.append(trial) - - choices = [ - Choice( - title=( - f"[Trial {trial.user_attrs['index']:>3}] " - f"Refusals: {trial.user_attrs['refusals']:>2}/{len(evaluator.bad_prompts)}, " - f"KL divergence: {trial.user_attrs['kl_divergence']:.4f}" + if not reproduction_mode: + # If no trials at all have been evaluated, the study must have been stopped + # by pressing Ctrl+C while the first trial was running. In this case, we just + # re-raise the interrupt to invoke the standard handler defined below. + completed_trials = [ + t for t in study.trials if t.state == TrialState.COMPLETE + ] + if not completed_trials: + raise KeyboardInterrupt + + # Get the Pareto front of trials. We can't use study.best_trials directly + # as get_score() doesn't return the pure KL divergence and refusal count. + # Note: Unlike study.best_trials, this does not handle objective constraints. + sorted_trials = sorted( + completed_trials, + key=lambda trial: ( + trial.user_attrs["refusals"], + trial.user_attrs["kl_divergence"], ), - value=trial, ) - for trial in best_trials - ] + min_divergence = math.inf + best_trials = [] + for trial in sorted_trials: + kl_divergence = trial.user_attrs["kl_divergence"] + if kl_divergence < min_divergence: + min_divergence = kl_divergence + best_trials.append(trial) + + choices = [ + Choice( + title=( + f"[Trial {trial.user_attrs['index']:>3}] " + f"Refusals: {trial.user_attrs['refusals']:>2}/{len(evaluator.bad_prompts)}, " + f"KL divergence: {trial.user_attrs['kl_divergence']:.4f}" + ), + value=trial, + ) + for trial in best_trials + ] - choices.append( - Choice( - title="Run additional trials", - value="continue", + choices.append( + Choice( + title="Run additional trials", + value="continue", + ) ) - ) - choices.append( - Choice( - title="Exit program", - value="", + choices.append( + Choice( + title="Exit program", + value="", + ) ) - ) - print() - print("[bold green]Optimization finished![/]") - print() - print( - ( - "The following trials resulted in Pareto optimal combinations of refusals and KL divergence. " - "After selecting a trial, you will be able to save the model, upload it to Hugging Face, " - "chat with it to test how well it works, or run standard benchmarks on it. " - "You can return to this menu later to select a different trial. " - "[yellow]Note that KL divergence values above 0.5 usually indicate significant damage to the original model's capabilities.[/]" + print() + print("[bold green]Optimization finished![/]") + print() + print( + ( + "The following trials resulted in Pareto optimal combinations of refusals and KL divergence. " + "After selecting a trial, you will be able to save the model, upload it to Hugging Face, " + "chat with it to test how well it works, or run standard benchmarks on it. " + "You can return to this menu later to select a different trial. " + "[yellow]Note that KL divergence values above 0.5 usually indicate significant damage to the original model's capabilities.[/]" + ) ) - ) while True: - print() - trial = prompt_select("Which trial do you want to use?", choices) + if reproduction_mode: + parameters = reproduction_information["parameters"] + metrics = reproduction_information["metrics"] + + trial = create_trial( + values=[], + user_attrs={ + "direction_index": parameters["direction_index"], + "parameters": parameters["abliteration_parameters"], + "kl_divergence": metrics["kl_divergence"], + "refusals": metrics["refusals"], + "base_refusals": metrics["base_refusals"], + "n_bad_prompts": metrics["n_bad_prompts"], + }, + ) + + print() + print("Restoring model from reproduction information...") + else: + print() + trial = prompt_select("Which trial do you want to use?", choices) + + if trial is None or trial == "": + return + + if trial == "continue": + while True: + try: + n_additional_trials = prompt_text( + "How many additional trials do you want to run?" + ) + if n_additional_trials is None or n_additional_trials == "": + n_additional_trials = 0 + break + n_additional_trials = int(n_additional_trials) + if n_additional_trials > 0: + break + print("[red]Please enter a number greater than 0.[/]") + except ValueError: + print("[red]Please enter a number.[/]") + + if n_additional_trials == 0: + continue + + settings.n_trials += n_additional_trials + study.set_user_attr("settings", settings.model_dump_json()) + study.set_user_attr("finished", False) - if trial == "continue": - while True: try: - n_additional_trials = prompt_text( - "How many additional trials do you want to run?" + study.optimize( + objective_wrapper, + n_trials=settings.n_trials - count_completed_trials(), ) - if n_additional_trials is None or n_additional_trials == "": - n_additional_trials = 0 - break - n_additional_trials = int(n_additional_trials) - if n_additional_trials > 0: - break - print("[red]Please enter a number greater than 0.[/]") - except ValueError: - print("[red]Please enter a number.[/]") - - if n_additional_trials == 0: - continue - - settings.n_trials += n_additional_trials - study.set_user_attr("settings", settings.model_dump_json()) - study.set_user_attr("finished", False) - - try: - study.optimize( - objective_wrapper, - n_trials=settings.n_trials - count_completed_trials(), - ) - except KeyboardInterrupt: - pass + except KeyboardInterrupt: + pass - if count_completed_trials() == settings.n_trials: - study.set_user_attr("finished", True) + if count_completed_trials() == settings.n_trials: + study.set_user_attr("finished", True) - break + break - elif trial is None or trial == "": - return + print() + print( + f"Restoring model from trial [bold]{trial.user_attrs['index']}[/]..." + ) - print() - print(f"Restoring model from trial [bold]{trial.user_attrs['index']}[/]...") print("* Parameters:") for name, value in get_trial_parameters(trial).items(): print(f" * {name} = [bold]{value}[/]") @@ -804,12 +840,20 @@ def reset_trial_model() -> None: "Upload the model to Hugging Face", "Chat with the model", "Benchmark the model", - "Return to the trial selection menu", + Choice( + title="Exit program" + if reproduction_mode + else "Return to the trial selection menu", + value="", + ), ], ) - if action is None or action == "Return to the trial selection menu": - break + if action is None or action == "": + if reproduction_mode: + return + else: + break # All actions are wrapped in a try/except block so that if an error occurs, # another action can be tried, instead of the program crashing and losing @@ -891,8 +935,10 @@ def reset_trial_model() -> None: settings.good_evaluation_prompts.dataset, settings.bad_evaluation_prompts.dataset, ] - is_reproducible = is_hf_path(settings.model) and all( - is_hf_path(dataset) for dataset in datasets + is_reproducible = ( + is_hf_path(settings.model) + and all(is_hf_path(dataset) for dataset in datasets) + and not reproduction_mode ) if is_reproducible: diff --git a/src/heretic/utils.py b/src/heretic/utils.py index 27dd697f..fdd5cf1b 100644 --- a/src/heretic/utils.py +++ b/src/heretic/utils.py @@ -23,6 +23,7 @@ from datasets.download.download_manager import DownloadMode from datasets.utils.info_utils import VerificationMode from optuna import Trial +from optuna.trial import FrozenTrial from psutil import Process from questionary import Choice, Style from rich.console import Console @@ -258,7 +259,7 @@ def batchify(items: list[T], batch_size: int) -> list[list[T]]: return [items[i : i + batch_size] for i in range(0, len(items), batch_size)] -def get_trial_parameters(trial: Trial) -> dict[str, str]: +def get_trial_parameters(trial: Trial | FrozenTrial) -> dict[str, str]: params = {} direction_index = trial.user_attrs["direction_index"] @@ -275,7 +276,7 @@ def get_trial_parameters(trial: Trial) -> dict[str, str]: def get_readme_intro( settings: Settings, - trial: Trial, + trial: Trial | FrozenTrial, contains_reproducibility_information: bool, ) -> str: if is_hf_path(settings.model): @@ -367,7 +368,7 @@ def format_hf_link( def generate_reproduce_readme( settings: Settings, checkpoint_filename: str, - trial: Trial, + trial: Trial | FrozenTrial, include_system_information: bool, ) -> str: """Generates the contents of a README.md for the reproduce/ folder.""" @@ -536,7 +537,7 @@ def generate_reproduce_readme( def generate_reproduce_json( settings: Settings, - trial: Trial, + trial: Trial | FrozenTrial, timestamp: str, uploaded_model_hashes: dict[str, str], include_system_information: bool, @@ -604,7 +605,7 @@ def create_reproduce_folder( path: Path, settings: Settings, checkpoint_path: str | Path, - trial: Trial, + trial: Trial | FrozenTrial, uploaded_model_hashes: dict[str, str], include_system_information: bool, ): @@ -678,7 +679,7 @@ def upload_reproduce_folder( settings: Settings, token: str, checkpoint_path: str | Path, - trial: Trial, + trial: Trial | FrozenTrial, include_system_information: bool, ): api = huggingface_hub.HfApi()