diff --git a/config.default.toml b/config.default.toml index 6ec8e8ed..1be2d80e 100644 --- a/config.default.toml +++ b/config.default.toml @@ -199,3 +199,45 @@ column = "text" dataset = "mlabonne/harmful_behaviors" split = "test[:100]" column = "text" + +# The parameters used to choose suggest settings for abliteration. With the +# exception of direction_scope and max_weight, all values are between 0 and 1. +# min_weight is set relative to max_weight, and the other values are relative +# to the number of layers in the model. +# By default, the same ranges are used across all components. For parameters +# that are sampled per component, the ranges are specified for each component. +[parameters] + +# The different refusal direction scopes that can be applied to each trial. +direction_scope = [ + # Choose a refusal direction by interpolating between 2 layers and apply it globally. + "global", + # For each layer within range, apply the layer's own refusal direction to itself. + "per layer", +] + +# For the global direction scope, the layer from which to choose the refusal direction. +direction_fraction = { low = 0.4, high = 0.9 } + +# The maximum weight with which to apply the abliteration. Set log = true to +# sample from the log space, which will select lower values more frequently. +# Note that low must be greater than 0 when log = true. +[parameters.max_weight] +"attn.o_proj" = { low = 0.8, high = 1.5, log = false } +"mlp.down_proj" = { low = 0.8, high = 1.5, log = false } + +# The position (layer) at which the maximum weight should be applied. +[parameters.max_weight_position_fraction] +"attn.o_proj" = { low = 0.6, high = 1.0 } +"mlp.down_proj" = { low = 0.6, high = 1.0 } + +# The minimum weight as a fraction of the maximum weight. +[parameters.min_weight_relative] +"attn.o_proj" = { low = 0.0, high = 1.0 } +"mlp.down_proj" = { low = 0.0, high = 1.0 } + +# The distance from max_weight_position across which the weight drops from +# max_weight to min_weight. Beyond this distance, the weight is set to 0. +[parameters.min_weight_distance_fraction] +"attn.o_proj" = { low = 0.0, high = 0.6 } +"mlp.down_proj" = { low = 0.0, high = 0.6 } diff --git a/pyproject.toml b/pyproject.toml index 3a45e0d1..68311487 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,6 +73,9 @@ heretic = "heretic.main:main" requires = ["uv_build>=0.8.11,<0.9.0"] build-backend = "uv_build" +[tool.ty.environment] +python-version = "3.10" + [tool.uv] exclude-newer = "7 days" diff --git a/src/heretic/config.py b/src/heretic/config.py index c2879f32..0ae289ee 100644 --- a/src/heretic/config.py +++ b/src/heretic/config.py @@ -13,6 +13,8 @@ TomlConfigSettingsSource, ) +from .parameters import ParameterSpecification + # !!!IMPORTANT!!! # # Any settings added to the classes defined in this module @@ -493,6 +495,11 @@ class Settings(BaseSettings): description="Dataset of prompts that tend to result in refusals (used for evaluating model performance).", ) + parameters: ParameterSpecification = Field( + default=ParameterSpecification(), + description="The parameter specifications, per parameter or per component within each parameter.", + ) + @classmethod def settings_customise_sources( cls, diff --git a/src/heretic/main.py b/src/heretic/main.py index 48eece75..c9a4b13b 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -58,6 +58,7 @@ def _is_help_invocation() -> bool: from optuna.trial import TrialState from pydantic import ValidationError from questionary import Choice, Style +from rich.markup import escape from rich.table import Table from rich.traceback import install @@ -65,6 +66,7 @@ def _is_help_invocation() -> bool: from .config import QuantizationMethod from .evaluator import Evaluator from .model import AbliterationParameters, Model, get_model_class +from .parameters import Parameters, ParamKind from .reproduce import collect_reproducibles from .system import empty_cache, get_accelerator_info from .utils import ( @@ -196,7 +198,19 @@ def run(): print(f"[red]Configuration contains [bold]{error.error_count()}[/] errors:[/]") for error in error.errors(): - print(f"[bold]{error['loc'][0]}[/]: [yellow]{error['msg']}[/]") + full_loc = str(error["loc"][0]) + for loc in error["loc"][1:]: + if loc in ParamKind: + continue # Skip param kinds added to loc by discriminators. + if not isinstance(loc, str): + full_loc += f"[{loc}]" + elif loc == "[key]": + full_loc += " (key error)" + elif "." in loc: + full_loc += f'["{loc}"]' + else: + full_loc += f".{loc}" + print(f"[bold]{escape(full_loc)}[/]: [yellow]{error['msg']}[/]") print() print( @@ -322,6 +336,7 @@ def run(): return model = Model(settings) + params = Parameters(settings.parameters) print() print_memory_usage() @@ -490,69 +505,47 @@ def objective(trial: Trial) -> tuple[float, float]: trial_index += 1 trial.set_user_attr("index", trial_index) - direction_scope = trial.suggest_categorical( - "direction_scope", - [ - "global", - "per layer", - ], - ) + direction_scope_param = params.direction_scope.get() + direction_scope = direction_scope_param.suggest(trial) last_layer_index = len(model.get_layers()) - 1 - # Discrimination between "harmful" and "harmless" inputs is usually strongest - # in layers slightly past the midpoint of the layer stack. See the original - # abliteration paper (https://arxiv.org/abs/2406.11717) for a deeper analysis. - # - # Note that we always sample this parameter even though we only need it for - # the "global" direction scope. The reason is that multivariate TPE doesn't - # work with conditional or variable-range parameters. - direction_index = trial.suggest_float( - "direction_index", - 0.4 * last_layer_index, - 0.9 * last_layer_index, - ) - - if direction_scope == "per layer": - direction_index = None + # Note that we always sample this parameter when the "global" direction + # scope is included in the choices, even though we only need it for the + # "global" direction scope itself. The reason is that multivariate TPE + # doesn't work with conditional or variable-range parameters. + if "global" in direction_scope_param.choices: + direction_fraction = params.direction_fraction.suggest(trial) + else: + direction_fraction = None - parameters = {} + parameters: dict[str, AbliterationParameters] = {} for component in model.get_abliterable_components(): - # The parameter ranges are based on experiments with various models - # and much wider ranges. They are not set in stone and might have to be - # adjusted for future models. - max_weight = trial.suggest_float( - f"{component}.max_weight", - 0.8, - 1.5, - ) - max_weight_position = trial.suggest_float( - f"{component}.max_weight_position", - 0.6 * last_layer_index, - 1.0 * last_layer_index, - ) - # For sampling purposes, min_weight is expressed as a fraction of max_weight, - # again because multivariate TPE doesn't support variable-range parameters. - # The value is transformed into the actual min_weight value below. - min_weight = trial.suggest_float( - f"{component}.min_weight", - 0.0, - 1.0, + max_weight = params.max_weight.suggest(trial, component) + + max_weight_position_fraction = params.max_weight_position_fraction.suggest( + trial, component ) - min_weight_distance = trial.suggest_float( - f"{component}.min_weight_distance", - 1.0, - 0.6 * last_layer_index, + + min_weight_relative = params.min_weight_relative.suggest(trial, component) + + min_weight_distance_fraction = params.min_weight_distance_fraction.suggest( + trial, component ) - parameters[component] = AbliterationParameters( + parameters[component.value] = AbliterationParameters( max_weight=max_weight, - max_weight_position=max_weight_position, - min_weight=(min_weight * max_weight), - min_weight_distance=min_weight_distance, + max_weight_position=max_weight_position_fraction * last_layer_index, + min_weight=min_weight_relative * max_weight, + min_weight_distance=min_weight_distance_fraction * last_layer_index, ) + if direction_fraction is None or direction_scope != "global": + direction_index = None + else: + direction_index = direction_fraction * last_layer_index + trial.set_user_attr("direction_index", direction_index) trial.set_user_attr("parameters", {k: asdict(v) for k, v in parameters.items()}) diff --git a/src/heretic/model.py b/src/heretic/model.py index 41a8e71c..76bce04f 100644 --- a/src/heretic/model.py +++ b/src/heretic/model.py @@ -30,6 +30,7 @@ ) from .config import QuantizationMethod, RowNormalization, Settings +from .parameters import ModelComponent from .system import empty_cache from .utils import Prompt, batchify, print @@ -155,7 +156,7 @@ def __init__(self, settings: Settings): print(f"* Transformer model with [bold]{len(self.get_layers())}[/] layers") - all_components = {} + all_components: dict[ModelComponent, int] = {} for layer_index in range(len(self.get_layers())): for component, modules in self.get_layer_modules(layer_index).items(): if component not in all_components: @@ -164,7 +165,7 @@ def __init__(self, settings: Settings): print("* Abliterable components:") for component, count in all_components.items(): - print(f" * [bold]{component}[/]: [bold]{count}[/] modules total") + print(f" * [bold]{component.value}[/]: [bold]{count}[/] modules total") def _apply_lora(self): # Guard against calling this method at the wrong time. @@ -349,12 +350,12 @@ def get_layers(self) -> ModuleList: # Text-only models. return model.model.layers - def get_layer_modules(self, layer_index: int) -> dict[str, list[Module]]: + def get_layer_modules(self, layer_index: int) -> dict[ModelComponent, list[Module]]: layer = self.get_layers()[layer_index] - modules = {} + modules: dict[ModelComponent, list[Module]] = {} - def try_add(component: str, module: Any): + def try_add(component: ModelComponent, module: Any): # Only add if it's a proper nn.Module (PEFT can wrap these with LoRA) if isinstance(module, Module): if component not in modules: @@ -363,40 +364,40 @@ def try_add(component: str, module: Any): else: # Assert for unexpected types (catches architecture changes) assert not isinstance(module, Tensor), ( - f"Unexpected Tensor in {component} - expected nn.Module" + f"Unexpected Tensor in {component.value} - expected nn.Module" ) # Standard self-attention out-projection (most models). with suppress(Exception): - try_add("attn.o_proj", layer.self_attn.o_proj) # ty:ignore[possibly-missing-attribute] + try_add(ModelComponent.ATTN_O_PROJ, layer.self_attn.o_proj) # ty:ignore[possibly-missing-attribute] # Qwen3.5 MoE hybrid layers use GatedDeltaNet (linear attention) instead of # standard self-attention, so self_attn.o_proj doesn't exist on those layers. with suppress(Exception): - try_add("attn.o_proj", layer.linear_attn.out_proj) # ty:ignore[possibly-missing-attribute] + try_add(ModelComponent.ATTN_O_PROJ, layer.linear_attn.out_proj) # ty:ignore[possibly-missing-attribute] # Most dense models. with suppress(Exception): - try_add("mlp.down_proj", layer.mlp.down_proj) # ty:ignore[possibly-missing-attribute] + try_add(ModelComponent.MLP_DOWN_PROJ, layer.mlp.down_proj) # ty:ignore[possibly-missing-attribute] # Some MoE models (e.g. Qwen3). with suppress(Exception): for expert in layer.mlp.experts: # ty:ignore[possibly-missing-attribute, not-iterable] - try_add("mlp.down_proj", expert.down_proj) # ty:ignore[possibly-missing-attribute] + try_add(ModelComponent.MLP_DOWN_PROJ, expert.down_proj) # ty:ignore[possibly-missing-attribute] # Phi-3.5-MoE (and possibly others). with suppress(Exception): for expert in layer.block_sparse_moe.experts: # ty:ignore[possibly-missing-attribute, not-iterable] - try_add("mlp.down_proj", expert.w2) # ty:ignore[possibly-missing-attribute] + try_add(ModelComponent.MLP_DOWN_PROJ, expert.w2) # ty:ignore[possibly-missing-attribute] # Granite MoE Hybrid - attention layers with shared_mlp. with suppress(Exception): - try_add("mlp.down_proj", layer.shared_mlp.output_linear) # ty:ignore[possibly-missing-attribute] + try_add(ModelComponent.MLP_DOWN_PROJ, layer.shared_mlp.output_linear) # ty:ignore[possibly-missing-attribute] # Granite MoE Hybrid - MoE layers with experts. with suppress(Exception): for expert in layer.moe.experts: # ty:ignore[possibly-missing-attribute, not-iterable] - try_add("mlp.down_proj", expert.output_linear) # ty:ignore[possibly-missing-attribute] + try_add(ModelComponent.MLP_DOWN_PROJ, expert.output_linear) # ty:ignore[possibly-missing-attribute] # We need at least one module across all components for abliteration to work. total_modules = sum(len(mods) for mods in modules.values()) @@ -404,8 +405,8 @@ def try_add(component: str, module: Any): return modules - def get_abliterable_components(self) -> list[str]: - components: set[str] = set() + def get_abliterable_components(self) -> list[ModelComponent]: + components: set[ModelComponent] = set() # Scan all layers because hybrid models (e.g. Qwen3.5 MoE) have different # components on different layers (some have self_attn, others linear_attn). diff --git a/src/heretic/parameters.py b/src/heretic/parameters.py new file mode 100644 index 00000000..f00f5234 --- /dev/null +++ b/src/heretic/parameters.py @@ -0,0 +1,433 @@ +# SPDX-License-Identifier: AGPL-3.0-or-later +# Copyright (C) 2025-2026 Philipp Emanuel Weidmann + contributors + +from enum import Enum +from typing import Annotated, Any, Generic, Literal, TypeVar, cast + +from annotated_types import Interval, IsFinite, MinLen +from optuna import Trial +from pydantic import ( + BaseModel, + ConfigDict, + Discriminator, + Field, + Tag, + model_validator, +) +from typing_extensions import Self, TypeAliasType + + +class ModelComponent(str, Enum): + ATTN_O_PROJ = "attn.o_proj" + MLP_DOWN_PROJ = "mlp.down_proj" + + +class ParamKind(str, Enum): + SCALAR = "param_kind_scalar" + LIST = "param_kind_list" + DICT = "param_kind_dict" + FLAT = "param_kind_flat" + NESTED = "param_kind_nested" + + +Scalar = TypeVar("Scalar", bound=None | bool | int | float | str) +UnitFloat = Annotated[float, Interval(ge=0.0, le=1.0)] + + +class FloatParamSpec(BaseModel): + model_config = ConfigDict(extra="forbid") + + low: Annotated[ + IsFinite[float], + Field( + description="Lower endpoint of the range of suggested values (inclusive).", + ), + ] + + high: Annotated[ + IsFinite[float], + Field( + description="Upper endpoint of the range of suggested values (inclusive).", + ), + ] + + log: Annotated[ + bool, + Field( + description="If true, the value is sampled from the range in the log domain.", + ), + ] = False + + @model_validator(mode="after") + def _validate(self) -> Self: + if self.low > self.high: + raise ValueError(f"low ({self.low}) must be ≤ high ({self.high}).") + if self.log and self.low <= 0: + raise ValueError(f"low ({self.low}) must be positive when log = true.") + return self + + +class UnitParamSpec(BaseModel): + model_config = ConfigDict(extra="forbid") + + low: Annotated[ + UnitFloat, + Field( + description="Lower endpoint of the range of suggested values (inclusive).", + ), + ] + + high: Annotated[ + UnitFloat, + Field( + description="Upper endpoint of the range of suggested values (inclusive).", + ), + ] + + # Included to make constructor compatible with FloatParamSpec. + log: Literal[False] = False + + @model_validator(mode="after") + def _validate(self) -> Self: + if self.low > self.high: + raise ValueError(f"low ({self.low}) must be ≤ high ({self.high}).") + return self + + +def discriminate_flat(param: Any) -> ParamKind | None: + if param is None or isinstance(param, (bool, int, float, str)): + return ParamKind.SCALAR + if isinstance(param, (FloatParamSpec, UnitParamSpec)): + return ParamKind.DICT + if isinstance(param, (list, tuple)): + return ParamKind.LIST + if isinstance(param, dict): + return ParamKind.DICT + return None + + +def discriminate_nested(param: Any) -> ParamKind | None: + if param is None or isinstance(param, (bool, int, float, str)): + return ParamKind.FLAT + if isinstance(param, (FloatParamSpec, UnitParamSpec)): + return ParamKind.FLAT + if isinstance(param, (list, tuple)): + return ParamKind.FLAT + if isinstance(param, dict): + if any(key in ModelComponent for key in param): + return ParamKind.NESTED + elif any(key == "low" or key == "high" or key == "log" for key in param): + return ParamKind.FLAT + else: + # Assume it's nested for a more helpful error message. + return ParamKind.NESTED + return None + + +FlatCategoricalParamType = TypeAliasType( + "FlatCategoricalParamType", + Annotated[ + Annotated[Scalar, Tag(ParamKind.SCALAR)] + | Annotated[list[Scalar], MinLen(1), Tag(ParamKind.LIST)], + Discriminator(discriminate_flat), + ], + type_params=(Scalar,), +) +NestedCategoricalParamType = TypeAliasType( + "NestedCategoricalParamType", + dict[ModelComponent, FlatCategoricalParamType[Scalar]], + type_params=(Scalar,), +) +CategoricalParamType = TypeAliasType( + "CategoricalParamType", + Annotated[ + Annotated[FlatCategoricalParamType[Scalar], Tag(ParamKind.FLAT)] + | Annotated[NestedCategoricalParamType[Scalar], Tag(ParamKind.NESTED)], + Discriminator(discriminate_nested), + ], + type_params=(Scalar,), +) + +FlatFloatParamType = Annotated[ + Annotated[IsFinite[float], Tag(ParamKind.SCALAR)] + | Annotated[FloatParamSpec, Tag(ParamKind.DICT)], + Discriminator(discriminate_flat), +] +NestedFloatParamType = dict[ModelComponent, FlatFloatParamType] +FloatParamType = Annotated[ + Annotated[FlatFloatParamType, Tag(ParamKind.FLAT)] + | Annotated[NestedFloatParamType, Tag(ParamKind.NESTED)], + Discriminator(discriminate_nested), +] + +FlatUnitParamType = Annotated[ + Annotated[UnitFloat, Tag(ParamKind.SCALAR)] + | Annotated[UnitParamSpec, Tag(ParamKind.DICT)], + Discriminator(discriminate_flat), +] +NestedUnitParamType = dict[ModelComponent, FlatUnitParamType] +UnitParamType = Annotated[ + Annotated[FlatUnitParamType, Tag(ParamKind.FLAT)] + | Annotated[NestedUnitParamType, Tag(ParamKind.NESTED)], + Discriminator(discriminate_nested), +] + + +class CategoricalParam(Generic[Scalar]): + name: str + choices: list[Scalar] + + def __init__(self, name: str, choices: list[Scalar]) -> None: + self.name = name + self.choices = choices + + def suggest(self, trial: Trial) -> Scalar: + if len(self.choices) < 2: + return self.choices[0] + return cast(Scalar, trial.suggest_categorical(self.name, self.choices)) + + +class FloatParam: + name: str + low: float + high: float + log: bool + + def __init__(self, name: str, low: float, high: float, log: bool) -> None: + self.name = name + self.low = low + self.high = high + self.log = log + + def suggest(self, trial: Trial) -> float: + if self.low == self.high: + return self.low + return trial.suggest_float(self.name, self.low, self.high, log=self.log) + + +class FlatCategoricalParamProxy(Generic[Scalar]): + name: str + current: FlatCategoricalParamType[Scalar] + default: FlatCategoricalParamType[Scalar] + + def __init__( + self, + name: str, + current: FlatCategoricalParamType[Scalar], + default: FlatCategoricalParamType[Scalar], + ) -> None: + self.name = name + self.current = current + self.default = default + + def get(self) -> CategoricalParam[Scalar]: + choices: list[Scalar] + if self.current is None or isinstance(self.current, (bool, int, float, str)): + choices = [self.current] + else: + choices = self.current + return CategoricalParam[Scalar](self.name, choices) + + def suggest(self, trial: Trial) -> Scalar: + return self.get().suggest(trial) + + +class CategoricalParamProxy(Generic[Scalar]): + name: str + current: CategoricalParamType[Scalar] + default: CategoricalParamType[Scalar] + + def __init__( + self, + name: str, + current: CategoricalParamType[Scalar], + default: CategoricalParamType[Scalar], + ) -> None: + self.name = name + self.current = current + self.default = default + + def get(self, component: ModelComponent) -> CategoricalParam[Scalar]: + choices: list[Scalar] + # First, get the choices from the defaults, which should be fully specified. + if self.default is None or isinstance(self.default, (bool, int, float, str)): + choices = [self.default] + elif isinstance(self.default, list): + choices = self.default + else: + param: Scalar | list[Scalar] = self.default[component] + if param is None or isinstance(param, (bool, int, float, str)): + choices = [param] + else: + choices = param + # Now check for a user-specified override. In the per-component form, this may not be fully specified. + if self.current is None or isinstance(self.current, (bool, int, float, str)): + choices = [self.current] + elif isinstance(self.current, list): + choices = self.current + elif component in self.current: + param: Scalar | list[Scalar] = self.current.get(component, choices) + if param is None or isinstance(param, (bool, int, float, str)): + choices = [param] + else: + choices = param + return CategoricalParam[Scalar](f"{component.value}.{self.name}", choices) + + def suggest(self, trial: Trial, component: ModelComponent) -> Scalar: + return self.get(component).suggest(trial) + + +class FlatFloatParamProxy: + name: str + current: FlatFloatParamType | FlatUnitParamType + default: FlatFloatParamType | FlatUnitParamType + + def __init__( + self, + name: str, + current: FlatFloatParamType | FlatUnitParamType, + default: FlatFloatParamType | FlatUnitParamType, + ) -> None: + self.name = name + self.current = current + self.default = default + + def get(self) -> FloatParam: + log = False + if isinstance(self.current, (int, float)): + low = high = self.current + else: + param = self.current + low, high, log = param.low, param.high, param.log + return FloatParam(self.name, low, high, log) + + def suggest(self, trial: Trial) -> float: + return self.get().suggest(trial) + + +class FloatParamProxy: + name: str + current: FloatParamType | UnitParamType + default: FloatParamType | UnitParamType + + def __init__( + self, + name: str, + current: FloatParamType | UnitParamType, + default: FloatParamType | UnitParamType, + ) -> None: + self.name = name + self.current = current + self.default = default + + def get(self, component: ModelComponent) -> FloatParam: + log = False + # First, get the choices from the defaults, which should be fully specified. + if isinstance(self.default, (int, float)): + low = high = self.default + elif isinstance(self.default, (FloatParamSpec, UnitParamSpec)): + param = self.default + low, high, log = param.low, param.high, param.log + else: + param = self.default[component] + if isinstance(param, (int, float)): + low = high = param + else: + low, high, log = param.low, param.high, param.log + # Now check for a user-specified override. In the per-component form, this may not be fully specified. + if isinstance(self.current, (int, float)): + low = high = self.current + elif isinstance(self.current, (FloatParamSpec, UnitParamSpec)): + param = self.current + low, high, log = param.low, param.high, param.log + elif component in self.current: + param = self.current[component] + if isinstance(param, (int, float)): + low = high = param + else: + low, high, log = param.low, param.high, param.log + return FloatParam(f"{component.value}.{self.name}", low, high, log) + + def suggest(self, trial: Trial, component: ModelComponent) -> float: + return self.get(component).suggest(trial) + + +class DirectionScope(str, Enum): + GLOBAL = "global" + PER_LAYER = "per layer" + + +class ParameterSpecification(BaseModel): + model_config = ConfigDict(extra="forbid") + + direction_scope: FlatCategoricalParamType[DirectionScope] = Field( + default=[DirectionScope.GLOBAL, DirectionScope.PER_LAYER], + description=( + "The different refusal direction scopes that can be applied to each trial. " + '"global": Choose a refusal direction by interpolating between 2 layers and apply it globally. ' + '"per layer": For each layer within range, apply the layer\'s own refusal direction to itself.' + ), + ) + + # Discrimination between "harmful" and "harmless" inputs is usually strongest in layers slightly past the midpoint + # of the layer stack. See the original abliteration paper (https://arxiv.org/abs/2406.11717) for a deeper analysis. + direction_fraction: FlatUnitParamType = Field( + default=UnitParamSpec(low=0.4, high=0.9), + description="For the global direction scope, the layer from which to choose the refusal direction.", + ) + + # The parameter ranges are based on experiments with various models and much wider ranges. + # They are not set in stone and might have to be adjusted for future models. + max_weight: FloatParamType = Field( + default=FloatParamSpec(low=0.8, high=1.5), + description=( + "The maximum weight with which to apply the abliteration. Set log = true to " + "sample from the log space, which will select lower values more frequently. " + "Note that low must be greater than 0 when log = true." + ), + ) + + max_weight_position_fraction: UnitParamType = Field( + default=UnitParamSpec(low=0.6, high=1.0), + description="The position (layer) at which the maximum weight should be applied.", + ) + + # For sampling purposes, min_weight is expressed as a fraction of max_weight, + # because multivariate TPE doesn't support variable-range parameters. + min_weight_relative: UnitParamType = Field( + default=UnitParamSpec(low=0.0, high=1.0), + description="The minimum weight as a fraction of the maximum weight.", + ) + + min_weight_distance_fraction: UnitParamType = Field( + default=UnitParamSpec(low=0.0, high=0.6), + description=( + "The distance from max_weight_position across which the weight drops from " + "max_weight to min_weight. Beyond this distance, the weight is set to 0." + ), + ) + + +class Parameters: + direction_scope: FlatCategoricalParamProxy[DirectionScope] + direction_fraction: FlatFloatParamProxy + max_weight: FloatParamProxy + max_weight_position_fraction: FloatParamProxy + min_weight_relative: FloatParamProxy + min_weight_distance_fraction: FloatParamProxy + + def __init__(self, current_params: ParameterSpecification) -> None: + default_params = ParameterSpecification() + + def _create_proxy(name: str, proxy_cls: type) -> None: + current = getattr(current_params, name) + default = getattr(default_params, name) + proxy = proxy_cls(name, current, default) + setattr(self, name, proxy) + + _create_proxy("direction_scope", FlatCategoricalParamProxy[DirectionScope]) + _create_proxy("direction_fraction", FlatFloatParamProxy) + _create_proxy("max_weight", FloatParamProxy) + _create_proxy("max_weight_position_fraction", FloatParamProxy) + _create_proxy("min_weight_relative", FloatParamProxy) + _create_proxy("min_weight_distance_fraction", FloatParamProxy)