opik-optimizer 2.0.1__py3-none-any.whl → 2.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opik_optimizer/__init__.py +12 -0
- opik_optimizer/base_optimizer.py +33 -0
- opik_optimizer/hierarchical_reflective_optimizer/__init__.py +5 -0
- opik_optimizer/hierarchical_reflective_optimizer/hierarchical_reflective_optimizer.py +718 -0
- opik_optimizer/hierarchical_reflective_optimizer/hierarchical_root_cause_analyzer.py +355 -0
- opik_optimizer/hierarchical_reflective_optimizer/prompts.py +91 -0
- opik_optimizer/hierarchical_reflective_optimizer/reporting.py +679 -0
- opik_optimizer/hierarchical_reflective_optimizer/types.py +49 -0
- opik_optimizer/optimization_result.py +227 -6
- opik_optimizer/parameter_optimizer/__init__.py +11 -0
- opik_optimizer/parameter_optimizer/parameter_optimizer.py +382 -0
- opik_optimizer/parameter_optimizer/parameter_search_space.py +125 -0
- opik_optimizer/parameter_optimizer/parameter_spec.py +214 -0
- opik_optimizer/parameter_optimizer/search_space_types.py +24 -0
- opik_optimizer/parameter_optimizer/sensitivity_analysis.py +71 -0
- {opik_optimizer-2.0.1.dist-info → opik_optimizer-2.1.1.dist-info}/METADATA +4 -2
- {opik_optimizer-2.0.1.dist-info → opik_optimizer-2.1.1.dist-info}/RECORD +20 -8
- {opik_optimizer-2.0.1.dist-info → opik_optimizer-2.1.1.dist-info}/WHEEL +0 -0
- {opik_optimizer-2.0.1.dist-info → opik_optimizer-2.1.1.dist-info}/licenses/LICENSE +0 -0
- {opik_optimizer-2.0.1.dist-info → opik_optimizer-2.1.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,718 @@
|
|
1
|
+
from opik.environment import get_tqdm_for_current_environment
|
2
|
+
import os
|
3
|
+
import logging
|
4
|
+
|
5
|
+
import opik
|
6
|
+
import litellm
|
7
|
+
from litellm.caching import Cache
|
8
|
+
from litellm.types.caching import LiteLLMCacheType
|
9
|
+
from opik.evaluation.evaluation_result import EvaluationResult
|
10
|
+
from opik.evaluation.models.litellm import opik_monitor as opik_litellm_monitor
|
11
|
+
from opik.evaluation import evaluator as opik_evaluator
|
12
|
+
|
13
|
+
from typing import Any, TypeVar
|
14
|
+
from collections.abc import Callable
|
15
|
+
from pydantic import BaseModel
|
16
|
+
from .. import _throttle
|
17
|
+
from ..base_optimizer import BaseOptimizer
|
18
|
+
from ..optimization_config import chat_prompt, mappers
|
19
|
+
from ..optimizable_agent import OptimizableAgent
|
20
|
+
|
21
|
+
from opik_optimizer.task_evaluator import _create_metric_class
|
22
|
+
from opik_optimizer.optimization_result import OptimizationResult
|
23
|
+
from . import reporting
|
24
|
+
from .hierarchical_root_cause_analyzer import HierarchicalRootCauseAnalyzer
|
25
|
+
from .types import (
|
26
|
+
FailureMode,
|
27
|
+
ImprovedPrompt,
|
28
|
+
HierarchicalRootCauseAnalysis,
|
29
|
+
)
|
30
|
+
from .prompts import IMPROVE_PROMPT_TEMPLATE
|
31
|
+
|
32
|
+
tqdm = get_tqdm_for_current_environment()
|
33
|
+
|
34
|
+
# Using disk cache for LLM calls
|
35
|
+
disk_cache_dir = os.path.expanduser("~/.litellm_cache")
|
36
|
+
litellm.cache = Cache(type=LiteLLMCacheType.DISK, disk_cache_dir=disk_cache_dir)
|
37
|
+
|
38
|
+
# Set up logging
|
39
|
+
logger = logging.getLogger(__name__) # Gets logger configured by setup_logging
|
40
|
+
|
41
|
+
_rate_limiter = _throttle.get_rate_limiter_for_current_opik_installation()
|
42
|
+
|
43
|
+
# Type variable for generic structured output
|
44
|
+
T = TypeVar("T", bound=BaseModel)
|
45
|
+
|
46
|
+
|
47
|
+
class HierarchicalReflectiveOptimizer(BaseOptimizer):
|
48
|
+
"""
|
49
|
+
The Hierarchical Reflective Optimizer uses hierarchical root cause analysis to improve prompts
|
50
|
+
based on failure modes identified during the evaluation process.
|
51
|
+
|
52
|
+
This algorithm uses a two-stage hierarchical approach: analyzing failures in batches and then
|
53
|
+
synthesizing findings to identify unified failure modes. It's best suited when you have a
|
54
|
+
complex prompt that you want to systematically refine based on understanding why it fails.
|
55
|
+
|
56
|
+
Args:
|
57
|
+
reasoning_model: LiteLLM model name for reasoning and analysis (default: "openai/gpt-4.1")
|
58
|
+
num_threads: Number of parallel threads for evaluation (default: 12)
|
59
|
+
verbose: Controls internal logging/progress bars (0=off, 1=on) (default: 1)
|
60
|
+
seed: Random seed for reproducibility (default: 42)
|
61
|
+
max_parallel_batches: Maximum number of batches to process concurrently during
|
62
|
+
hierarchical root cause analysis (default: 5)
|
63
|
+
batch_size: Number of test cases per batch for root cause analysis (default: 25)
|
64
|
+
**model_kwargs: Additional arguments passed to the LLM model
|
65
|
+
"""
|
66
|
+
|
67
|
+
DEFAULT_ROUNDS = 10
|
68
|
+
|
69
|
+
def __init__(
|
70
|
+
self,
|
71
|
+
reasoning_model: str = "openai/gpt-4.1",
|
72
|
+
num_threads: int = 12,
|
73
|
+
verbose: int = 1,
|
74
|
+
seed: int = 42,
|
75
|
+
max_parallel_batches: int = 5,
|
76
|
+
batch_size: int = 25,
|
77
|
+
**model_kwargs: Any,
|
78
|
+
):
|
79
|
+
super().__init__(
|
80
|
+
model=reasoning_model, verbose=verbose, seed=seed, **model_kwargs
|
81
|
+
)
|
82
|
+
self.reasoning_model = reasoning_model
|
83
|
+
self.num_threads = num_threads
|
84
|
+
self.max_parallel_batches = max_parallel_batches
|
85
|
+
self.batch_size = batch_size
|
86
|
+
|
87
|
+
# Initialize hierarchical analyzer
|
88
|
+
self._hierarchical_analyzer = HierarchicalRootCauseAnalyzer(
|
89
|
+
call_model_fn=self._call_model_async,
|
90
|
+
reasoning_model=self.reasoning_model,
|
91
|
+
seed=self.seed,
|
92
|
+
max_parallel_batches=self.max_parallel_batches,
|
93
|
+
batch_size=self.batch_size,
|
94
|
+
verbose=self.verbose,
|
95
|
+
)
|
96
|
+
|
97
|
+
def _prepare_model_params(
|
98
|
+
self,
|
99
|
+
model_kwargs: dict[str, Any],
|
100
|
+
response_model: type[T] | None = None,
|
101
|
+
) -> dict[str, Any]:
|
102
|
+
"""
|
103
|
+
Prepare parameters for LiteLLM call by filtering and adding monitoring.
|
104
|
+
|
105
|
+
Args:
|
106
|
+
model_kwargs: Additional model parameters
|
107
|
+
response_model: Optional Pydantic model for structured output
|
108
|
+
|
109
|
+
Returns:
|
110
|
+
Dictionary of parameters ready for litellm.completion/acompletion
|
111
|
+
"""
|
112
|
+
current_model_kwargs = self.model_kwargs.copy()
|
113
|
+
current_model_kwargs.update(model_kwargs)
|
114
|
+
|
115
|
+
# Filter out optimizer-specific kwargs that shouldn't be passed to LiteLLM
|
116
|
+
filtered_call_kwargs = current_model_kwargs.copy()
|
117
|
+
filtered_call_kwargs.pop("n_trials", None)
|
118
|
+
filtered_call_kwargs.pop("n_samples", None)
|
119
|
+
filtered_call_kwargs.pop("n_iterations", None)
|
120
|
+
filtered_call_kwargs.pop("min_examples", None)
|
121
|
+
filtered_call_kwargs.pop("max_examples", None)
|
122
|
+
filtered_call_kwargs.pop("project_name", None)
|
123
|
+
|
124
|
+
final_params_for_litellm = (
|
125
|
+
opik_litellm_monitor.try_add_opik_monitoring_to_params(filtered_call_kwargs)
|
126
|
+
)
|
127
|
+
|
128
|
+
# Add structured output support if response_model is provided
|
129
|
+
# According to LiteLLM docs: https://docs.litellm.ai/docs/completion/json_mode
|
130
|
+
# Pass the Pydantic model directly to response_format
|
131
|
+
if response_model is not None:
|
132
|
+
final_params_for_litellm["response_format"] = response_model
|
133
|
+
|
134
|
+
return final_params_for_litellm
|
135
|
+
|
136
|
+
def _parse_response(
|
137
|
+
self,
|
138
|
+
response: Any,
|
139
|
+
response_model: type[T] | None = None,
|
140
|
+
) -> T | str:
|
141
|
+
"""
|
142
|
+
Parse LiteLLM response, with optional structured output parsing.
|
143
|
+
|
144
|
+
Args:
|
145
|
+
response: The response from litellm.completion/acompletion
|
146
|
+
response_model: Optional Pydantic model for structured output
|
147
|
+
|
148
|
+
Returns:
|
149
|
+
If response_model is provided, returns an instance of that model.
|
150
|
+
Otherwise, returns the raw string response.
|
151
|
+
"""
|
152
|
+
content = response.choices[0].message.content
|
153
|
+
|
154
|
+
# When using structured outputs with Pydantic models, LiteLLM automatically
|
155
|
+
# parses the response. Parse the JSON string into the Pydantic model
|
156
|
+
if response_model is not None:
|
157
|
+
return response_model.model_validate_json(content)
|
158
|
+
|
159
|
+
return content
|
160
|
+
|
161
|
+
@_throttle.rate_limited(_rate_limiter)
|
162
|
+
def _call_model(
|
163
|
+
self,
|
164
|
+
model: str,
|
165
|
+
messages: list[dict[str, str]],
|
166
|
+
seed: int,
|
167
|
+
model_kwargs: dict[str, Any],
|
168
|
+
response_model: type[T] | None = None,
|
169
|
+
) -> T | str:
|
170
|
+
"""
|
171
|
+
Call the LLM model with optional structured output.
|
172
|
+
|
173
|
+
Args:
|
174
|
+
model: The model to use for the call
|
175
|
+
messages: List of message dictionaries with 'role' and 'content' keys
|
176
|
+
seed: Random seed for reproducibility
|
177
|
+
model_kwargs: Additional model parameters
|
178
|
+
response_model: Optional Pydantic model for structured output
|
179
|
+
|
180
|
+
Returns:
|
181
|
+
If response_model is provided, returns an instance of that model.
|
182
|
+
Otherwise, returns the raw string response.
|
183
|
+
"""
|
184
|
+
self.increment_llm_counter()
|
185
|
+
|
186
|
+
final_params_for_litellm = self._prepare_model_params(
|
187
|
+
model_kwargs, response_model
|
188
|
+
)
|
189
|
+
|
190
|
+
response = litellm.completion(
|
191
|
+
model=model,
|
192
|
+
messages=messages,
|
193
|
+
seed=seed,
|
194
|
+
num_retries=6,
|
195
|
+
**final_params_for_litellm,
|
196
|
+
)
|
197
|
+
|
198
|
+
return self._parse_response(response, response_model)
|
199
|
+
|
200
|
+
@_throttle.rate_limited(_rate_limiter)
|
201
|
+
async def _call_model_async(
|
202
|
+
self,
|
203
|
+
model: str,
|
204
|
+
messages: list[dict[str, str]],
|
205
|
+
seed: int,
|
206
|
+
model_kwargs: dict[str, Any],
|
207
|
+
response_model: type[T] | None = None,
|
208
|
+
) -> T | str:
|
209
|
+
"""
|
210
|
+
Async version of _call_model using litellm.acompletion.
|
211
|
+
|
212
|
+
Args:
|
213
|
+
model: The model to use for the call
|
214
|
+
messages: List of message dictionaries with 'role' and 'content' keys
|
215
|
+
seed: Random seed for reproducibility
|
216
|
+
model_kwargs: Additional model parameters
|
217
|
+
response_model: Optional Pydantic model for structured output
|
218
|
+
|
219
|
+
Returns:
|
220
|
+
If response_model is provided, returns an instance of that model.
|
221
|
+
Otherwise, returns the raw string response.
|
222
|
+
"""
|
223
|
+
self.increment_llm_counter()
|
224
|
+
|
225
|
+
final_params_for_litellm = self._prepare_model_params(
|
226
|
+
model_kwargs, response_model
|
227
|
+
)
|
228
|
+
|
229
|
+
response = await litellm.acompletion(
|
230
|
+
model=model,
|
231
|
+
messages=messages,
|
232
|
+
seed=seed,
|
233
|
+
num_retries=6,
|
234
|
+
**final_params_for_litellm,
|
235
|
+
)
|
236
|
+
|
237
|
+
return self._parse_response(response, response_model)
|
238
|
+
|
239
|
+
def get_optimizer_metadata(self) -> dict[str, Any]:
|
240
|
+
"""
|
241
|
+
Get metadata about the optimizer configuration.
|
242
|
+
|
243
|
+
Returns:
|
244
|
+
Dictionary containing optimizer-specific configuration
|
245
|
+
"""
|
246
|
+
return {
|
247
|
+
"reasoning_model": self.reasoning_model,
|
248
|
+
"num_threads": self.num_threads,
|
249
|
+
"max_parallel_batches": self.max_parallel_batches,
|
250
|
+
"seed": self.seed,
|
251
|
+
"verbose": self.verbose,
|
252
|
+
}
|
253
|
+
|
254
|
+
def _calculate_improvement(
|
255
|
+
self, current_score: float, previous_score: float
|
256
|
+
) -> float:
|
257
|
+
"""Calculate the improvement percentage between scores."""
|
258
|
+
return (
|
259
|
+
(current_score - previous_score) / previous_score
|
260
|
+
if previous_score > 0
|
261
|
+
else 0
|
262
|
+
)
|
263
|
+
|
264
|
+
def _evaluate_prompt(
|
265
|
+
self,
|
266
|
+
prompt: chat_prompt.ChatPrompt,
|
267
|
+
dataset: opik.Dataset,
|
268
|
+
metric: Callable,
|
269
|
+
optimization_id: str,
|
270
|
+
n_samples: int | None = None,
|
271
|
+
experiment_config: dict | None = None,
|
272
|
+
**kwargs: Any,
|
273
|
+
) -> EvaluationResult:
|
274
|
+
"""
|
275
|
+
Args:
|
276
|
+
dataset: Opik Dataset to evaluate the prompt on
|
277
|
+
metric: Metric functions
|
278
|
+
use_full_dataset: Whether to use the full dataset or a subset
|
279
|
+
experiment_config: Optional configuration for the experiment, useful to log additional metadata
|
280
|
+
n_samples: Optional number of items to test in the dataset
|
281
|
+
optimization_id: Optional ID of the optimization
|
282
|
+
verbose: Controls internal logging/progress bars (0=off, 1=on).
|
283
|
+
Returns:
|
284
|
+
float: The evaluation score
|
285
|
+
"""
|
286
|
+
logger.debug("Using full dataset for evaluation")
|
287
|
+
|
288
|
+
configuration_updates = self._drop_none({"n_samples": n_samples})
|
289
|
+
meta_metadata = self._drop_none(
|
290
|
+
{"optimization_id": optimization_id, "stage": "trial_evaluation"}
|
291
|
+
)
|
292
|
+
experiment_config = self._prepare_experiment_config(
|
293
|
+
prompt=prompt,
|
294
|
+
dataset=dataset,
|
295
|
+
metric=metric,
|
296
|
+
experiment_config=experiment_config,
|
297
|
+
configuration_updates=configuration_updates,
|
298
|
+
additional_metadata={"meta_prompt": meta_metadata}
|
299
|
+
if meta_metadata
|
300
|
+
else None,
|
301
|
+
)
|
302
|
+
|
303
|
+
def llm_task(dataset_item: dict[str, Any]) -> dict[str, str]:
|
304
|
+
new_prompt = prompt.copy()
|
305
|
+
messages = new_prompt.get_messages(dataset_item)
|
306
|
+
new_prompt.set_messages(messages)
|
307
|
+
agent = self.agent_class(prompt=new_prompt)
|
308
|
+
|
309
|
+
try:
|
310
|
+
logger.debug(
|
311
|
+
f"Calling LLM with prompt length: {sum(len(msg['content']) for msg in messages)}"
|
312
|
+
)
|
313
|
+
raw_model_output = agent.invoke(messages)
|
314
|
+
logger.debug(f"LLM raw response length: {len(raw_model_output)}")
|
315
|
+
logger.debug(f"LLM raw output: {raw_model_output}")
|
316
|
+
except Exception as e:
|
317
|
+
logger.error(f"Error calling model with prompt: {e}")
|
318
|
+
logger.error(f"Failed prompt: {messages}")
|
319
|
+
logger.error(
|
320
|
+
f"Prompt length: {sum(len(msg['content']) for msg in messages)}"
|
321
|
+
)
|
322
|
+
raise
|
323
|
+
|
324
|
+
cleaned_model_output = raw_model_output.strip()
|
325
|
+
|
326
|
+
result = {
|
327
|
+
mappers.EVALUATED_LLM_TASK_OUTPUT: cleaned_model_output,
|
328
|
+
}
|
329
|
+
return result
|
330
|
+
|
331
|
+
# Use dataset's get_items with limit for sampling
|
332
|
+
logger.debug(
|
333
|
+
f"Starting evaluation with {n_samples if n_samples else 'all'} samples for metric: {getattr(metric, '__name__', str(metric))}"
|
334
|
+
)
|
335
|
+
result = opik_evaluator.evaluate_optimization_trial(
|
336
|
+
optimization_id=optimization_id,
|
337
|
+
dataset=dataset,
|
338
|
+
task=llm_task,
|
339
|
+
scoring_metrics=[_create_metric_class(metric)],
|
340
|
+
task_threads=self.num_threads,
|
341
|
+
nb_samples=n_samples,
|
342
|
+
experiment_config=experiment_config,
|
343
|
+
verbose=self.verbose,
|
344
|
+
)
|
345
|
+
|
346
|
+
return result
|
347
|
+
|
348
|
+
def _hierarchical_root_cause_analysis(
|
349
|
+
self, evaluation_result: EvaluationResult
|
350
|
+
) -> HierarchicalRootCauseAnalysis:
|
351
|
+
"""
|
352
|
+
Perform hierarchical root cause analysis on evaluation results.
|
353
|
+
|
354
|
+
This method uses a two-stage hierarchical approach:
|
355
|
+
1. Split results into batches and analyze each batch
|
356
|
+
2. Synthesize batch analyses into unified failure modes
|
357
|
+
|
358
|
+
Args:
|
359
|
+
evaluation_result: The evaluation result to analyze
|
360
|
+
|
361
|
+
Returns:
|
362
|
+
HierarchicalRootCauseAnalysis containing batch analyses and overall synthesis
|
363
|
+
"""
|
364
|
+
logger.debug("Performing hierarchical root cause analysis...")
|
365
|
+
return self._hierarchical_analyzer.analyze(evaluation_result)
|
366
|
+
|
367
|
+
def _improve_prompt(
|
368
|
+
self, prompt: chat_prompt.ChatPrompt, root_cause: FailureMode, attempt: int = 1
|
369
|
+
) -> ImprovedPrompt:
|
370
|
+
"""
|
371
|
+
Improve the prompt based on the root cause analysis.
|
372
|
+
|
373
|
+
Args:
|
374
|
+
prompt: Current prompt to improve
|
375
|
+
root_cause: The failure mode to address
|
376
|
+
attempt: Attempt number (1-indexed). Used to vary seed for retries.
|
377
|
+
|
378
|
+
Returns:
|
379
|
+
ImprovedPrompt with reasoning and improved messages
|
380
|
+
"""
|
381
|
+
|
382
|
+
improve_prompt_prompt = IMPROVE_PROMPT_TEMPLATE.format(
|
383
|
+
current_prompt=prompt.get_messages(),
|
384
|
+
failure_mode_name=root_cause.name,
|
385
|
+
failure_mode_description=root_cause.description,
|
386
|
+
failure_mode_root_cause=root_cause.root_cause,
|
387
|
+
)
|
388
|
+
|
389
|
+
# Vary seed based on attempt to avoid cache hits and ensure different results
|
390
|
+
# Each attempt gets a unique seed: base_seed, base_seed+1000, base_seed+2000, etc.
|
391
|
+
attempt_seed = self.seed + (attempt - 1) * 1000
|
392
|
+
|
393
|
+
if attempt > 1:
|
394
|
+
logger.debug(
|
395
|
+
f"Retry attempt {attempt}: Using seed {attempt_seed} (base seed: {self.seed})"
|
396
|
+
)
|
397
|
+
|
398
|
+
improve_prompt_response = self._call_model(
|
399
|
+
model=self.reasoning_model,
|
400
|
+
messages=[{"role": "user", "content": improve_prompt_prompt}],
|
401
|
+
seed=attempt_seed,
|
402
|
+
model_kwargs={},
|
403
|
+
response_model=ImprovedPrompt,
|
404
|
+
)
|
405
|
+
|
406
|
+
return improve_prompt_response
|
407
|
+
|
408
|
+
def _generate_and_evaluate_improvement(
|
409
|
+
self,
|
410
|
+
root_cause: FailureMode,
|
411
|
+
best_prompt: chat_prompt.ChatPrompt,
|
412
|
+
best_score: float,
|
413
|
+
prompt: chat_prompt.ChatPrompt,
|
414
|
+
dataset: opik.Dataset,
|
415
|
+
metric: Callable,
|
416
|
+
optimization_id: str,
|
417
|
+
n_samples: int | None,
|
418
|
+
attempt: int,
|
419
|
+
max_attempts: int,
|
420
|
+
) -> tuple[chat_prompt.ChatPrompt, float]:
|
421
|
+
"""
|
422
|
+
Generate and evaluate a single improvement attempt for a failure mode.
|
423
|
+
|
424
|
+
Args:
|
425
|
+
root_cause: The failure mode to address
|
426
|
+
best_prompt: The current best prompt to improve upon
|
427
|
+
best_score: The current best score (for comparison)
|
428
|
+
prompt: The original prompt (for metadata like name and tools)
|
429
|
+
dataset: Dataset to evaluate on
|
430
|
+
metric: Metric function
|
431
|
+
optimization_id: ID of the optimization
|
432
|
+
n_samples: Optional number of samples
|
433
|
+
attempt: Current attempt number (1-indexed)
|
434
|
+
max_attempts: Total number of attempts
|
435
|
+
|
436
|
+
Returns:
|
437
|
+
Tuple of (improved_prompt, improved_score)
|
438
|
+
"""
|
439
|
+
# Generate improvement with progress indication
|
440
|
+
with reporting.display_prompt_improvement(
|
441
|
+
failure_mode_name=root_cause.name, verbose=self.verbose
|
442
|
+
) as improvement_reporter:
|
443
|
+
improved_prompt_response = self._improve_prompt(
|
444
|
+
prompt=best_prompt, root_cause=root_cause, attempt=attempt
|
445
|
+
)
|
446
|
+
improvement_reporter.set_reasoning(improved_prompt_response.reasoning)
|
447
|
+
|
448
|
+
# Convert to chat prompt
|
449
|
+
messages_as_dicts = [
|
450
|
+
{"role": msg.role, "content": msg.content}
|
451
|
+
for msg in improved_prompt_response.messages
|
452
|
+
]
|
453
|
+
|
454
|
+
improved_chat_prompt = chat_prompt.ChatPrompt(
|
455
|
+
name=prompt.name,
|
456
|
+
messages=messages_as_dicts,
|
457
|
+
tools=prompt.tools,
|
458
|
+
)
|
459
|
+
|
460
|
+
# Evaluate improved prompt
|
461
|
+
eval_message = f"Evaluating improvement for failure mode '{root_cause.name}'"
|
462
|
+
if max_attempts > 1:
|
463
|
+
eval_message += f" (attempt {attempt}/{max_attempts})"
|
464
|
+
eval_message += ":"
|
465
|
+
|
466
|
+
with reporting.display_evaluation(
|
467
|
+
message=eval_message,
|
468
|
+
verbose=self.verbose,
|
469
|
+
indent="│ ",
|
470
|
+
baseline_score=best_score, # Pass baseline for comparison
|
471
|
+
) as improved_reporter:
|
472
|
+
improved_experiment_result = self._evaluate_prompt(
|
473
|
+
prompt=improved_chat_prompt,
|
474
|
+
dataset=dataset,
|
475
|
+
metric=metric,
|
476
|
+
optimization_id=optimization_id,
|
477
|
+
n_samples=n_samples,
|
478
|
+
)
|
479
|
+
|
480
|
+
improved_score = sum(
|
481
|
+
[
|
482
|
+
x.score_results[0].value
|
483
|
+
for x in improved_experiment_result.test_results
|
484
|
+
]
|
485
|
+
) / len(improved_experiment_result.test_results)
|
486
|
+
improved_reporter.set_score(improved_score)
|
487
|
+
|
488
|
+
return improved_chat_prompt, improved_score
|
489
|
+
|
490
|
+
def optimize_prompt(
|
491
|
+
self,
|
492
|
+
prompt: chat_prompt.ChatPrompt,
|
493
|
+
dataset: opik.Dataset,
|
494
|
+
metric: Callable[..., Any],
|
495
|
+
experiment_config: dict | None = None,
|
496
|
+
n_samples: int | None = None,
|
497
|
+
auto_continue: bool = False,
|
498
|
+
agent_class: type[OptimizableAgent] | None = None,
|
499
|
+
max_retries: int = 2,
|
500
|
+
**kwargs: Any,
|
501
|
+
) -> OptimizationResult:
|
502
|
+
# Reset counters at the start of optimization
|
503
|
+
self.reset_counters()
|
504
|
+
|
505
|
+
# Configure prompt model if not set
|
506
|
+
self.configure_prompt_model(prompt)
|
507
|
+
|
508
|
+
# Setup agent class
|
509
|
+
self.agent_class = self.setup_agent_class(prompt, agent_class)
|
510
|
+
|
511
|
+
optimization = self.opik_client.create_optimization(
|
512
|
+
dataset_name=dataset.name,
|
513
|
+
objective_name=getattr(metric, "__name__", str(metric)),
|
514
|
+
metadata={"optimizer": self.__class__.__name__},
|
515
|
+
)
|
516
|
+
logger.debug(f"Created optimization with ID: {optimization.id}")
|
517
|
+
|
518
|
+
reporting.display_header(
|
519
|
+
algorithm=self.__class__.__name__,
|
520
|
+
optimization_id=optimization.id if optimization is not None else None,
|
521
|
+
dataset_id=dataset.id,
|
522
|
+
verbose=self.verbose,
|
523
|
+
)
|
524
|
+
reporting.display_configuration(
|
525
|
+
messages=prompt.get_messages(),
|
526
|
+
optimizer_config={
|
527
|
+
"optimizer": self.__class__.__name__,
|
528
|
+
"n_samples": n_samples,
|
529
|
+
"auto_continue": auto_continue,
|
530
|
+
"max_retries": max_retries,
|
531
|
+
},
|
532
|
+
verbose=self.verbose,
|
533
|
+
tools=getattr(prompt, "tools", None),
|
534
|
+
)
|
535
|
+
|
536
|
+
# First we will evaluate the prompt on the dataset
|
537
|
+
with reporting.display_evaluation(verbose=self.verbose) as baseline_reporter:
|
538
|
+
experiment_result = self._evaluate_prompt(
|
539
|
+
prompt=prompt,
|
540
|
+
dataset=dataset,
|
541
|
+
metric=metric,
|
542
|
+
optimization_id=optimization.id,
|
543
|
+
n_samples=n_samples,
|
544
|
+
)
|
545
|
+
|
546
|
+
avg_scores = sum(
|
547
|
+
[x.score_results[0].value for x in experiment_result.test_results]
|
548
|
+
) / len(experiment_result.test_results)
|
549
|
+
baseline_reporter.set_score(avg_scores)
|
550
|
+
|
551
|
+
# Track baseline and best scores
|
552
|
+
initial_score = avg_scores
|
553
|
+
best_score = initial_score
|
554
|
+
best_prompt = prompt
|
555
|
+
best_messages = prompt.get_messages()
|
556
|
+
initial_messages = list(
|
557
|
+
prompt.get_messages()
|
558
|
+
) # Store copy of initial messages for diff
|
559
|
+
|
560
|
+
# Iteration 1: Analyze and improve (structure ready for future multi-iteration support)
|
561
|
+
with reporting.display_optimization_iteration(
|
562
|
+
iteration=1, verbose=self.verbose
|
563
|
+
) as iteration_reporter:
|
564
|
+
# Perform hierarchical root cause analysis
|
565
|
+
with reporting.display_root_cause_analysis(
|
566
|
+
verbose=self.verbose
|
567
|
+
) as analysis_reporter:
|
568
|
+
hierarchical_analysis = self._hierarchical_root_cause_analysis(
|
569
|
+
experiment_result
|
570
|
+
)
|
571
|
+
analysis_reporter.set_completed(
|
572
|
+
total_test_cases=hierarchical_analysis.total_test_cases,
|
573
|
+
num_batches=hierarchical_analysis.num_batches,
|
574
|
+
)
|
575
|
+
|
576
|
+
# Display hierarchical synthesis and failure modes
|
577
|
+
if self.verbose:
|
578
|
+
reporting.display_hierarchical_synthesis(
|
579
|
+
total_test_cases=hierarchical_analysis.total_test_cases,
|
580
|
+
num_batches=hierarchical_analysis.num_batches,
|
581
|
+
synthesis_notes=hierarchical_analysis.synthesis_notes,
|
582
|
+
verbose=self.verbose,
|
583
|
+
)
|
584
|
+
|
585
|
+
reporting.display_failure_modes(
|
586
|
+
failure_modes=hierarchical_analysis.unified_failure_modes,
|
587
|
+
verbose=self.verbose,
|
588
|
+
)
|
589
|
+
|
590
|
+
# Generate improved prompt for each failure mode
|
591
|
+
for idx, root_cause in enumerate(
|
592
|
+
hierarchical_analysis.unified_failure_modes, 1
|
593
|
+
):
|
594
|
+
logger.debug(
|
595
|
+
f"Addressing failure mode {idx}/{len(hierarchical_analysis.unified_failure_modes)}: {root_cause.name}"
|
596
|
+
)
|
597
|
+
|
598
|
+
# Try multiple attempts if needed
|
599
|
+
max_attempts = max_retries + 1
|
600
|
+
improved_chat_prompt = None
|
601
|
+
improved_score = None
|
602
|
+
|
603
|
+
for attempt in range(1, max_attempts + 1):
|
604
|
+
# Generate and evaluate improvement
|
605
|
+
improved_chat_prompt, improved_score = (
|
606
|
+
self._generate_and_evaluate_improvement(
|
607
|
+
root_cause=root_cause,
|
608
|
+
best_prompt=best_prompt,
|
609
|
+
best_score=best_score,
|
610
|
+
prompt=prompt,
|
611
|
+
dataset=dataset,
|
612
|
+
metric=metric,
|
613
|
+
optimization_id=optimization.id,
|
614
|
+
n_samples=n_samples,
|
615
|
+
attempt=attempt,
|
616
|
+
max_attempts=max_attempts,
|
617
|
+
)
|
618
|
+
)
|
619
|
+
|
620
|
+
# Check if we got improvement
|
621
|
+
if improved_score > best_score:
|
622
|
+
logger.info(
|
623
|
+
f"Improvement found for '{root_cause.name}' on attempt {attempt}"
|
624
|
+
)
|
625
|
+
break
|
626
|
+
|
627
|
+
# No improvement - should we retry?
|
628
|
+
if attempt < max_attempts:
|
629
|
+
reporting.display_retry_attempt(
|
630
|
+
attempt=attempt,
|
631
|
+
max_attempts=max_attempts,
|
632
|
+
failure_mode_name=root_cause.name,
|
633
|
+
verbose=self.verbose,
|
634
|
+
)
|
635
|
+
else:
|
636
|
+
logger.debug(
|
637
|
+
f"No improvement after {attempt} attempts for '{root_cause.name}'"
|
638
|
+
)
|
639
|
+
|
640
|
+
# Check if final result is an improvement
|
641
|
+
if (
|
642
|
+
improved_score is not None
|
643
|
+
and improved_chat_prompt is not None
|
644
|
+
and improved_score > best_score
|
645
|
+
):
|
646
|
+
improvement = self._calculate_improvement(
|
647
|
+
improved_score, best_score
|
648
|
+
)
|
649
|
+
|
650
|
+
# Display improvement for this iteration
|
651
|
+
reporting.display_iteration_improvement(
|
652
|
+
improvement=improvement,
|
653
|
+
current_score=improved_score,
|
654
|
+
best_score=best_score,
|
655
|
+
verbose=self.verbose,
|
656
|
+
)
|
657
|
+
|
658
|
+
# Update best
|
659
|
+
best_score = improved_score
|
660
|
+
best_prompt = improved_chat_prompt
|
661
|
+
best_messages = improved_chat_prompt.get_messages()
|
662
|
+
logger.info(
|
663
|
+
f"Updated best prompt after addressing '{root_cause.name}'"
|
664
|
+
)
|
665
|
+
else:
|
666
|
+
logger.debug(
|
667
|
+
f"Keeping previous best prompt, no improvement from '{root_cause.name}'"
|
668
|
+
)
|
669
|
+
|
670
|
+
# Mark iteration complete
|
671
|
+
improved_since_start = best_score > initial_score
|
672
|
+
iteration_reporter.iteration_complete(
|
673
|
+
best_score=best_score, improved=improved_since_start
|
674
|
+
)
|
675
|
+
|
676
|
+
# Display final optimization result with diff
|
677
|
+
reporting.display_optimized_prompt_diff(
|
678
|
+
initial_messages=initial_messages,
|
679
|
+
optimized_messages=best_messages,
|
680
|
+
initial_score=initial_score,
|
681
|
+
best_score=best_score,
|
682
|
+
verbose=self.verbose,
|
683
|
+
)
|
684
|
+
|
685
|
+
# Prepare details for the result
|
686
|
+
details = {
|
687
|
+
"reasoning_model": self.reasoning_model,
|
688
|
+
"num_threads": self.num_threads,
|
689
|
+
"max_parallel_batches": self.max_parallel_batches,
|
690
|
+
"max_retries": max_retries,
|
691
|
+
"n_samples": n_samples,
|
692
|
+
"auto_continue": auto_continue,
|
693
|
+
}
|
694
|
+
|
695
|
+
# Extract tool prompts if tools exist
|
696
|
+
tool_prompts = None
|
697
|
+
if final_tools := getattr(best_prompt, "tools", None):
|
698
|
+
tool_prompts = {
|
699
|
+
tool.get("function", {}).get("name", f"tool_{idx}"): tool.get(
|
700
|
+
"function", {}
|
701
|
+
).get("description", "")
|
702
|
+
for idx, tool in enumerate(final_tools)
|
703
|
+
}
|
704
|
+
|
705
|
+
return OptimizationResult(
|
706
|
+
optimizer=self.__class__.__name__,
|
707
|
+
prompt=best_messages,
|
708
|
+
score=best_score,
|
709
|
+
metric_name=metric.__name__,
|
710
|
+
initial_prompt=prompt.get_messages(),
|
711
|
+
initial_score=initial_score,
|
712
|
+
details=details,
|
713
|
+
llm_calls=self.llm_call_counter,
|
714
|
+
tool_calls=self.tool_call_counter,
|
715
|
+
optimization_id=optimization.id,
|
716
|
+
dataset_id=dataset.id,
|
717
|
+
tool_prompts=tool_prompts,
|
718
|
+
)
|