opik-optimizer 2.0.1__py3-none-any.whl → 2.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opik_optimizer/__init__.py +12 -0
- opik_optimizer/base_optimizer.py +33 -0
- opik_optimizer/hierarchical_reflective_optimizer/__init__.py +5 -0
- opik_optimizer/hierarchical_reflective_optimizer/hierarchical_reflective_optimizer.py +718 -0
- opik_optimizer/hierarchical_reflective_optimizer/hierarchical_root_cause_analyzer.py +355 -0
- opik_optimizer/hierarchical_reflective_optimizer/prompts.py +91 -0
- opik_optimizer/hierarchical_reflective_optimizer/reporting.py +679 -0
- opik_optimizer/hierarchical_reflective_optimizer/types.py +49 -0
- opik_optimizer/optimization_result.py +227 -6
- opik_optimizer/parameter_optimizer/__init__.py +11 -0
- opik_optimizer/parameter_optimizer/parameter_optimizer.py +382 -0
- opik_optimizer/parameter_optimizer/parameter_search_space.py +125 -0
- opik_optimizer/parameter_optimizer/parameter_spec.py +214 -0
- opik_optimizer/parameter_optimizer/search_space_types.py +24 -0
- opik_optimizer/parameter_optimizer/sensitivity_analysis.py +71 -0
- {opik_optimizer-2.0.1.dist-info → opik_optimizer-2.1.1.dist-info}/METADATA +4 -2
- {opik_optimizer-2.0.1.dist-info → opik_optimizer-2.1.1.dist-info}/RECORD +20 -8
- {opik_optimizer-2.0.1.dist-info → opik_optimizer-2.1.1.dist-info}/WHEEL +0 -0
- {opik_optimizer-2.0.1.dist-info → opik_optimizer-2.1.1.dist-info}/licenses/LICENSE +0 -0
- {opik_optimizer-2.0.1.dist-info → opik_optimizer-2.1.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,382 @@
|
|
1
|
+
"""Simple Optuna-based optimizer for model parameter tuning."""
|
2
|
+
|
3
|
+
from collections.abc import Callable, Mapping
|
4
|
+
from typing import Any
|
5
|
+
|
6
|
+
import copy
|
7
|
+
import logging
|
8
|
+
from datetime import datetime
|
9
|
+
|
10
|
+
import optuna
|
11
|
+
from optuna import importance as optuna_importance
|
12
|
+
from optuna.trial import Trial, TrialState
|
13
|
+
|
14
|
+
from opik import Dataset
|
15
|
+
|
16
|
+
from ..base_optimizer import BaseOptimizer
|
17
|
+
from ..optimizable_agent import OptimizableAgent
|
18
|
+
from ..optimization_config import chat_prompt
|
19
|
+
from ..optimization_result import OptimizationResult
|
20
|
+
from .parameter_search_space import ParameterSearchSpace
|
21
|
+
from .search_space_types import ParameterType
|
22
|
+
from .sensitivity_analysis import compute_sensitivity_from_trials
|
23
|
+
|
24
|
+
logger = logging.getLogger(__name__)
|
25
|
+
|
26
|
+
|
27
|
+
class ParameterOptimizer(BaseOptimizer):
|
28
|
+
"""Optimizer that tunes model call parameters (temperature, top_p, etc.)."""
|
29
|
+
|
30
|
+
def __init__(
|
31
|
+
self,
|
32
|
+
model: str,
|
33
|
+
*,
|
34
|
+
default_n_trials: int = 20,
|
35
|
+
n_threads: int = 4,
|
36
|
+
seed: int = 42,
|
37
|
+
verbose: int = 1,
|
38
|
+
local_search_ratio: float = 0.3,
|
39
|
+
local_search_scale: float = 0.2,
|
40
|
+
**model_kwargs: Any,
|
41
|
+
) -> None:
|
42
|
+
super().__init__(model=model, verbose=verbose, seed=seed, **model_kwargs)
|
43
|
+
self.default_n_trials = default_n_trials
|
44
|
+
self.n_threads = n_threads
|
45
|
+
self.local_search_ratio = max(0.0, min(local_search_ratio, 1.0))
|
46
|
+
self.local_search_scale = max(0.0, local_search_scale)
|
47
|
+
|
48
|
+
if self.verbose == 0:
|
49
|
+
logger.setLevel(logging.WARNING)
|
50
|
+
elif self.verbose == 1:
|
51
|
+
logger.setLevel(logging.INFO)
|
52
|
+
else:
|
53
|
+
logger.setLevel(logging.DEBUG)
|
54
|
+
|
55
|
+
def optimize_prompt(
|
56
|
+
self,
|
57
|
+
prompt: chat_prompt.ChatPrompt,
|
58
|
+
dataset: Dataset,
|
59
|
+
metric: Callable[[Any, Any], float],
|
60
|
+
experiment_config: dict | None = None,
|
61
|
+
n_samples: int | None = None,
|
62
|
+
auto_continue: bool = False,
|
63
|
+
agent_class: type[OptimizableAgent] | None = None,
|
64
|
+
**kwargs: Any,
|
65
|
+
) -> OptimizationResult:
|
66
|
+
raise NotImplementedError(
|
67
|
+
"ParameterOptimizer.optimize_prompt is not supported. "
|
68
|
+
"Use optimize_parameter(prompt, dataset, metric, parameter_space) instead, "
|
69
|
+
"where parameter_space is a ParameterSearchSpace or dict defining the parameters to optimize."
|
70
|
+
)
|
71
|
+
|
72
|
+
def optimize_parameter(
|
73
|
+
self,
|
74
|
+
prompt: chat_prompt.ChatPrompt,
|
75
|
+
dataset: Dataset,
|
76
|
+
metric: Callable[[Any, Any], float],
|
77
|
+
parameter_space: ParameterSearchSpace | Mapping[str, Any],
|
78
|
+
experiment_config: dict | None = None,
|
79
|
+
n_trials: int | None = None,
|
80
|
+
n_samples: int | None = None,
|
81
|
+
agent_class: type[OptimizableAgent] | None = None,
|
82
|
+
**kwargs: Any,
|
83
|
+
) -> OptimizationResult:
|
84
|
+
if not isinstance(parameter_space, ParameterSearchSpace):
|
85
|
+
parameter_space = ParameterSearchSpace.model_validate(parameter_space)
|
86
|
+
|
87
|
+
# After validation, parameter_space is guaranteed to be ParameterSearchSpace
|
88
|
+
assert isinstance(parameter_space, ParameterSearchSpace) # for mypy
|
89
|
+
|
90
|
+
sampler = kwargs.pop("sampler", None)
|
91
|
+
callbacks = kwargs.pop("callbacks", None)
|
92
|
+
timeout = kwargs.pop("timeout", None)
|
93
|
+
local_trials_override = kwargs.pop("local_trials", None)
|
94
|
+
local_search_scale_override = kwargs.pop("local_search_scale", None)
|
95
|
+
if kwargs:
|
96
|
+
extra_keys = ", ".join(sorted(kwargs.keys()))
|
97
|
+
raise TypeError(f"Unsupported keyword arguments: {extra_keys}")
|
98
|
+
|
99
|
+
self.validate_optimization_inputs(prompt, dataset, metric)
|
100
|
+
self.configure_prompt_model(prompt)
|
101
|
+
|
102
|
+
base_model_kwargs = copy.deepcopy(prompt.model_kwargs or {})
|
103
|
+
base_prompt = prompt.copy()
|
104
|
+
base_prompt.model_kwargs = copy.deepcopy(base_model_kwargs)
|
105
|
+
|
106
|
+
metric_name = getattr(metric, "__name__", str(metric))
|
107
|
+
|
108
|
+
self.agent_class = self.setup_agent_class(base_prompt, agent_class)
|
109
|
+
baseline_score = self.evaluate_prompt(
|
110
|
+
prompt=base_prompt,
|
111
|
+
dataset=dataset,
|
112
|
+
metric=metric,
|
113
|
+
n_threads=self.n_threads,
|
114
|
+
verbose=self.verbose,
|
115
|
+
experiment_config=experiment_config,
|
116
|
+
n_samples=n_samples,
|
117
|
+
agent_class=self.agent_class,
|
118
|
+
)
|
119
|
+
|
120
|
+
history: list[dict[str, Any]] = [
|
121
|
+
{
|
122
|
+
"iteration": 0,
|
123
|
+
"timestamp": datetime.utcnow().isoformat(),
|
124
|
+
"parameters": {},
|
125
|
+
"score": baseline_score,
|
126
|
+
"model_kwargs": copy.deepcopy(base_prompt.model_kwargs or {}),
|
127
|
+
"model": base_prompt.model,
|
128
|
+
"type": "baseline",
|
129
|
+
"stage": "baseline",
|
130
|
+
}
|
131
|
+
]
|
132
|
+
|
133
|
+
try:
|
134
|
+
optuna.logging.disable_default_handler()
|
135
|
+
optuna_logger = logging.getLogger("optuna")
|
136
|
+
optuna_logger.setLevel(logger.getEffectiveLevel())
|
137
|
+
optuna_logger.propagate = False
|
138
|
+
except Exception as exc: # pragma: no cover - defensive safety
|
139
|
+
logger.warning("Could not configure Optuna logging: %s", exc)
|
140
|
+
|
141
|
+
sampler = sampler or optuna.samplers.TPESampler(seed=self.seed)
|
142
|
+
study = optuna.create_study(direction="maximize", sampler=sampler)
|
143
|
+
|
144
|
+
total_trials = self.default_n_trials if n_trials is None else n_trials
|
145
|
+
if total_trials < 0:
|
146
|
+
total_trials = 0
|
147
|
+
|
148
|
+
if local_trials_override is not None:
|
149
|
+
local_trials = min(max(int(local_trials_override), 0), total_trials)
|
150
|
+
else:
|
151
|
+
local_trials = int(total_trials * self.local_search_ratio)
|
152
|
+
|
153
|
+
global_trials = total_trials - local_trials
|
154
|
+
if total_trials > 0 and global_trials <= 0:
|
155
|
+
global_trials = 1
|
156
|
+
local_trials = max(0, total_trials - global_trials)
|
157
|
+
|
158
|
+
current_space = parameter_space
|
159
|
+
current_stage = "global"
|
160
|
+
stage_records: list[dict[str, Any]] = []
|
161
|
+
search_ranges: dict[str, dict[str, Any]] = {}
|
162
|
+
|
163
|
+
def objective(trial: Trial) -> float:
|
164
|
+
sampled_values = current_space.suggest(trial)
|
165
|
+
tuned_prompt = parameter_space.apply(
|
166
|
+
prompt,
|
167
|
+
sampled_values,
|
168
|
+
base_model_kwargs=base_model_kwargs,
|
169
|
+
)
|
170
|
+
tuned_agent_class = self.setup_agent_class(tuned_prompt, agent_class)
|
171
|
+
score = self.evaluate_prompt(
|
172
|
+
prompt=tuned_prompt,
|
173
|
+
dataset=dataset,
|
174
|
+
metric=metric,
|
175
|
+
n_threads=self.n_threads,
|
176
|
+
verbose=self.verbose,
|
177
|
+
experiment_config=experiment_config,
|
178
|
+
n_samples=n_samples,
|
179
|
+
agent_class=tuned_agent_class,
|
180
|
+
)
|
181
|
+
trial.set_user_attr("parameters", sampled_values)
|
182
|
+
trial.set_user_attr(
|
183
|
+
"model_kwargs", copy.deepcopy(tuned_prompt.model_kwargs)
|
184
|
+
)
|
185
|
+
trial.set_user_attr("model", tuned_prompt.model)
|
186
|
+
trial.set_user_attr("stage", current_stage)
|
187
|
+
return float(score)
|
188
|
+
|
189
|
+
global_range = parameter_space.describe()
|
190
|
+
stage_records.append(
|
191
|
+
{
|
192
|
+
"stage": "global",
|
193
|
+
"trials": global_trials,
|
194
|
+
"scale": 1.0,
|
195
|
+
"parameters": global_range,
|
196
|
+
}
|
197
|
+
)
|
198
|
+
search_ranges["global"] = global_range
|
199
|
+
|
200
|
+
if global_trials > 0:
|
201
|
+
study.optimize(
|
202
|
+
objective,
|
203
|
+
n_trials=global_trials,
|
204
|
+
timeout=timeout,
|
205
|
+
callbacks=callbacks,
|
206
|
+
show_progress_bar=False,
|
207
|
+
)
|
208
|
+
|
209
|
+
for trial in study.trials:
|
210
|
+
if trial.state != TrialState.COMPLETE or trial.value is None:
|
211
|
+
continue
|
212
|
+
timestamp = (
|
213
|
+
trial.datetime_complete or trial.datetime_start or datetime.utcnow()
|
214
|
+
)
|
215
|
+
history.append(
|
216
|
+
{
|
217
|
+
"iteration": trial.number + 1,
|
218
|
+
"timestamp": timestamp.isoformat(),
|
219
|
+
"parameters": trial.user_attrs.get("parameters", {}),
|
220
|
+
"score": float(trial.value),
|
221
|
+
"model_kwargs": trial.user_attrs.get("model_kwargs"),
|
222
|
+
"model": trial.user_attrs.get("model"),
|
223
|
+
"stage": trial.user_attrs.get("stage", "global"),
|
224
|
+
}
|
225
|
+
)
|
226
|
+
|
227
|
+
best_score = baseline_score
|
228
|
+
best_parameters: dict[str, Any] = {}
|
229
|
+
best_model_kwargs = copy.deepcopy(base_prompt.model_kwargs or {})
|
230
|
+
best_model = base_prompt.model
|
231
|
+
|
232
|
+
completed_trials = [
|
233
|
+
trial
|
234
|
+
for trial in study.trials
|
235
|
+
if trial.state == TrialState.COMPLETE and trial.value is not None
|
236
|
+
]
|
237
|
+
if completed_trials:
|
238
|
+
best_trial = max(completed_trials, key=lambda t: t.value) # type: ignore[arg-type]
|
239
|
+
if best_trial.value is not None and best_trial.value > best_score:
|
240
|
+
best_score = float(best_trial.value)
|
241
|
+
best_parameters = best_trial.user_attrs.get("parameters", {})
|
242
|
+
best_model_kwargs = best_trial.user_attrs.get("model_kwargs", {})
|
243
|
+
best_model = best_trial.user_attrs.get("model", prompt.model)
|
244
|
+
|
245
|
+
local_space: ParameterSearchSpace | None = None
|
246
|
+
if (
|
247
|
+
local_trials > 0
|
248
|
+
and completed_trials
|
249
|
+
and any(
|
250
|
+
spec.distribution in {ParameterType.FLOAT, ParameterType.INT}
|
251
|
+
for spec in parameter_space.parameters
|
252
|
+
)
|
253
|
+
):
|
254
|
+
local_scale = (
|
255
|
+
self.local_search_scale
|
256
|
+
if local_search_scale_override is None
|
257
|
+
else max(0.0, float(local_search_scale_override))
|
258
|
+
)
|
259
|
+
|
260
|
+
if best_parameters:
|
261
|
+
center_values = best_parameters
|
262
|
+
elif base_model_kwargs:
|
263
|
+
center_values = base_model_kwargs
|
264
|
+
else:
|
265
|
+
center_values = {}
|
266
|
+
|
267
|
+
if local_scale > 0 and center_values:
|
268
|
+
current_stage = "local"
|
269
|
+
local_space = parameter_space.narrow_around(center_values, local_scale)
|
270
|
+
local_range = local_space.describe()
|
271
|
+
stage_records.append(
|
272
|
+
{
|
273
|
+
"stage": "local",
|
274
|
+
"trials": local_trials,
|
275
|
+
"scale": local_scale,
|
276
|
+
"parameters": local_range,
|
277
|
+
}
|
278
|
+
)
|
279
|
+
search_ranges["local"] = local_range
|
280
|
+
|
281
|
+
current_space = local_space
|
282
|
+
study.optimize(
|
283
|
+
objective,
|
284
|
+
n_trials=local_trials,
|
285
|
+
timeout=timeout,
|
286
|
+
callbacks=callbacks,
|
287
|
+
show_progress_bar=False,
|
288
|
+
)
|
289
|
+
|
290
|
+
completed_trials = [
|
291
|
+
trial
|
292
|
+
for trial in study.trials
|
293
|
+
if trial.state == TrialState.COMPLETE and trial.value is not None
|
294
|
+
]
|
295
|
+
if completed_trials:
|
296
|
+
new_best = max(completed_trials, key=lambda t: t.value) # type: ignore[arg-type]
|
297
|
+
if new_best.value is not None and new_best.value > best_score:
|
298
|
+
best_score = float(new_best.value)
|
299
|
+
best_parameters = new_best.user_attrs.get("parameters", {})
|
300
|
+
best_model_kwargs = new_best.user_attrs.get("model_kwargs", {})
|
301
|
+
best_model = new_best.user_attrs.get("model", prompt.model)
|
302
|
+
|
303
|
+
else:
|
304
|
+
local_trials = 0
|
305
|
+
|
306
|
+
for trial in study.trials:
|
307
|
+
if trial.state != TrialState.COMPLETE or trial.value is None:
|
308
|
+
continue
|
309
|
+
timestamp = (
|
310
|
+
trial.datetime_complete or trial.datetime_start or datetime.utcnow()
|
311
|
+
)
|
312
|
+
if not any(entry["iteration"] == trial.number + 1 for entry in history):
|
313
|
+
history.append(
|
314
|
+
{
|
315
|
+
"iteration": trial.number + 1,
|
316
|
+
"timestamp": timestamp.isoformat(),
|
317
|
+
"parameters": trial.user_attrs.get("parameters", {}),
|
318
|
+
"score": float(trial.value),
|
319
|
+
"model_kwargs": trial.user_attrs.get("model_kwargs"),
|
320
|
+
"model": trial.user_attrs.get("model"),
|
321
|
+
"stage": trial.user_attrs.get("stage", current_stage),
|
322
|
+
}
|
323
|
+
)
|
324
|
+
|
325
|
+
rounds_summary = [
|
326
|
+
{
|
327
|
+
"iteration": trial.number + 1,
|
328
|
+
"parameters": trial.user_attrs.get("parameters", {}),
|
329
|
+
"score": float(trial.value) if trial.value is not None else None,
|
330
|
+
"model": trial.user_attrs.get("model"),
|
331
|
+
"stage": trial.user_attrs.get("stage"),
|
332
|
+
}
|
333
|
+
for trial in completed_trials
|
334
|
+
]
|
335
|
+
|
336
|
+
try:
|
337
|
+
importance = optuna_importance.get_param_importances(study)
|
338
|
+
except (ValueError, RuntimeError, ImportError):
|
339
|
+
# Falls back to custom sensitivity analysis if:
|
340
|
+
# - Study has insufficient data (ValueError/RuntimeError)
|
341
|
+
# - scikit-learn not installed (ImportError)
|
342
|
+
importance = {}
|
343
|
+
|
344
|
+
if not importance or all(value == 0 for value in importance.values()):
|
345
|
+
importance = compute_sensitivity_from_trials(
|
346
|
+
completed_trials, parameter_space.parameters
|
347
|
+
)
|
348
|
+
|
349
|
+
details = {
|
350
|
+
"initial_score": baseline_score,
|
351
|
+
"optimized_parameters": best_parameters,
|
352
|
+
"optimized_model_kwargs": best_model_kwargs,
|
353
|
+
"optimized_model": best_model,
|
354
|
+
"trials": history,
|
355
|
+
"parameter_space": parameter_space.model_dump(by_alias=True),
|
356
|
+
"n_trials": total_trials,
|
357
|
+
"model": best_model,
|
358
|
+
"rounds": rounds_summary,
|
359
|
+
"baseline_parameters": base_model_kwargs,
|
360
|
+
"temperature": best_model_kwargs.get("temperature"),
|
361
|
+
"local_trials": local_trials,
|
362
|
+
"global_trials": global_trials,
|
363
|
+
"search_stages": stage_records,
|
364
|
+
"search_ranges": search_ranges,
|
365
|
+
"parameter_importance": importance,
|
366
|
+
"parameter_precision": 6,
|
367
|
+
}
|
368
|
+
|
369
|
+
return OptimizationResult(
|
370
|
+
optimizer=self.__class__.__name__,
|
371
|
+
prompt=prompt.get_messages() if hasattr(prompt, "get_messages") else [],
|
372
|
+
initial_prompt=prompt.get_messages()
|
373
|
+
if hasattr(prompt, "get_messages")
|
374
|
+
else [],
|
375
|
+
initial_score=baseline_score,
|
376
|
+
score=best_score,
|
377
|
+
metric_name=metric_name,
|
378
|
+
details=details,
|
379
|
+
history=history,
|
380
|
+
llm_calls=self.llm_call_counter,
|
381
|
+
tool_calls=self.tool_call_counter,
|
382
|
+
)
|
@@ -0,0 +1,125 @@
|
|
1
|
+
"""Parameter search space for collections of tunable parameters."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
import copy
|
6
|
+
from typing import Any
|
7
|
+
from collections.abc import Mapping, Sequence
|
8
|
+
|
9
|
+
from optuna.trial import Trial
|
10
|
+
from pydantic import BaseModel, Field, model_validator
|
11
|
+
|
12
|
+
from .parameter_spec import ParameterSpec
|
13
|
+
|
14
|
+
|
15
|
+
class ParameterSearchSpace(BaseModel):
|
16
|
+
"""Collection of parameters to explore during optimization."""
|
17
|
+
|
18
|
+
parameters: list[ParameterSpec] = Field(default_factory=list)
|
19
|
+
|
20
|
+
model_config = {
|
21
|
+
"extra": "forbid",
|
22
|
+
}
|
23
|
+
|
24
|
+
@model_validator(mode="before")
|
25
|
+
@classmethod
|
26
|
+
def _normalize(cls, data: Any) -> Any:
|
27
|
+
if isinstance(data, ParameterSearchSpace):
|
28
|
+
return data
|
29
|
+
if isinstance(data, Mapping):
|
30
|
+
if "parameters" in data:
|
31
|
+
return data
|
32
|
+
parameters = []
|
33
|
+
for name, spec in data.items():
|
34
|
+
if isinstance(spec, Mapping):
|
35
|
+
spec_dict = dict(spec)
|
36
|
+
elif isinstance(spec, ParameterSpec):
|
37
|
+
spec_dict = spec.model_dump()
|
38
|
+
else:
|
39
|
+
raise TypeError(
|
40
|
+
"Parameter definitions must be mappings or ParameterSpec instances"
|
41
|
+
)
|
42
|
+
spec_dict.setdefault("name", name)
|
43
|
+
parameters.append(spec_dict)
|
44
|
+
return {"parameters": parameters}
|
45
|
+
if isinstance(data, Sequence):
|
46
|
+
return {"parameters": list(data)}
|
47
|
+
return data
|
48
|
+
|
49
|
+
@model_validator(mode="after")
|
50
|
+
def _validate(self) -> ParameterSearchSpace:
|
51
|
+
names = [spec.name for spec in self.parameters]
|
52
|
+
if len(names) != len(set(names)):
|
53
|
+
duplicates = {name for name in names if names.count(name) > 1}
|
54
|
+
raise ValueError(
|
55
|
+
f"Duplicate parameter names detected: {', '.join(sorted(duplicates))}"
|
56
|
+
)
|
57
|
+
if not self.parameters:
|
58
|
+
raise ValueError("Parameter search space cannot be empty")
|
59
|
+
return self
|
60
|
+
|
61
|
+
def suggest(self, trial: Trial) -> dict[str, Any]:
|
62
|
+
"""Sample a set of parameter values using an Optuna trial."""
|
63
|
+
return {spec.name: spec.suggest(trial) for spec in self.parameters}
|
64
|
+
|
65
|
+
def apply(
|
66
|
+
self,
|
67
|
+
prompt: Any, # ChatPrompt type
|
68
|
+
values: Mapping[str, Any],
|
69
|
+
*,
|
70
|
+
base_model_kwargs: dict[str, Any] | None = None,
|
71
|
+
) -> Any: # Returns ChatPrompt
|
72
|
+
"""Return a prompt copy with sampled values applied."""
|
73
|
+
prompt_copy = prompt.copy()
|
74
|
+
if base_model_kwargs is not None:
|
75
|
+
prompt_copy.model_kwargs = copy.deepcopy(base_model_kwargs)
|
76
|
+
for spec in self.parameters:
|
77
|
+
if spec.name in values:
|
78
|
+
spec.apply_to_prompt(prompt_copy, values[spec.name])
|
79
|
+
return prompt_copy
|
80
|
+
|
81
|
+
def values_to_model_kwargs(
|
82
|
+
self,
|
83
|
+
values: Mapping[str, Any],
|
84
|
+
*,
|
85
|
+
base: dict[str, Any] | None = None,
|
86
|
+
) -> dict[str, Any]:
|
87
|
+
"""Produce a model_kwargs dictionary with sampled values applied."""
|
88
|
+
model_kwargs = copy.deepcopy(base) if base is not None else {}
|
89
|
+
for spec in self.parameters:
|
90
|
+
if spec.name in values:
|
91
|
+
spec.apply_to_model_kwargs(model_kwargs, values[spec.name])
|
92
|
+
return model_kwargs
|
93
|
+
|
94
|
+
def model_dump(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
|
95
|
+
"""Ensure dumping keeps parameter definitions accessible."""
|
96
|
+
return super().model_dump(*args, **kwargs)
|
97
|
+
|
98
|
+
def narrow_around(
|
99
|
+
self, values: Mapping[str, Any], scale: float
|
100
|
+
) -> ParameterSearchSpace:
|
101
|
+
"""Return a new search space narrowed around provided parameter values."""
|
102
|
+
|
103
|
+
narrowed: list[ParameterSpec] = []
|
104
|
+
for spec in self.parameters:
|
105
|
+
value = values.get(spec.name)
|
106
|
+
narrowed.append(spec.narrow(value, scale))
|
107
|
+
return ParameterSearchSpace(parameters=narrowed)
|
108
|
+
|
109
|
+
def describe(self) -> dict[str, dict[str, Any]]:
|
110
|
+
"""Return a human-friendly description of each parameter range."""
|
111
|
+
|
112
|
+
summary: dict[str, dict[str, Any]] = {}
|
113
|
+
for spec in self.parameters:
|
114
|
+
entry: dict[str, Any] = {"type": spec.distribution.value}
|
115
|
+
if spec.distribution.value in {"float", "int"}:
|
116
|
+
entry["min"] = spec.low
|
117
|
+
entry["max"] = spec.high
|
118
|
+
if spec.step is not None:
|
119
|
+
entry["step"] = spec.step
|
120
|
+
entry["scale"] = spec.scale
|
121
|
+
else:
|
122
|
+
if spec.choices is not None:
|
123
|
+
entry["choices"] = list(spec.choices)
|
124
|
+
summary[spec.name] = entry
|
125
|
+
return summary
|