dao-ai 0.1.19__py3-none-any.whl → 0.1.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,930 @@
1
+ """
2
+ Context-aware semantic cache threshold optimization using Optuna Bayesian optimization.
3
+
4
+ This module provides optimization for context-aware Genie cache thresholds using
5
+ Optuna's Tree-structured Parzen Estimator (TPE) algorithm with LLM-as-Judge
6
+ evaluation for semantic match validation.
7
+
8
+ The optimizer tunes these thresholds:
9
+ - similarity_threshold: Minimum similarity for question matching
10
+ - context_similarity_threshold: Minimum similarity for context matching
11
+ - question_weight: Weight for question similarity in combined score
12
+
13
+ Usage:
14
+ from dao_ai.genie.cache.context_aware.optimization import (
15
+ optimize_context_aware_cache_thresholds,
16
+ generate_eval_dataset_from_cache,
17
+ )
18
+
19
+ # Get entries from your cache
20
+ entries = cache_service.get_entries(include_embeddings=True, limit=100)
21
+
22
+ # Generate evaluation dataset
23
+ eval_dataset = generate_eval_dataset_from_cache(
24
+ cache_entries=entries,
25
+ dataset_name="my_cache_eval",
26
+ )
27
+
28
+ # Optimize thresholds
29
+ result = optimize_context_aware_cache_thresholds(
30
+ dataset=eval_dataset,
31
+ judge_model="databricks-meta-llama-3-3-70b-instruct",
32
+ n_trials=50,
33
+ metric="f1",
34
+ )
35
+
36
+ if result.improved:
37
+ print(f"Improved by {result.improvement:.1%}")
38
+ print(f"Best thresholds: {result.optimized_thresholds}")
39
+ """
40
+
41
+ import hashlib
42
+ import math
43
+ from dataclasses import dataclass, field
44
+ from datetime import datetime, timezone
45
+ from typing import TYPE_CHECKING, Any, Callable, Iterator, Literal, Sequence
46
+
47
+ import mlflow
48
+ from loguru import logger
49
+
50
+ from dao_ai.config import GenieContextAwareCacheParametersModel, LLMModel
51
+ from dao_ai.utils import dao_ai_version
52
+
53
+ # Type-only import for optuna.Trial to support type hints without runtime dependency
54
+ if TYPE_CHECKING:
55
+ import optuna
56
+
57
+ __all__ = [
58
+ "ContextAwareCacheEvalEntry",
59
+ "ContextAwareCacheEvalDataset",
60
+ "ThresholdOptimizationResult",
61
+ "optimize_context_aware_cache_thresholds",
62
+ "generate_eval_dataset_from_cache",
63
+ "semantic_match_judge",
64
+ "clear_judge_cache",
65
+ ]
66
+
67
+
68
+ @dataclass
69
+ class ContextAwareCacheEvalEntry:
70
+ """Single evaluation entry for threshold optimization.
71
+
72
+ Represents a pair of question/context combinations to evaluate
73
+ whether the cache should return a hit or miss.
74
+
75
+ Attributes:
76
+ question: Current question being asked
77
+ question_embedding: Pre-computed embedding of the question
78
+ context: Conversation context string
79
+ context_embedding: Pre-computed embedding of the context
80
+ cached_question: Question from the cache entry
81
+ cached_question_embedding: Embedding of the cached question
82
+ cached_context: Context from the cache entry
83
+ cached_context_embedding: Embedding of the cached context
84
+ expected_match: Whether this pair should be a cache hit (True),
85
+ miss (False), or use LLM to determine (None)
86
+ """
87
+
88
+ question: str
89
+ question_embedding: list[float]
90
+ context: str
91
+ context_embedding: list[float]
92
+ cached_question: str
93
+ cached_question_embedding: list[float]
94
+ cached_context: str
95
+ cached_context_embedding: list[float]
96
+ expected_match: bool | None = None
97
+
98
+
99
+ @dataclass
100
+ class ContextAwareCacheEvalDataset:
101
+ """Dataset for semantic cache threshold optimization.
102
+
103
+ Attributes:
104
+ name: Name of the dataset for tracking
105
+ entries: List of evaluation entries
106
+ description: Optional description of the dataset
107
+ """
108
+
109
+ name: str
110
+ entries: list[ContextAwareCacheEvalEntry]
111
+ description: str = ""
112
+
113
+ def __len__(self) -> int:
114
+ return len(self.entries)
115
+
116
+ def __iter__(self) -> Iterator[ContextAwareCacheEvalEntry]:
117
+ return iter(self.entries)
118
+
119
+
120
+ @dataclass
121
+ class ThresholdOptimizationResult:
122
+ """Result of semantic cache threshold optimization.
123
+
124
+ Attributes:
125
+ optimized_thresholds: Dictionary of optimized threshold values
126
+ original_thresholds: Dictionary of original threshold values
127
+ original_score: Score with original thresholds
128
+ optimized_score: Score with optimized thresholds
129
+ improvement: Percentage improvement (0.0-1.0)
130
+ n_trials: Number of optimization trials run
131
+ best_trial_number: Trial number that produced best result
132
+ study_name: Name of the Optuna study
133
+ metadata: Additional optimization metadata
134
+ """
135
+
136
+ optimized_thresholds: dict[str, float]
137
+ original_thresholds: dict[str, float]
138
+ original_score: float
139
+ optimized_score: float
140
+ improvement: float
141
+ n_trials: int
142
+ best_trial_number: int
143
+ study_name: str
144
+ metadata: dict[str, Any] = field(default_factory=dict)
145
+
146
+ @property
147
+ def improved(self) -> bool:
148
+ """Whether the optimization improved the thresholds."""
149
+ return self.optimized_score > self.original_score
150
+
151
+
152
+ # Cache for LLM judge results to avoid redundant calls
153
+ _judge_cache: dict[str, bool] = {}
154
+
155
+
156
+ def _compute_cache_key(
157
+ question1: str, context1: str, question2: str, context2: str
158
+ ) -> str:
159
+ """Compute a cache key for judge results."""
160
+ content = f"{question1}|{context1}|{question2}|{context2}"
161
+ return hashlib.sha256(content.encode()).hexdigest()
162
+
163
+
164
+ def semantic_match_judge(
165
+ question1: str,
166
+ context1: str,
167
+ question2: str,
168
+ context2: str,
169
+ model: LLMModel | str,
170
+ use_cache: bool = True,
171
+ ) -> bool:
172
+ """
173
+ Use LLM to determine if two question/context pairs are semantically equivalent.
174
+
175
+ This function acts as a judge to determine whether two questions with their
176
+ respective conversation contexts are asking for the same information and
177
+ would expect the same SQL query response.
178
+
179
+ Args:
180
+ question1: First question
181
+ context1: Conversation context for first question
182
+ question2: Second question
183
+ context2: Conversation context for second question
184
+ model: LLM model to use for judging
185
+ use_cache: Whether to cache results (default True)
186
+
187
+ Returns:
188
+ True if the pairs are semantically equivalent, False otherwise
189
+ """
190
+ global _judge_cache
191
+
192
+ # Check cache first
193
+ if use_cache:
194
+ cache_key = _compute_cache_key(question1, context1, question2, context2)
195
+ if cache_key in _judge_cache:
196
+ return _judge_cache[cache_key]
197
+
198
+ # Convert model to LLMModel if string
199
+ llm_model: LLMModel = LLMModel(name=model) if isinstance(model, str) else model
200
+
201
+ # Create the chat model
202
+ chat = llm_model.as_chat_model()
203
+
204
+ # Construct the prompt for semantic equivalence judgment
205
+ prompt = f"""You are an expert at determining semantic equivalence between database queries.
206
+
207
+ Given two question-context pairs, determine if they are semantically equivalent - meaning they are asking for the same information and would expect the same SQL query result.
208
+
209
+ Consider:
210
+ 1. Are both questions asking for the same data/information?
211
+ 2. Do the conversation contexts provide similar filtering or constraints?
212
+ 3. Would answering both require the same SQL query?
213
+
214
+ IMPORTANT: Be strict. Only return "MATCH" if the questions are truly asking for the same thing in the same context. Similar but different questions should return "NO_MATCH".
215
+
216
+ Question 1: {question1}
217
+ Context 1: {context1 if context1 else "(no context)"}
218
+
219
+ Question 2: {question2}
220
+ Context 2: {context2 if context2 else "(no context)"}
221
+
222
+ Respond with ONLY one word: "MATCH" or "NO_MATCH"
223
+ """
224
+
225
+ try:
226
+ response = chat.invoke(prompt)
227
+ result_text = response.content.strip().upper()
228
+ is_match = "MATCH" in result_text and "NO_MATCH" not in result_text
229
+
230
+ # Cache the result
231
+ if use_cache:
232
+ _judge_cache[cache_key] = is_match
233
+
234
+ logger.trace(
235
+ "LLM judge result",
236
+ question1=question1[:50],
237
+ question2=question2[:50],
238
+ is_match=is_match,
239
+ )
240
+
241
+ return is_match
242
+
243
+ except Exception as e:
244
+ logger.warning("LLM judge failed", error=str(e))
245
+ # Default to not matching on error (conservative)
246
+ return False
247
+
248
+
249
+ def _compute_l2_similarity(embedding1: list[float], embedding2: list[float]) -> float:
250
+ """
251
+ Compute similarity from L2 (Euclidean) distance.
252
+
253
+ Uses the same formula as the semantic cache:
254
+ similarity = 1.0 / (1.0 + L2_distance)
255
+
256
+ This gives a value in range [0, 1] where 1 means identical.
257
+
258
+ Args:
259
+ embedding1: First embedding vector
260
+ embedding2: Second embedding vector
261
+
262
+ Returns:
263
+ Similarity score in range [0, 1]
264
+ """
265
+ if len(embedding1) != len(embedding2):
266
+ raise ValueError(
267
+ f"Embedding dimensions must match: {len(embedding1)} vs {len(embedding2)}"
268
+ )
269
+
270
+ # Compute L2 distance
271
+ squared_diff_sum = sum(
272
+ (a - b) ** 2 for a, b in zip(embedding1, embedding2, strict=True)
273
+ )
274
+ l2_distance = math.sqrt(squared_diff_sum)
275
+
276
+ # Convert to similarity
277
+ similarity = 1.0 / (1.0 + l2_distance)
278
+ return similarity
279
+
280
+
281
+ def _evaluate_thresholds(
282
+ dataset: ContextAwareCacheEvalDataset,
283
+ similarity_threshold: float,
284
+ context_similarity_threshold: float,
285
+ question_weight: float,
286
+ judge_model: LLMModel | str | None = None,
287
+ ) -> tuple[float, float, float, dict[str, int]]:
288
+ """
289
+ Evaluate a set of thresholds against the dataset.
290
+
291
+ Args:
292
+ dataset: Evaluation dataset
293
+ similarity_threshold: Threshold for question similarity
294
+ context_similarity_threshold: Threshold for context similarity
295
+ question_weight: Weight for question in combined score
296
+ judge_model: Optional LLM model for judging unlabeled entries
297
+
298
+ Returns:
299
+ Tuple of (precision, recall, f1, confusion_matrix_dict)
300
+ """
301
+ true_positives = 0
302
+ false_positives = 0
303
+ true_negatives = 0
304
+ false_negatives = 0
305
+
306
+ # Note: context_weight = 1.0 - question_weight could be used for weighted scoring
307
+ # but we currently use independent thresholds for question and context similarity
308
+
309
+ for entry in dataset.entries:
310
+ # Compute similarities
311
+ question_sim = _compute_l2_similarity(
312
+ entry.question_embedding, entry.cached_question_embedding
313
+ )
314
+ context_sim = _compute_l2_similarity(
315
+ entry.context_embedding, entry.cached_context_embedding
316
+ )
317
+
318
+ # Apply threshold logic (same as production cache)
319
+ predicted_match = (
320
+ question_sim >= similarity_threshold
321
+ and context_sim >= context_similarity_threshold
322
+ )
323
+
324
+ # Get expected match
325
+ expected_match = entry.expected_match
326
+ if expected_match is None:
327
+ if judge_model is None:
328
+ # Skip entries without labels if no judge provided
329
+ continue
330
+ # Use LLM judge to determine expected match
331
+ expected_match = semantic_match_judge(
332
+ entry.question,
333
+ entry.context,
334
+ entry.cached_question,
335
+ entry.cached_context,
336
+ judge_model,
337
+ )
338
+
339
+ # Update confusion matrix
340
+ if predicted_match and expected_match:
341
+ true_positives += 1
342
+ elif predicted_match and not expected_match:
343
+ false_positives += 1
344
+ elif not predicted_match and expected_match:
345
+ false_negatives += 1
346
+ else:
347
+ true_negatives += 1
348
+
349
+ # Calculate metrics
350
+ total = true_positives + false_positives + true_negatives + false_negatives
351
+
352
+ precision = (
353
+ true_positives / (true_positives + false_positives)
354
+ if (true_positives + false_positives) > 0
355
+ else 0.0
356
+ )
357
+ recall = (
358
+ true_positives / (true_positives + false_negatives)
359
+ if (true_positives + false_negatives) > 0
360
+ else 0.0
361
+ )
362
+ f1 = (
363
+ 2 * precision * recall / (precision + recall)
364
+ if (precision + recall) > 0
365
+ else 0.0
366
+ )
367
+
368
+ confusion = {
369
+ "true_positives": true_positives,
370
+ "false_positives": false_positives,
371
+ "true_negatives": true_negatives,
372
+ "false_negatives": false_negatives,
373
+ "total": total,
374
+ }
375
+
376
+ return precision, recall, f1, confusion
377
+
378
+
379
+ def _create_objective(
380
+ dataset: ContextAwareCacheEvalDataset,
381
+ judge_model: LLMModel | str | None,
382
+ metric: Literal["f1", "precision", "recall", "fbeta"],
383
+ beta: float = 1.0,
384
+ ) -> Callable[["optuna.Trial"], float]:
385
+ """Create the Optuna objective function."""
386
+
387
+ def objective(trial: "optuna.Trial") -> float:
388
+ # Sample parameters
389
+ similarity_threshold = trial.suggest_float(
390
+ "similarity_threshold", 0.5, 0.99, log=False
391
+ )
392
+ context_similarity_threshold = trial.suggest_float(
393
+ "context_similarity_threshold", 0.5, 0.99, log=False
394
+ )
395
+ question_weight = trial.suggest_float("question_weight", 0.1, 0.9, log=False)
396
+
397
+ # Evaluate
398
+ precision, recall, f1, confusion = _evaluate_thresholds(
399
+ dataset=dataset,
400
+ similarity_threshold=similarity_threshold,
401
+ context_similarity_threshold=context_similarity_threshold,
402
+ question_weight=question_weight,
403
+ judge_model=judge_model,
404
+ )
405
+
406
+ # Log intermediate results
407
+ trial.set_user_attr("precision", precision)
408
+ trial.set_user_attr("recall", recall)
409
+ trial.set_user_attr("f1", f1)
410
+ trial.set_user_attr("confusion", confusion)
411
+
412
+ # Return selected metric
413
+ if metric == "f1":
414
+ return f1
415
+ elif metric == "precision":
416
+ return precision
417
+ elif metric == "recall":
418
+ return recall
419
+ elif metric == "fbeta":
420
+ # F-beta score: (1 + beta^2) * (precision * recall) / (beta^2 * precision + recall)
421
+ if precision + recall == 0:
422
+ return 0.0
423
+ fbeta = (
424
+ (1 + beta**2) * (precision * recall) / (beta**2 * precision + recall)
425
+ )
426
+ return fbeta
427
+ else:
428
+ raise ValueError(f"Unknown metric: {metric}")
429
+
430
+ return objective
431
+
432
+
433
+ def optimize_context_aware_cache_thresholds(
434
+ dataset: ContextAwareCacheEvalDataset,
435
+ original_thresholds: dict[str, float]
436
+ | GenieContextAwareCacheParametersModel
437
+ | None = None,
438
+ judge_model: LLMModel | str = "databricks-meta-llama-3-3-70b-instruct",
439
+ n_trials: int = 50,
440
+ metric: Literal["f1", "precision", "recall", "fbeta"] = "f1",
441
+ beta: float = 1.0,
442
+ register_if_improved: bool = True,
443
+ study_name: str | None = None,
444
+ seed: int | None = None,
445
+ show_progress_bar: bool = True,
446
+ ) -> ThresholdOptimizationResult:
447
+ """
448
+ Optimize semantic cache thresholds using Bayesian optimization.
449
+
450
+ Uses Optuna's Tree-structured Parzen Estimator (TPE) to efficiently
451
+ search the parameter space and find optimal threshold values.
452
+
453
+ Args:
454
+ dataset: Evaluation dataset with question/context pairs
455
+ original_thresholds: Original thresholds to compare against.
456
+ Can be a dict or GenieContextAwareCacheParametersModel.
457
+ If None, uses default values.
458
+ judge_model: LLM model for semantic match judging (for unlabeled entries)
459
+ n_trials: Number of optimization trials to run
460
+ metric: Optimization metric ("f1", "precision", "recall", "fbeta")
461
+ beta: Beta value for fbeta metric (higher = favor recall)
462
+ register_if_improved: Log results to MLflow if improved
463
+ study_name: Optional name for the Optuna study
464
+ seed: Random seed for reproducibility
465
+ show_progress_bar: Whether to show progress bar during optimization
466
+
467
+ Returns:
468
+ ThresholdOptimizationResult with optimized thresholds and metrics
469
+
470
+ Example:
471
+ from dao_ai.genie.cache.context_aware.optimization import (
472
+ optimize_context_aware_cache_thresholds,
473
+ ContextAwareCacheEvalDataset,
474
+ )
475
+
476
+ result = optimize_context_aware_cache_thresholds(
477
+ dataset=my_dataset,
478
+ judge_model="databricks-meta-llama-3-3-70b-instruct",
479
+ n_trials=50,
480
+ metric="f1",
481
+ )
482
+
483
+ if result.improved:
484
+ print(f"New thresholds: {result.optimized_thresholds}")
485
+ """
486
+ # Lazy import optuna - only loaded when optimization is actually called
487
+ # This allows the cache module to be imported without optuna installed
488
+ try:
489
+ import optuna
490
+ from optuna.samplers import TPESampler
491
+ except ImportError as e:
492
+ raise ImportError(
493
+ "optuna is required for cache threshold optimization. "
494
+ "Install it with: pip install optuna"
495
+ ) from e
496
+
497
+ # Optional MLflow integration - requires optuna-integration[mlflow]
498
+ try:
499
+ from optuna.integration import MLflowCallback
500
+
501
+ mlflow_callback_available = True
502
+ except ModuleNotFoundError:
503
+ mlflow_callback_available = False
504
+ MLflowCallback = None # type: ignore
505
+
506
+ logger.info(
507
+ "Starting semantic cache threshold optimization",
508
+ dataset_name=dataset.name,
509
+ dataset_size=len(dataset),
510
+ n_trials=n_trials,
511
+ metric=metric,
512
+ )
513
+
514
+ # Parse original thresholds
515
+ if original_thresholds is None:
516
+ orig_thresholds = {
517
+ "similarity_threshold": 0.85,
518
+ "context_similarity_threshold": 0.80,
519
+ "question_weight": 0.6,
520
+ }
521
+ elif isinstance(original_thresholds, GenieContextAwareCacheParametersModel):
522
+ orig_thresholds = {
523
+ "similarity_threshold": original_thresholds.similarity_threshold,
524
+ "context_similarity_threshold": original_thresholds.context_similarity_threshold,
525
+ "question_weight": original_thresholds.question_weight or 0.6,
526
+ }
527
+ else:
528
+ orig_thresholds = original_thresholds.copy()
529
+
530
+ # Evaluate original thresholds
531
+ orig_precision, orig_recall, orig_f1, orig_confusion = _evaluate_thresholds(
532
+ dataset=dataset,
533
+ similarity_threshold=orig_thresholds["similarity_threshold"],
534
+ context_similarity_threshold=orig_thresholds["context_similarity_threshold"],
535
+ question_weight=orig_thresholds["question_weight"],
536
+ judge_model=judge_model,
537
+ )
538
+
539
+ # Determine original score based on metric
540
+ if metric == "f1":
541
+ original_score = orig_f1
542
+ elif metric == "precision":
543
+ original_score = orig_precision
544
+ elif metric == "recall":
545
+ original_score = orig_recall
546
+ elif metric == "fbeta":
547
+ if orig_precision + orig_recall == 0:
548
+ original_score = 0.0
549
+ else:
550
+ original_score = (
551
+ (1 + beta**2)
552
+ * (orig_precision * orig_recall)
553
+ / (beta**2 * orig_precision + orig_recall)
554
+ )
555
+ else:
556
+ raise ValueError(f"Unknown metric: {metric}")
557
+
558
+ logger.info(
559
+ "Evaluated original thresholds",
560
+ precision=f"{orig_precision:.4f}",
561
+ recall=f"{orig_recall:.4f}",
562
+ f1=f"{orig_f1:.4f}",
563
+ original_score=f"{original_score:.4f}",
564
+ )
565
+
566
+ # Create study name if not provided
567
+ if study_name is None:
568
+ timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
569
+ study_name = f"context_aware_cache_threshold_optimization_{timestamp}"
570
+
571
+ # Create Optuna study
572
+ sampler = TPESampler(seed=seed)
573
+ study = optuna.create_study(
574
+ study_name=study_name,
575
+ direction="maximize",
576
+ sampler=sampler,
577
+ )
578
+
579
+ # Add original thresholds as first trial for comparison
580
+ study.enqueue_trial(orig_thresholds)
581
+
582
+ # Create objective function
583
+ objective = _create_objective(
584
+ dataset=dataset,
585
+ judge_model=judge_model,
586
+ metric=metric,
587
+ beta=beta,
588
+ )
589
+
590
+ # Set up MLflow callback if available
591
+ callbacks = []
592
+ if mlflow_callback_available and MLflowCallback is not None:
593
+ try:
594
+ mlflow_callback = MLflowCallback(
595
+ tracking_uri=mlflow.get_tracking_uri(),
596
+ metric_name=metric,
597
+ create_experiment=False,
598
+ )
599
+ callbacks.append(mlflow_callback)
600
+ except Exception as e:
601
+ logger.debug("MLflow callback not available", error=str(e))
602
+
603
+ # Run optimization
604
+ study.optimize(
605
+ objective,
606
+ n_trials=n_trials,
607
+ show_progress_bar=show_progress_bar,
608
+ callbacks=callbacks if callbacks else None,
609
+ )
610
+
611
+ # Get best trial
612
+ best_trial = study.best_trial
613
+ best_thresholds = {
614
+ "similarity_threshold": best_trial.params["similarity_threshold"],
615
+ "context_similarity_threshold": best_trial.params[
616
+ "context_similarity_threshold"
617
+ ],
618
+ "question_weight": best_trial.params["question_weight"],
619
+ }
620
+
621
+ # Get metrics from best trial
622
+ best_precision = best_trial.user_attrs.get("precision", 0.0)
623
+ best_recall = best_trial.user_attrs.get("recall", 0.0)
624
+ best_f1 = best_trial.user_attrs.get("f1", 0.0)
625
+ best_confusion = best_trial.user_attrs.get("confusion", {})
626
+ optimized_score = best_trial.value
627
+
628
+ # Calculate improvement
629
+ improvement = (
630
+ (optimized_score - original_score) / original_score
631
+ if original_score > 0
632
+ else 0.0
633
+ )
634
+
635
+ logger.success(
636
+ "Optimization complete",
637
+ best_trial_number=best_trial.number,
638
+ original_score=f"{original_score:.4f}",
639
+ optimized_score=f"{optimized_score:.4f}",
640
+ improvement=f"{improvement:.1%}",
641
+ best_thresholds=best_thresholds,
642
+ )
643
+
644
+ # Log to MLflow if improved
645
+ if register_if_improved and improvement > 0:
646
+ try:
647
+ _log_optimization_to_mlflow(
648
+ study_name=study_name,
649
+ dataset_name=dataset.name,
650
+ dataset_size=len(dataset),
651
+ original_thresholds=orig_thresholds,
652
+ optimized_thresholds=best_thresholds,
653
+ original_score=original_score,
654
+ optimized_score=optimized_score,
655
+ improvement=improvement,
656
+ metric=metric,
657
+ n_trials=n_trials,
658
+ best_precision=best_precision,
659
+ best_recall=best_recall,
660
+ best_f1=best_f1,
661
+ best_confusion=best_confusion,
662
+ judge_model=judge_model,
663
+ )
664
+ except Exception as e:
665
+ logger.warning("Failed to log optimization to MLflow", error=str(e))
666
+
667
+ # Build result
668
+ result = ThresholdOptimizationResult(
669
+ optimized_thresholds=best_thresholds,
670
+ original_thresholds=orig_thresholds,
671
+ original_score=original_score,
672
+ optimized_score=optimized_score,
673
+ improvement=improvement,
674
+ n_trials=n_trials,
675
+ best_trial_number=best_trial.number,
676
+ study_name=study_name,
677
+ metadata={
678
+ "metric": metric,
679
+ "beta": beta if metric == "fbeta" else None,
680
+ "judge_model": str(judge_model),
681
+ "dataset_name": dataset.name,
682
+ "dataset_size": len(dataset),
683
+ "original_precision": orig_precision,
684
+ "original_recall": orig_recall,
685
+ "original_f1": orig_f1,
686
+ "optimized_precision": best_precision,
687
+ "optimized_recall": best_recall,
688
+ "optimized_f1": best_f1,
689
+ "confusion_matrix": best_confusion,
690
+ },
691
+ )
692
+
693
+ return result
694
+
695
+
696
+ def _log_optimization_to_mlflow(
697
+ study_name: str,
698
+ dataset_name: str,
699
+ dataset_size: int,
700
+ original_thresholds: dict[str, float],
701
+ optimized_thresholds: dict[str, float],
702
+ original_score: float,
703
+ optimized_score: float,
704
+ improvement: float,
705
+ metric: str,
706
+ n_trials: int,
707
+ best_precision: float,
708
+ best_recall: float,
709
+ best_f1: float,
710
+ best_confusion: dict[str, int],
711
+ judge_model: LLMModel | str,
712
+ ) -> None:
713
+ """Log optimization results to MLflow."""
714
+ with mlflow.start_run(run_name=study_name):
715
+ # Log parameters
716
+ mlflow.log_params(
717
+ {
718
+ "optimizer": "optuna_tpe",
719
+ "metric": metric,
720
+ "n_trials": n_trials,
721
+ "dataset_name": dataset_name,
722
+ "dataset_size": dataset_size,
723
+ "judge_model": str(judge_model),
724
+ "dao_ai_version": dao_ai_version(),
725
+ }
726
+ )
727
+
728
+ # Log original thresholds
729
+ for key, value in original_thresholds.items():
730
+ mlflow.log_param(f"original_{key}", value)
731
+
732
+ # Log optimized thresholds
733
+ for key, value in optimized_thresholds.items():
734
+ mlflow.log_param(f"optimized_{key}", value)
735
+
736
+ # Log metrics
737
+ mlflow.log_metrics(
738
+ {
739
+ "original_score": original_score,
740
+ "optimized_score": optimized_score,
741
+ "improvement": improvement,
742
+ "precision": best_precision,
743
+ "recall": best_recall,
744
+ "f1": best_f1,
745
+ **{f"confusion_{k}": v for k, v in best_confusion.items()},
746
+ }
747
+ )
748
+
749
+ # Log thresholds as artifact
750
+ thresholds_artifact = {
751
+ "study_name": study_name,
752
+ "original": original_thresholds,
753
+ "optimized": optimized_thresholds,
754
+ "improvement": improvement,
755
+ "metric": metric,
756
+ }
757
+ mlflow.log_dict(thresholds_artifact, "optimized_thresholds.json")
758
+
759
+ logger.info(
760
+ "Logged optimization results to MLflow",
761
+ study_name=study_name,
762
+ improvement=f"{improvement:.1%}",
763
+ )
764
+
765
+
766
+ def generate_eval_dataset_from_cache(
767
+ cache_entries: Sequence[dict[str, Any]],
768
+ embedding_model: LLMModel | str = "databricks-gte-large-en",
769
+ num_positive_pairs: int = 50,
770
+ num_negative_pairs: int = 50,
771
+ paraphrase_model: LLMModel | str | None = None,
772
+ dataset_name: str = "generated_eval_dataset",
773
+ ) -> ContextAwareCacheEvalDataset:
774
+ """
775
+ Generate an evaluation dataset from existing cache entries.
776
+
777
+ Creates positive pairs (semantically equivalent questions) using LLM paraphrasing
778
+ and negative pairs (different questions) from random cache entry pairs.
779
+
780
+ Args:
781
+ cache_entries: List of cache entries with 'question', 'conversation_context',
782
+ 'question_embedding', and 'context_embedding' keys. Use cache.get_entries()
783
+ with include_embeddings=True to retrieve these.
784
+ embedding_model: Model for generating embeddings for paraphrased questions
785
+ num_positive_pairs: Number of positive (matching) pairs to generate
786
+ num_negative_pairs: Number of negative (non-matching) pairs to generate
787
+ paraphrase_model: LLM for generating paraphrases (defaults to embedding_model)
788
+ dataset_name: Name for the generated dataset
789
+
790
+ Returns:
791
+ ContextAwareCacheEvalDataset with generated entries
792
+
793
+ Example:
794
+ # Get entries from cache with embeddings
795
+ entries = cache_service.get_entries(include_embeddings=True, limit=100)
796
+
797
+ # Generate evaluation dataset
798
+ eval_dataset = generate_eval_dataset_from_cache(
799
+ cache_entries=entries,
800
+ num_positive_pairs=50,
801
+ num_negative_pairs=50,
802
+ dataset_name="my_cache_eval",
803
+ )
804
+ """
805
+ import random
806
+
807
+ if len(cache_entries) < 2:
808
+ raise ValueError("Need at least 2 cache entries to generate dataset")
809
+
810
+ # Convert embedding model
811
+ emb_model: LLMModel = (
812
+ LLMModel(name=embedding_model)
813
+ if isinstance(embedding_model, str)
814
+ else embedding_model
815
+ )
816
+ embeddings = emb_model.as_embeddings_model()
817
+
818
+ # Use paraphrase model or default to a capable LLM
819
+ para_model: LLMModel = (
820
+ LLMModel(name=paraphrase_model)
821
+ if isinstance(paraphrase_model, str)
822
+ else (
823
+ paraphrase_model
824
+ if paraphrase_model
825
+ else LLMModel(name="databricks-meta-llama-3-3-70b-instruct")
826
+ )
827
+ )
828
+ chat = para_model.as_chat_model()
829
+
830
+ entries: list[ContextAwareCacheEvalEntry] = []
831
+
832
+ # Generate positive pairs (paraphrases)
833
+ logger.info(
834
+ "Generating positive pairs using paraphrasing", count=num_positive_pairs
835
+ )
836
+
837
+ for i in range(min(num_positive_pairs, len(cache_entries))):
838
+ entry = cache_entries[i % len(cache_entries)]
839
+ original_question = entry.get("question", "")
840
+ original_context = entry.get("conversation_context", "")
841
+ original_q_emb = entry.get("question_embedding", [])
842
+ original_c_emb = entry.get("context_embedding", [])
843
+
844
+ if not original_question or not original_q_emb:
845
+ continue
846
+
847
+ # Generate paraphrase
848
+ try:
849
+ paraphrase_prompt = f"""Rephrase the following question to ask the same thing but using different words.
850
+ Keep the same meaning and intent. Only output the rephrased question, nothing else.
851
+
852
+ Original question: {original_question}
853
+
854
+ Rephrased question:"""
855
+ response = chat.invoke(paraphrase_prompt)
856
+ paraphrased_question = response.content.strip()
857
+
858
+ # Generate embedding for paraphrase
859
+ para_q_emb = embeddings.embed_query(paraphrased_question)
860
+ para_c_emb = (
861
+ embeddings.embed_query(original_context)
862
+ if original_context
863
+ else original_c_emb
864
+ )
865
+
866
+ entries.append(
867
+ ContextAwareCacheEvalEntry(
868
+ question=paraphrased_question,
869
+ question_embedding=para_q_emb,
870
+ context=original_context,
871
+ context_embedding=para_c_emb,
872
+ cached_question=original_question,
873
+ cached_question_embedding=original_q_emb,
874
+ cached_context=original_context,
875
+ cached_context_embedding=original_c_emb,
876
+ expected_match=True,
877
+ )
878
+ )
879
+ except Exception as e:
880
+ logger.warning("Failed to generate paraphrase", error=str(e))
881
+
882
+ # Generate negative pairs (random different questions)
883
+ logger.info(
884
+ "Generating negative pairs from different cache entries",
885
+ count=num_negative_pairs,
886
+ )
887
+
888
+ for _ in range(num_negative_pairs):
889
+ # Pick two different random entries
890
+ if len(cache_entries) < 2:
891
+ break
892
+ idx1, idx2 = random.sample(range(len(cache_entries)), 2)
893
+ entry1 = cache_entries[idx1]
894
+ entry2 = cache_entries[idx2]
895
+
896
+ # Use entry1 as the "question" and entry2 as the "cached" entry
897
+ entries.append(
898
+ ContextAwareCacheEvalEntry(
899
+ question=entry1.get("question", ""),
900
+ question_embedding=entry1.get("question_embedding", []),
901
+ context=entry1.get("conversation_context", ""),
902
+ context_embedding=entry1.get("context_embedding", []),
903
+ cached_question=entry2.get("question", ""),
904
+ cached_question_embedding=entry2.get("question_embedding", []),
905
+ cached_context=entry2.get("conversation_context", ""),
906
+ cached_context_embedding=entry2.get("context_embedding", []),
907
+ expected_match=False,
908
+ )
909
+ )
910
+
911
+ logger.info(
912
+ "Generated evaluation dataset",
913
+ name=dataset_name,
914
+ total_entries=len(entries),
915
+ positive_pairs=sum(1 for e in entries if e.expected_match is True),
916
+ negative_pairs=sum(1 for e in entries if e.expected_match is False),
917
+ )
918
+
919
+ return ContextAwareCacheEvalDataset(
920
+ name=dataset_name,
921
+ entries=entries,
922
+ description=f"Generated from {len(cache_entries)} cache entries",
923
+ )
924
+
925
+
926
+ def clear_judge_cache() -> None:
927
+ """Clear the LLM judge result cache."""
928
+ global _judge_cache
929
+ _judge_cache.clear()
930
+ logger.debug("Cleared judge cache")