opik-optimizer 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1100 @@
1
+ from typing import List, Dict, Any, Optional, Union
2
+ import opik
3
+ from opik import Dataset
4
+ import litellm
5
+ from litellm.caching import Cache
6
+ import logging
7
+ import json
8
+ import os
9
+ from string import Template
10
+
11
+ from .optimization_config import mappers
12
+ from .optimization_config.configs import MetricConfig, TaskConfig
13
+ from .base_optimizer import BaseOptimizer, OptimizationRound
14
+ from .optimization_result import OptimizationResult
15
+ from opik_optimizer import task_evaluator
16
+ from opik.api_objects import opik_client
17
+ from opik.evaluation.models.litellm import opik_monitor as opik_litellm_monitor
18
+ from opik.environment import get_tqdm_for_current_environment
19
+
20
+ tqdm = get_tqdm_for_current_environment()
21
+
22
+ # Using disk cache for LLM calls
23
+ disk_cache_dir = os.path.expanduser("~/.litellm_cache")
24
+ litellm.cache = Cache(type="disk", disk_cache_dir=disk_cache_dir)
25
+
26
+ # Set up logging
27
+ logger = logging.getLogger(__name__) # Gets logger configured by setup_logging
28
+
29
+
30
+ class MetaPromptOptimizer(BaseOptimizer):
31
+ """Optimizer that uses meta-prompting to improve prompts based on examples and performance."""
32
+
33
+ # --- Constants for Default Configuration ---
34
+ DEFAULT_MAX_ROUNDS = 3
35
+ DEFAULT_PROMPTS_PER_ROUND = 4
36
+ DEFAULT_IMPROVEMENT_THRESHOLD = 0.05
37
+ DEFAULT_INITIAL_TRIALS = 3
38
+ DEFAULT_MAX_TRIALS = 6
39
+ DEFAULT_ADAPTIVE_THRESHOLD = 0.8 # Set to None to disable adaptive trials
40
+
41
+ # --- Reasoning System Prompt ---
42
+ _REASONING_SYSTEM_PROMPT = """You are an expert prompt engineer. Your task is to improve prompts for any type of task.
43
+ Focus on making the prompt more effective by:
44
+ 1. Being clear and specific about what is expected
45
+ 2. Providing necessary context and constraints
46
+ 3. Guiding the model to produce the desired output format
47
+ 4. Removing ambiguity and unnecessary elements
48
+ 5. Maintaining conciseness while being complete
49
+
50
+ Return a JSON array of prompts with the following structure:
51
+ {
52
+ "prompts": [
53
+ {
54
+ "prompt": "the improved prompt text",
55
+ "improvement_focus": "what aspect this prompt improves",
56
+ "reasoning": "why this improvement should help"
57
+ }
58
+ ]
59
+ }"""
60
+
61
+ # --- Constants for Default Configuration ---
62
+ DEFAULT_MAX_ROUNDS = 3
63
+ DEFAULT_PROMPTS_PER_ROUND = 4
64
+ DEFAULT_IMPROVEMENT_THRESHOLD = 0.05
65
+ DEFAULT_INITIAL_TRIALS = 3
66
+ DEFAULT_MAX_TRIALS = 6
67
+ DEFAULT_ADAPTIVE_THRESHOLD = 0.8 # Set to None to disable adaptive trials
68
+
69
+ # --- Reasoning System Prompt ---
70
+ _REASONING_SYSTEM_PROMPT = """You are an expert prompt engineer. Your task is to improve prompts for any type of task.
71
+ Focus on making the prompt more effective by:
72
+ 1. Being clear and specific about what is expected
73
+ 2. Providing necessary context and constraints
74
+ 3. Guiding the model to produce the desired output format
75
+ 4. Removing ambiguity and unnecessary elements
76
+ 5. Maintaining conciseness while being complete
77
+
78
+ Return a JSON array of prompts with the following structure:
79
+ {
80
+ "prompts": [
81
+ {
82
+ "prompt": "the improved prompt text",
83
+ "improvement_focus": "what aspect this prompt improves",
84
+ "reasoning": "why this improvement should help"
85
+ }
86
+ ]
87
+ }"""
88
+
89
+ def __init__(
90
+ self,
91
+ model: str,
92
+ reasoning_model: str = None,
93
+ max_rounds: int = DEFAULT_MAX_ROUNDS,
94
+ num_prompts_per_round: int = DEFAULT_PROMPTS_PER_ROUND,
95
+ improvement_threshold: float = DEFAULT_IMPROVEMENT_THRESHOLD,
96
+ initial_trials_per_candidate: int = DEFAULT_INITIAL_TRIALS,
97
+ max_trials_per_candidate: int = DEFAULT_MAX_TRIALS,
98
+ adaptive_trial_threshold: Optional[float] = DEFAULT_ADAPTIVE_THRESHOLD,
99
+ num_threads: int = 12,
100
+ project_name: Optional[str] = None,
101
+ **model_kwargs,
102
+ ):
103
+ """
104
+ Initialize the MetaPromptOptimizer.
105
+
106
+ Args:
107
+ model: The model to use for evaluation
108
+ reasoning_model: The model to use for reasoning and prompt generation
109
+ max_rounds: Maximum number of optimization rounds
110
+ num_prompts_per_round: Number of prompts to generate per round
111
+ improvement_threshold: Minimum improvement required to continue
112
+ initial_trials_per_candidate: Number of initial evaluation trials for each candidate prompt.
113
+ max_trials_per_candidate: Maximum number of evaluation trials if adaptive trials are enabled and score is promising.
114
+ adaptive_trial_threshold: If not None, prompts scoring below `best_score * adaptive_trial_threshold` after initial trials won't get max trials.
115
+ num_threads: Number of threads for parallel evaluation
116
+ project_name: Optional project name for tracking
117
+ **model_kwargs: Additional model parameters
118
+ """
119
+ super().__init__(model=model, project_name=project_name, **model_kwargs)
120
+ self.reasoning_model = reasoning_model if reasoning_model is not None else model
121
+ self.max_rounds = max_rounds
122
+ self.num_prompts_per_round = num_prompts_per_round
123
+ self.improvement_threshold = improvement_threshold
124
+ self.initial_trials = initial_trials_per_candidate
125
+ self.max_trials = max_trials_per_candidate
126
+ self.adaptive_threshold = adaptive_trial_threshold
127
+ self.num_threads = num_threads
128
+ self.dataset = None
129
+ self.task_config = None
130
+ self._opik_client = opik_client.get_client_cached()
131
+ logger.debug(
132
+ f"Initialized MetaPromptOptimizer with model={model}, reasoning_model={self.reasoning_model}"
133
+ )
134
+ logger.debug(
135
+ f"Optimization rounds: {max_rounds}, Prompts/round: {num_prompts_per_round}"
136
+ )
137
+ logger.debug(
138
+ f"Trials config: Initial={self.initial_trials}, Max={self.max_trials}, Adaptive Threshold={self.adaptive_threshold}"
139
+ )
140
+
141
+ def evaluate_prompt(
142
+ self,
143
+ dataset: opik.Dataset,
144
+ metric_config: MetricConfig,
145
+ task_config: TaskConfig,
146
+ prompt: str,
147
+ use_full_dataset: bool = False,
148
+ experiment_config: Optional[Dict] = None,
149
+ n_samples: Optional[int] = None,
150
+ optimization_id: Optional[str] = None,
151
+ ) -> float:
152
+ """
153
+ Evaluate a prompt using the given dataset and metric configuration.
154
+
155
+ Args:
156
+ dataset: The dataset to evaluate against
157
+ metric_config: The metric configuration to use for evaluation
158
+ task_config: The task configuration containing input/output fields
159
+ prompt: The prompt to evaluate
160
+ use_full_dataset: Whether to use the full dataset or a subset for evaluation
161
+ experiment_config: A dictionary to log with the experiments
162
+ n_samples: The number of dataset items to use for evaluation
163
+ optimization_id: Optional ID for tracking the optimization run
164
+
165
+ Returns:
166
+ float: The evaluation score
167
+ """
168
+ return self._evaluate_prompt(
169
+ dataset=dataset,
170
+ metric_config=metric_config,
171
+ task_config=task_config,
172
+ prompt=prompt,
173
+ use_full_dataset=use_full_dataset,
174
+ experiment_config=experiment_config,
175
+ n_samples=n_samples,
176
+ optimization_id=optimization_id,
177
+ )
178
+
179
+ def _call_model(
180
+ self,
181
+ prompt: str,
182
+ system_prompt: Optional[str] = None,
183
+ is_reasoning: bool = False,
184
+ optimization_id: Optional[str] = None,
185
+ ) -> str:
186
+ """Call the model with the given prompt and return the response."""
187
+ # Note: Basic retry logic could be added here using tenacity
188
+ try:
189
+ # Basic LLM parameters (e.g., temperature, max_tokens)
190
+ llm_config_params = {
191
+ "temperature": getattr(self, "temperature", 0.3),
192
+ "max_tokens": getattr(self, "max_tokens", 1000),
193
+ "top_p": getattr(self, "top_p", 1.0),
194
+ "frequency_penalty": getattr(self, "frequency_penalty", 0.0),
195
+ "presence_penalty": getattr(self, "presence_penalty", 0.0),
196
+ }
197
+
198
+ # Prepare metadata that we want to be part of the LLM call context.
199
+ metadata_for_opik = {}
200
+ if self.project_name:
201
+ metadata_for_opik["project_name"] = (
202
+ self.project_name
203
+ ) # Top-level for general use
204
+ metadata_for_opik["opik"] = {"project_name": self.project_name}
205
+
206
+ if optimization_id:
207
+ # Also add to opik-specific structure if project_name was added
208
+ if "opik" in metadata_for_opik:
209
+ metadata_for_opik["opik"]["optimization_id"] = optimization_id
210
+
211
+ metadata_for_opik["optimizer_name"] = self.__class__.__name__
212
+ metadata_for_opik["opik_call_type"] = (
213
+ "reasoning" if is_reasoning else "evaluation_llm_task_direct"
214
+ )
215
+
216
+ if metadata_for_opik:
217
+ llm_config_params["metadata"] = metadata_for_opik
218
+
219
+ messages = []
220
+ if system_prompt and (
221
+ is_reasoning or getattr(self.task_config, "use_chat_prompt", False)
222
+ ):
223
+ messages.append({"role": "system", "content": system_prompt})
224
+ messages.append({"role": "user", "content": prompt})
225
+
226
+ model_to_use = self.reasoning_model if is_reasoning else self.model
227
+
228
+ # Pass llm_config_params (which now includes our metadata) to the Opik monitor.
229
+ # The monitor is expected to return a dictionary suitable for spreading into litellm.completion,
230
+ # having handled our metadata and added any Opik-specific configurations.
231
+ final_call_params = opik_litellm_monitor.try_add_opik_monitoring_to_params(
232
+ llm_config_params.copy()
233
+ )
234
+
235
+ logger.debug(
236
+ f"Calling model '{model_to_use}' with messages: {messages}, "
237
+ f"final params for litellm (from monitor): {final_call_params}"
238
+ )
239
+
240
+ response = litellm.completion(
241
+ model=model_to_use, messages=messages, **final_call_params
242
+ )
243
+ return response.choices[0].message.content
244
+ except litellm.exceptions.RateLimitError as e:
245
+ logger.error(f"LiteLLM Rate Limit Error: {e}")
246
+ raise
247
+ except litellm.exceptions.APIConnectionError as e:
248
+ logger.error(f"LiteLLM API Connection Error: {e}")
249
+ raise
250
+ except litellm.exceptions.ContextWindowExceededError as e:
251
+ logger.error(f"LiteLLM Context Window Exceeded Error: {e}")
252
+ # Log prompt length if possible? Needs access to prompt_for_llm here.
253
+ raise
254
+ except Exception as e:
255
+ logger.error(
256
+ f"Error calling model '{model_to_use}': {type(e).__name__} - {e}"
257
+ )
258
+ raise
259
+
260
+ def _evaluate_prompt(
261
+ self,
262
+ dataset: opik.Dataset,
263
+ metric_config: MetricConfig,
264
+ task_config: TaskConfig,
265
+ prompt: str,
266
+ use_full_dataset: bool,
267
+ experiment_config: Optional[Dict],
268
+ n_samples: Optional[int],
269
+ optimization_id: Optional[str] = None,
270
+ ) -> float:
271
+ # Calculate subset size for trials
272
+ if not use_full_dataset:
273
+ total_items = len(dataset.get_items())
274
+ if n_samples is not None:
275
+ if n_samples > total_items:
276
+ logger.warning(
277
+ f"Requested n_samples ({n_samples}) is larger than dataset size ({total_items}). Using full dataset."
278
+ )
279
+ subset_size = None
280
+ else:
281
+ subset_size = n_samples
282
+ logger.debug(f"Using specified n_samples: {subset_size} items")
283
+ else:
284
+ # Calculate 20% of total, but no more than 20 items and no more than total items
285
+ subset_size = min(total_items, min(20, max(10, int(total_items * 0.2))))
286
+ logger.debug(
287
+ f"Using automatic subset size calculation: {subset_size} items (20% of {total_items} total items)"
288
+ )
289
+ else:
290
+ subset_size = None # Use all items for final checks
291
+ logger.debug("Using full dataset for evaluation")
292
+ experiment_config = experiment_config or {}
293
+ experiment_config = {
294
+ **experiment_config,
295
+ **{
296
+ "optimizer": self.__class__.__name__,
297
+ "metric": metric_config.metric.name,
298
+ "dataset": dataset.name,
299
+ "configuration": {
300
+ "prompt": prompt,
301
+ "n_samples": subset_size,
302
+ "use_full_dataset": use_full_dataset,
303
+ },
304
+ },
305
+ }
306
+
307
+ def llm_task(dataset_item: Dict[str, Any]) -> Dict[str, str]:
308
+ # Convert DatasetItem to dict if needed
309
+ if hasattr(dataset_item, "to_dict"):
310
+ dataset_item = dataset_item.to_dict()
311
+
312
+ # Validate that input and output fields are in the dataset_item
313
+ for input_key in task_config.input_dataset_fields:
314
+ if input_key not in dataset_item:
315
+ logger.error(
316
+ f"Input field '{input_key}' not found in dataset sample: {dataset_item}"
317
+ )
318
+ raise ValueError(
319
+ f"Input field '{input_key}' not found in dataset sample"
320
+ )
321
+ if task_config.output_dataset_field not in dataset_item:
322
+ logger.error(
323
+ f"Output field '{task_config.output_dataset_field}' not found in dataset sample: {dataset_item}"
324
+ )
325
+ raise ValueError(
326
+ f"Output field '{task_config.output_dataset_field}' not found in dataset sample"
327
+ )
328
+
329
+ # --- Step 1: Prepare the prompt for the LLM ---
330
+ prompt_for_llm: str
331
+ field_mapping = {
332
+ field: dataset_item[field]
333
+ for field in task_config.input_dataset_fields
334
+ if field in dataset_item
335
+ }
336
+
337
+ if getattr(task_config, "use_chat_prompt", False):
338
+ # For chat prompts, the candidate prompt `prompt` is expected to be a template for the user message.
339
+ # We assume it contains placeholders like {question} or {text}.
340
+ candidate_template = Template(prompt)
341
+ prompt_for_llm = candidate_template.safe_substitute(field_mapping)
342
+ else:
343
+ # For non-chat prompts, `prompt` (the candidate/initial prompt) is the base instruction.
344
+ # Append the actual data fields to it.
345
+ input_clauses = []
346
+ for field_name in task_config.input_dataset_fields:
347
+ if field_name in dataset_item:
348
+ input_clauses.append(
349
+ f"{field_name.capitalize()}: {dataset_item[field_name]}"
350
+ )
351
+ item_specific_inputs_str = "\n".join(input_clauses)
352
+ prompt_for_llm = f"{prompt}\n\n{item_specific_inputs_str}"
353
+
354
+ logger.debug(f"Evaluating with inputs: {field_mapping}")
355
+ logger.debug(f"Prompt for LLM: {prompt_for_llm}")
356
+
357
+ # --- Step 2: Call the model ---
358
+ try:
359
+ logger.debug(f"Calling LLM with prompt length: {len(prompt_for_llm)}")
360
+ raw_model_output = self._call_model(
361
+ prompt=prompt_for_llm,
362
+ system_prompt=None,
363
+ is_reasoning=False,
364
+ optimization_id=optimization_id,
365
+ )
366
+ logger.debug(f"LLM raw response length: {len(raw_model_output)}")
367
+ logger.debug(f"LLM raw output: {raw_model_output}")
368
+ except Exception as e:
369
+ logger.error(f"Error calling model with prompt: {e}")
370
+ logger.error(f"Failed prompt: {prompt_for_llm}")
371
+ logger.error(f"Prompt length: {len(prompt_for_llm)}")
372
+ raise
373
+
374
+ # --- Step 3: Clean the model's output before metric evaluation ---
375
+ cleaned_model_output = raw_model_output.strip()
376
+ original_cleaned_output = cleaned_model_output # For logging if changed
377
+
378
+ # Dynamically generate prefixes based on the output field name
379
+ output_field = task_config.output_dataset_field # e.g., "answer" or "label"
380
+ dynamic_prefixes = [
381
+ f"{output_field.capitalize()}:",
382
+ f"{output_field.capitalize()} :",
383
+ f"{output_field}:", # Also check lowercase field name
384
+ f"{output_field} :",
385
+ ]
386
+
387
+ # Add common generic prefixes
388
+ generic_prefixes = ["Answer:", "Answer :", "A:"]
389
+
390
+ # Combine and remove duplicates (if any)
391
+ prefixes_to_strip = list(set(dynamic_prefixes + generic_prefixes))
392
+ logger.debug(f"Prefixes to strip: {prefixes_to_strip}")
393
+
394
+ for prefix_to_check in prefixes_to_strip:
395
+ # Perform case-insensitive check for robustness
396
+ if cleaned_model_output.lower().startswith(prefix_to_check.lower()):
397
+ # Strip based on the actual length of the found prefix
398
+ cleaned_model_output = cleaned_model_output[
399
+ len(prefix_to_check) :
400
+ ].strip()
401
+ logger.debug(
402
+ f"Stripped prefix '{prefix_to_check}', new output for metric: {cleaned_model_output}"
403
+ )
404
+ break # Stop after stripping the first found prefix
405
+
406
+ if original_cleaned_output != cleaned_model_output:
407
+ logger.debug(
408
+ f"Raw model output: '{original_cleaned_output}' -> Cleaned for metric: '{cleaned_model_output}'"
409
+ )
410
+ result = {
411
+ mappers.EVALUATED_LLM_TASK_OUTPUT: cleaned_model_output,
412
+ }
413
+ return result
414
+
415
+ # Use dataset's get_items with limit for sampling
416
+ logger.info(
417
+ f"Starting evaluation with {subset_size if subset_size else 'all'} samples for metric: {metric_config.metric.name}"
418
+ )
419
+ score = task_evaluator.evaluate(
420
+ dataset=dataset,
421
+ metric_config=metric_config,
422
+ evaluated_task=llm_task,
423
+ num_threads=self.num_threads,
424
+ project_name=self.project_name,
425
+ n_samples=subset_size, # Use subset_size for trials, None for full dataset
426
+ experiment_config=experiment_config,
427
+ optimization_id=optimization_id,
428
+ )
429
+ logger.debug(f"Evaluation score: {score:.4f}")
430
+ return score
431
+
432
+ def optimize_prompt(
433
+ self,
434
+ dataset: Union[str, Dataset],
435
+ metric_config: MetricConfig,
436
+ task_config: TaskConfig,
437
+ experiment_config: Optional[Dict] = None,
438
+ n_samples: int = None,
439
+ auto_continue: bool = False,
440
+ **kwargs,
441
+ ) -> OptimizationResult:
442
+ """
443
+ Optimize a prompt using meta-reasoning.
444
+
445
+ Args:
446
+ dataset: The dataset to evaluate against
447
+ metric_config: The metric configuration to use for evaluation
448
+ task_config: The task configuration containing input/output fields
449
+ experiment_config: A dictionary to log with the experiments
450
+ n_samples: The number of dataset items to use for evaluation
451
+ auto_continue: If True, the algorithm may continue if goal not met
452
+ **kwargs: Additional arguments for evaluation
453
+
454
+ Returns:
455
+ OptimizationResult: Structured result containing optimization details
456
+ """
457
+ total_items = len(dataset.get_items())
458
+ if n_samples is not None and n_samples > total_items:
459
+ logger.warning(
460
+ f"Requested n_samples ({n_samples}) is larger than dataset size ({total_items}). Using full dataset."
461
+ )
462
+ n_samples = None
463
+
464
+ logger.info(
465
+ f"Starting optimization with n_samples={n_samples}, auto_continue={auto_continue}"
466
+ )
467
+ logger.info(f"Dataset size: {total_items} items")
468
+ logger.info(f"Initial prompt: {task_config.instruction_prompt}")
469
+
470
+ optimization = None
471
+ try:
472
+ optimization = self._opik_client.create_optimization(
473
+ dataset_name=dataset.name, objective_name=metric_config.metric.name
474
+ )
475
+ logger.info(f"Created optimization with ID: {optimization.id}")
476
+ except Exception as e:
477
+ logger.warning(
478
+ f"Opik server does not support optimizations: {e}. Please upgrade opik."
479
+ )
480
+ optimization = None
481
+
482
+ try:
483
+ result = self._optimize_prompt(
484
+ optimization_id=optimization.id if optimization is not None else None,
485
+ dataset=dataset,
486
+ metric_config=metric_config,
487
+ task_config=task_config,
488
+ experiment_config=experiment_config,
489
+ n_samples=n_samples,
490
+ auto_continue=auto_continue,
491
+ **kwargs,
492
+ )
493
+ if optimization:
494
+ self.update_optimization(optimization, status="completed")
495
+ logger.info("Optimization completed successfully")
496
+ return result
497
+ except Exception as e:
498
+ logger.error(f"Optimization failed: {e}")
499
+ if optimization:
500
+ self.update_optimization(optimization, status="cancelled")
501
+ logger.info("Optimization marked as cancelled")
502
+ raise e
503
+
504
+ def _optimize_prompt(
505
+ self,
506
+ optimization_id: str,
507
+ dataset: Union[str, Dataset],
508
+ metric_config: MetricConfig,
509
+ task_config: TaskConfig,
510
+ experiment_config: Optional[Dict],
511
+ n_samples: int,
512
+ auto_continue: bool,
513
+ **kwargs,
514
+ ) -> OptimizationResult:
515
+ self.auto_continue = auto_continue
516
+ self.dataset = dataset
517
+ self.task_config = task_config
518
+
519
+ current_prompt = task_config.instruction_prompt
520
+ experiment_config = experiment_config or {}
521
+ experiment_config = {
522
+ **experiment_config,
523
+ **{
524
+ "optimizer": self.__class__.__name__,
525
+ "metric": metric_config.metric.name,
526
+ "dataset": self.dataset.name,
527
+ "configuration": {
528
+ "prompt": current_prompt,
529
+ "max_rounds": self.max_rounds,
530
+ "num_prompts_per_round": self.num_prompts_per_round,
531
+ "improvement_threshold": self.improvement_threshold,
532
+ "initial_trials": self.initial_trials,
533
+ "max_trials": self.max_trials,
534
+ "adaptive_threshold": self.adaptive_threshold,
535
+ },
536
+ },
537
+ }
538
+
539
+ logger.info("Evaluating initial prompt")
540
+ initial_score = self.evaluate_prompt(
541
+ optimization_id=optimization_id,
542
+ dataset=dataset,
543
+ metric_config=metric_config,
544
+ task_config=task_config,
545
+ prompt=current_prompt,
546
+ n_samples=n_samples,
547
+ experiment_config=experiment_config,
548
+ use_full_dataset=n_samples is None,
549
+ )
550
+ best_score = initial_score
551
+ best_prompt = current_prompt
552
+ rounds = []
553
+ stopped_early = False
554
+
555
+ logger.info(f"Initial score: {initial_score:.4f}")
556
+
557
+ # Initialize TQDM with postfix placeholder
558
+ pbar = tqdm(
559
+ total=self.max_rounds,
560
+ desc="Optimizing Prompt",
561
+ unit="round",
562
+ bar_format="{l_bar}{bar:20}{r_bar} | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]",
563
+ position=0,
564
+ leave=True,
565
+ postfix={
566
+ "best_score": f"{initial_score:.4f}",
567
+ "llm_calls": self.llm_call_counter,
568
+ },
569
+ )
570
+
571
+ for round_num in range(self.max_rounds):
572
+ logger.info(f"\n{'='*50}")
573
+ logger.info(f"Starting Round {round_num + 1}/{self.max_rounds}")
574
+ logger.info(f"Current best score: {best_score:.4f}")
575
+ logger.info(f"Current best prompt: {best_prompt}")
576
+
577
+ previous_best_score = best_score
578
+ try:
579
+ logger.info("Generating candidate prompts")
580
+ candidate_prompts = self._generate_candidate_prompts(
581
+ current_prompt=best_prompt,
582
+ best_score=best_score,
583
+ round_num=round_num,
584
+ previous_rounds=rounds,
585
+ metric_config=metric_config,
586
+ optimization_id=optimization_id,
587
+ )
588
+ logger.info(f"Generated {len(candidate_prompts)} candidate prompts")
589
+ except Exception as e:
590
+ logger.error(f"Error generating candidate prompts: {e}")
591
+ break
592
+
593
+ prompt_scores = []
594
+ for candidate_count, prompt in enumerate(candidate_prompts):
595
+ logger.info(
596
+ f"\nEvaluating candidate {candidate_count + 1}/{len(candidate_prompts)}"
597
+ )
598
+ logger.info(f"Prompt: {prompt}")
599
+
600
+ scores = []
601
+ should_run_max_trials = True
602
+
603
+ # Initial trials
604
+ logger.debug(f"Running initial {self.initial_trials} trials...")
605
+ for trial in range(self.initial_trials):
606
+ try:
607
+ logger.debug(f"Trial {trial + 1}/{self.initial_trials}")
608
+ score = self.evaluate_prompt(
609
+ dataset=dataset,
610
+ metric_config=metric_config,
611
+ task_config=task_config,
612
+ prompt=prompt,
613
+ n_samples=n_samples,
614
+ use_full_dataset=False,
615
+ experiment_config=experiment_config,
616
+ )
617
+ scores.append(score)
618
+ logger.debug(f"Trial {trial+1} score: {score:.4f}")
619
+ except Exception as e:
620
+ logger.error(f"Error in trial {trial + 1}: {e}")
621
+ continue
622
+
623
+ if not scores:
624
+ logger.warning(
625
+ "All initial trials failed for this prompt, skipping"
626
+ )
627
+ continue
628
+
629
+ # Adaptive trials logic
630
+ avg_score_initial = sum(scores) / len(scores)
631
+ if (
632
+ self.adaptive_threshold is not None
633
+ and self.max_trials > self.initial_trials
634
+ and avg_score_initial < best_score * self.adaptive_threshold
635
+ ):
636
+ should_run_max_trials = False
637
+ logger.debug("Skipping additional trials...")
638
+
639
+ # Run additional trials
640
+ if should_run_max_trials and self.max_trials > self.initial_trials:
641
+ num_additional_trials = self.max_trials - self.initial_trials
642
+ logger.debug(
643
+ f"Running {num_additional_trials} additional trials..."
644
+ )
645
+ for trial in range(self.initial_trials, self.max_trials):
646
+ try:
647
+ logger.debug(
648
+ f"Additional trial {trial + 1}/{self.max_trials}"
649
+ )
650
+ score = self.evaluate_prompt(
651
+ dataset=dataset,
652
+ metric_config=metric_config,
653
+ task_config=task_config,
654
+ prompt=prompt,
655
+ n_samples=n_samples,
656
+ use_full_dataset=False,
657
+ experiment_config=experiment_config,
658
+ )
659
+ scores.append(score)
660
+ logger.debug(
661
+ f"Additional trial {trial+1} score: {score:.4f}"
662
+ )
663
+ except Exception as e:
664
+ logger.error(f"Error in additional trial {trial + 1}: {e}")
665
+ continue
666
+
667
+ # Calculate final average score
668
+ if scores:
669
+ final_avg_score = sum(scores) / len(scores)
670
+ prompt_scores.append((prompt, final_avg_score, scores))
671
+ logger.info(f"Completed {len(scores)} trials for prompt.")
672
+ logger.info(f"Final average score: {final_avg_score:.4f}")
673
+ logger.debug(
674
+ f"Individual trial scores: {[f'{s:.4f}' for s in scores]}"
675
+ )
676
+ else:
677
+ # This case should be rare now due to the initial check, but good practice
678
+ logger.warning("No successful trials completed for this prompt.")
679
+
680
+ if not prompt_scores:
681
+ logger.warning("No prompts were successfully evaluated in this round")
682
+ break
683
+
684
+ # Sort by float score
685
+ prompt_scores.sort(key=lambda x: x[1], reverse=True)
686
+ best_candidate_this_round, best_cand_score_avg, best_cand_trials = (
687
+ prompt_scores[0]
688
+ )
689
+
690
+ logger.info(
691
+ f"\nBest candidate from this round (avg score {metric_config.metric.name}): {best_cand_score_avg:.4f}"
692
+ )
693
+ logger.info(f"Prompt: {best_candidate_this_round}")
694
+
695
+ # Re-evaluate the best candidate from the round using the full dataset (if n_samples is None)
696
+ # or the specified n_samples subset for a more stable score comparison.
697
+ # This uses use_full_dataset flag appropriately.
698
+ if best_cand_score_avg > best_score:
699
+ logger.info("Running final evaluation on best candidate...")
700
+ final_score_best_cand = self.evaluate_prompt(
701
+ optimization_id=optimization_id,
702
+ dataset=dataset,
703
+ metric_config=metric_config,
704
+ task_config=task_config,
705
+ prompt=best_candidate_this_round,
706
+ experiment_config=experiment_config,
707
+ n_samples=n_samples,
708
+ use_full_dataset=n_samples is None,
709
+ )
710
+ logger.info(
711
+ f"Final evaluation score for best candidate: {final_score_best_cand:.4f}"
712
+ )
713
+
714
+ if final_score_best_cand > best_score:
715
+ logger.info(f"New best prompt found!")
716
+ best_score = final_score_best_cand
717
+ best_prompt = best_candidate_this_round
718
+ logger.info(f"New Best Prompt: {best_prompt}")
719
+ logger.info(
720
+ f"New Best Score ({metric_config.metric.name}): {best_score:.4f}"
721
+ )
722
+ else:
723
+ logger.info(
724
+ "Best candidate score did not improve upon final evaluation."
725
+ )
726
+ # Decide what prompt to carry to the next round's generation step.
727
+ # Option 1: Carry the best scoring prompt overall (best_prompt)
728
+ # Option 2: Carry the best candidate from this round (best_candidate_this_round) even if it didn't beat the overall best after final eval.
729
+ # Let's stick with Option 1 for now - always generate from the overall best.
730
+ # current_prompt = best_prompt # Implicitly done as best_prompt is updated
731
+
732
+ improvement = self._calculate_improvement(best_score, previous_best_score)
733
+ logger.info(
734
+ f"Improvement in score ({metric_config.metric.name}) this round: {improvement:.2%}"
735
+ )
736
+
737
+ # Create round data
738
+ round_data = self._create_round_data(
739
+ round_num,
740
+ best_prompt,
741
+ best_score,
742
+ best_prompt,
743
+ prompt_scores,
744
+ previous_best_score,
745
+ improvement,
746
+ )
747
+ rounds.append(round_data)
748
+ self._add_to_history(round_data.dict())
749
+
750
+ if (
751
+ improvement < self.improvement_threshold and round_num > 0
752
+ ): # Avoid stopping after first round if threshold is low
753
+ logger.info(
754
+ f"Improvement below threshold ({improvement:.2%} < {self.improvement_threshold:.2%}), stopping early"
755
+ )
756
+ stopped_early = True
757
+ break
758
+
759
+ # Update TQDM postfix
760
+ pbar.set_postfix(
761
+ {
762
+ "best_score": f"{best_score:.4f}",
763
+ "improvement": f"{improvement:.2%}",
764
+ "llm_calls": self.llm_call_counter,
765
+ }
766
+ )
767
+ pbar.update(1)
768
+
769
+ pbar.close()
770
+
771
+ logger.info("\n" + "=" * 80)
772
+ logger.info("OPTIMIZATION COMPLETE")
773
+ logger.info("=" * 80)
774
+ logger.info(f"Initial score: {initial_score:.4f}")
775
+ logger.info(f"Final best score: {best_score:.4f}")
776
+ if initial_score != 0: # Avoid division by zero if initial score was 0
777
+ total_improvement_pct = (best_score - initial_score) / abs(
778
+ initial_score
779
+ ) # Use abs for safety
780
+ logger.info(f"Total improvement: {total_improvement_pct:.2%}")
781
+ elif best_score > 0:
782
+ logger.info("Total improvement: infinite (initial score was 0)")
783
+ else:
784
+ logger.info("Total improvement: 0.00% (scores did not improve from 0)")
785
+ logger.info("\nFINAL OPTIMIZED PROMPT:")
786
+ logger.info("-" * 80)
787
+ logger.info(best_prompt)
788
+ logger.info("-" * 80)
789
+ logger.info("=" * 80)
790
+
791
+ return self._create_result(
792
+ metric_config,
793
+ task_config,
794
+ best_prompt,
795
+ best_score,
796
+ initial_score,
797
+ rounds,
798
+ stopped_early,
799
+ )
800
+
801
+ def _calculate_improvement(
802
+ self, current_score: float, previous_score: float
803
+ ) -> float:
804
+ """Calculate the improvement percentage between scores."""
805
+ return (
806
+ (current_score - previous_score) / previous_score
807
+ if previous_score > 0
808
+ else 0
809
+ )
810
+
811
+ def _create_round_data(
812
+ self,
813
+ round_num: int,
814
+ current_best_prompt: str,
815
+ current_best_score: float,
816
+ best_prompt_overall: str,
817
+ evaluated_candidates: List[tuple[str, float, List[float]]],
818
+ previous_best_score: float,
819
+ improvement_this_round: float,
820
+ ) -> OptimizationRound:
821
+ """Create an OptimizationRound object with the current round's data."""
822
+ generated_prompts_log = []
823
+ for prompt, avg_score, trial_scores in evaluated_candidates:
824
+ improvement_vs_prev = self._calculate_improvement(
825
+ avg_score, previous_best_score
826
+ )
827
+ generated_prompts_log.append(
828
+ {
829
+ "prompt": prompt,
830
+ "score": avg_score,
831
+ "trial_scores": trial_scores,
832
+ "improvement": improvement_vs_prev,
833
+ }
834
+ )
835
+
836
+ return OptimizationRound(
837
+ round_number=round_num + 1,
838
+ current_prompt=current_best_prompt,
839
+ current_score=current_best_score,
840
+ generated_prompts=generated_prompts_log,
841
+ best_prompt=best_prompt_overall,
842
+ best_score=current_best_score,
843
+ improvement=improvement_this_round,
844
+ )
845
+
846
+ def _create_result(
847
+ self,
848
+ metric_config: MetricConfig,
849
+ task_config: TaskConfig,
850
+ best_prompt: str,
851
+ best_score: float,
852
+ initial_score: float,
853
+ rounds: List[OptimizationRound],
854
+ stopped_early: bool,
855
+ ) -> OptimizationResult:
856
+ """Create the final OptimizationResult object."""
857
+ details = {
858
+ "prompt_type": "chat" if task_config.use_chat_prompt else "non-chat",
859
+ "initial_prompt": task_config.instruction_prompt,
860
+ "initial_score": initial_score,
861
+ "final_prompt": best_prompt,
862
+ "final_score": best_score,
863
+ "rounds": rounds,
864
+ "total_rounds": len(rounds),
865
+ "stopped_early": stopped_early,
866
+ "metric_config": metric_config.dict(),
867
+ "task_config": task_config.dict(),
868
+ "model": self.model,
869
+ "temperature": self.model_kwargs.get("temperature"),
870
+ }
871
+
872
+ return OptimizationResult(
873
+ prompt=best_prompt,
874
+ score=best_score,
875
+ metric_name=metric_config.metric.name,
876
+ details=details,
877
+ )
878
+
879
+ def _get_task_context(self, metric_config: MetricConfig) -> str:
880
+ """Get task-specific context from the dataset and metric configuration."""
881
+ if self.dataset is None or self.task_config is None:
882
+ return ""
883
+
884
+ input_fields = self.task_config.input_dataset_fields
885
+ output_field = self.task_config.output_dataset_field
886
+
887
+ # Describe Single Metric
888
+ metric_name = metric_config.metric.name
889
+ description = getattr(
890
+ metric_config.metric, "description", "No description available."
891
+ )
892
+ goal = (
893
+ "higher is better"
894
+ if getattr(metric_config.metric, "higher_is_better", True)
895
+ else "lower is better"
896
+ )
897
+ metrics_str = f"- {metric_name}: {description} ({goal})"
898
+
899
+ context = "\nTask Context:\n"
900
+ context += f"Input fields: {', '.join(input_fields)}\n"
901
+ context += f"Output field: {output_field}\n"
902
+ context += f"Evaluation Metric:\n{metrics_str}\n"
903
+
904
+ try:
905
+ # Try get_items() first as it's the preferred method
906
+ items = self.dataset.get_items()
907
+ if items:
908
+ sample = items[0] # Get first sample
909
+ else:
910
+ # Fallback to other methods if get_items() fails or returns empty
911
+ if hasattr(self.dataset, "samples") and self.dataset.samples:
912
+ sample = self.dataset.samples[0] # Get first sample
913
+ elif hasattr(self.dataset, "__iter__"):
914
+ sample = next(iter(self.dataset))
915
+ else:
916
+ logger.warning(
917
+ "Dataset does not have a samples attribute or is not iterable"
918
+ )
919
+ return context
920
+
921
+ if sample is not None:
922
+ context += "\nExample:\n"
923
+ for field in input_fields:
924
+ if field in sample:
925
+ context += f"Input '{field}': {sample[field]}\n"
926
+ if output_field in sample:
927
+ context += f"Output '{output_field}': {sample[output_field]}\n"
928
+ except Exception as e:
929
+ logger.warning(f"Could not get sample from dataset: {e}")
930
+
931
+ return context
932
+
933
+ def _generate_candidate_prompts(
934
+ self,
935
+ current_prompt: str,
936
+ best_score: float,
937
+ round_num: int,
938
+ previous_rounds: List[OptimizationRound],
939
+ metric_config: MetricConfig,
940
+ optimization_id: Optional[str] = None,
941
+ ) -> List[str]:
942
+ """Generate candidate prompts using meta-prompting."""
943
+
944
+ logger.debug(f"\nGenerating candidate prompts for round {round_num + 1}")
945
+ logger.debug(f"Generating from prompt: {current_prompt}")
946
+ logger.debug(f"Current best score: {best_score:.4f}")
947
+
948
+ # Pass single metric_config
949
+ history_context = self._build_history_context(previous_rounds)
950
+ task_context = self._get_task_context(metric_config=metric_config)
951
+
952
+ user_prompt = f"""Current prompt: {current_prompt}
953
+ Current score: {best_score}
954
+ {history_context}
955
+ {task_context}
956
+
957
+ Analyze the example provided, the metric description, and the history of scores.
958
+ Generate {self.num_prompts_per_round} improved versions of this prompt.
959
+ Focus on improving the score for the metric: {metric_config.metric.name}.
960
+ Each version should aim to:
961
+ 1. Be more specific and clear about expectations based on the metric and task.
962
+ 2. Provide necessary context and constraints.
963
+ 3. Guide the model to produce the desired output format suitable for the metric.
964
+ 4. Remove ambiguity and unnecessary elements.
965
+ 5. Maintain conciseness while being complete.
966
+
967
+ Return a valid JSON array as specified."""
968
+
969
+ try:
970
+ # Use _call_model which handles selecting reasoning_model
971
+ content = self._call_model(
972
+ prompt=user_prompt,
973
+ system_prompt=self._REASONING_SYSTEM_PROMPT,
974
+ is_reasoning=True,
975
+ optimization_id=optimization_id,
976
+ )
977
+ logger.debug(f"Raw response from reasoning model: {content}")
978
+
979
+ # --- Robust JSON Parsing and Validation ---
980
+ json_result = None
981
+ try:
982
+ # Try direct JSON parsing
983
+ json_result = json.loads(content)
984
+ except json.JSONDecodeError:
985
+ # If direct fails, try regex extraction
986
+ logger.warning(
987
+ "Direct JSON parsing failed, attempting regex extraction."
988
+ )
989
+ import re
990
+
991
+ json_match = re.search(r"\{.*\}", content, re.DOTALL)
992
+ if json_match:
993
+ try:
994
+ json_result = json.loads(json_match.group())
995
+ except json.JSONDecodeError as e:
996
+ logger.error(f"Could not parse JSON extracted via regex: {e}")
997
+ return [current_prompt] # Fallback
998
+ else:
999
+ logger.error("No JSON object found in response via regex.")
1000
+ return [current_prompt] # Fallback
1001
+
1002
+ # Validate the parsed JSON structure
1003
+ if not isinstance(json_result, dict) or "prompts" not in json_result:
1004
+ logger.error(
1005
+ "Parsed JSON is not a dictionary or missing 'prompts' key."
1006
+ )
1007
+ logger.debug(f"Parsed JSON content: {json_result}")
1008
+ return [current_prompt] # Fallback
1009
+
1010
+ if not isinstance(json_result["prompts"], list):
1011
+ logger.error("'prompts' key does not contain a list.")
1012
+ logger.debug(f"Content of 'prompts': {json_result.get('prompts')}")
1013
+ return [current_prompt] # Fallback
1014
+
1015
+ # Extract and log valid prompts
1016
+ valid_prompts = []
1017
+ for item in json_result["prompts"]:
1018
+ if (
1019
+ isinstance(item, dict)
1020
+ and "prompt" in item
1021
+ and isinstance(item["prompt"], str)
1022
+ ):
1023
+ prompt_text = item["prompt"]
1024
+ valid_prompts.append(prompt_text)
1025
+ # Log details
1026
+ focus = item.get("improvement_focus", "N/A")
1027
+ reasoning = item.get("reasoning", "N/A")
1028
+ logger.info(f"Generated prompt: {prompt_text}")
1029
+ logger.info(f" Improvement focus: {focus}")
1030
+ logger.info(f" Reasoning: {reasoning}")
1031
+ else:
1032
+ logger.warning(
1033
+ f"Skipping invalid prompt item structure in JSON response: {item}"
1034
+ )
1035
+
1036
+ if not valid_prompts:
1037
+ logger.warning(
1038
+ "No valid prompts found in the parsed JSON response after validation."
1039
+ )
1040
+ return [current_prompt] # Fallback
1041
+
1042
+ return valid_prompts
1043
+ # --- End Robust Parsing ---
1044
+
1045
+ except Exception as e:
1046
+ # Catch other errors during model call or processing
1047
+ logger.error(f"Unexpected error during candidate prompt generation: {e}")
1048
+ logger.error("Falling back to current prompt.")
1049
+ return [current_prompt]
1050
+
1051
+ def _build_history_context(self, previous_rounds: List[OptimizationRound]) -> str:
1052
+ """Build context from previous optimization rounds."""
1053
+ if not previous_rounds:
1054
+ return ""
1055
+
1056
+ context = "\nPrevious rounds (latest first):\n"
1057
+ for round_data in reversed(previous_rounds[-3:]):
1058
+ context += f"\nRound {round_data.round_number}:\n"
1059
+ context += f"Best score this round: {round_data.best_score:.4f}\n"
1060
+ context += "Generated prompts this round (best first):\n"
1061
+
1062
+ sorted_generated = sorted(
1063
+ round_data.generated_prompts,
1064
+ key=lambda p: p.get("score", -float("inf")),
1065
+ reverse=True,
1066
+ )
1067
+
1068
+ for p in sorted_generated[:3]:
1069
+ prompt_text = p.get("prompt", "N/A")
1070
+ score = p.get("score", float("nan"))
1071
+ context += f"- Prompt: {prompt_text[:150]}...\n"
1072
+ context += f" Avg Score: {score:.4f}\n"
1073
+ return context
1074
+
1075
+ def _get_evaluation_subset(
1076
+ self, dataset: opik.Dataset, min_size: int = 20, max_size: int = 100
1077
+ ) -> List[Dict[str, Any]]:
1078
+ """Get a random subset of the dataset for evaluation.
1079
+
1080
+ Returns:
1081
+ List[Dict[str, Any]]: A list of dataset items to evaluate against
1082
+ """
1083
+ try:
1084
+ # Get all items from the dataset
1085
+ all_items = dataset.get_items()
1086
+ if not all_items:
1087
+ return all_items
1088
+
1089
+ # Calculate subset size
1090
+ total_size = len(all_items)
1091
+ subset_size = min(max(min_size, int(total_size * 0.2)), max_size)
1092
+
1093
+ # Get random subset of items
1094
+ import random
1095
+
1096
+ return random.sample(all_items, subset_size)
1097
+
1098
+ except Exception as e:
1099
+ logger.warning(f"Could not create evaluation subset: {e}")
1100
+ return all_items