dao-ai 0.1.2__py3-none-any.whl → 0.1.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. dao_ai/apps/__init__.py +24 -0
  2. dao_ai/apps/handlers.py +105 -0
  3. dao_ai/apps/model_serving.py +29 -0
  4. dao_ai/apps/resources.py +1122 -0
  5. dao_ai/apps/server.py +39 -0
  6. dao_ai/cli.py +546 -37
  7. dao_ai/config.py +1179 -139
  8. dao_ai/evaluation.py +543 -0
  9. dao_ai/genie/__init__.py +55 -7
  10. dao_ai/genie/cache/__init__.py +34 -7
  11. dao_ai/genie/cache/base.py +143 -2
  12. dao_ai/genie/cache/context_aware/__init__.py +31 -0
  13. dao_ai/genie/cache/context_aware/base.py +1151 -0
  14. dao_ai/genie/cache/context_aware/in_memory.py +609 -0
  15. dao_ai/genie/cache/context_aware/persistent.py +802 -0
  16. dao_ai/genie/cache/context_aware/postgres.py +1166 -0
  17. dao_ai/genie/cache/core.py +1 -1
  18. dao_ai/genie/cache/lru.py +257 -75
  19. dao_ai/genie/cache/optimization.py +890 -0
  20. dao_ai/genie/core.py +235 -11
  21. dao_ai/memory/postgres.py +175 -39
  22. dao_ai/middleware/__init__.py +38 -0
  23. dao_ai/middleware/assertions.py +3 -3
  24. dao_ai/middleware/context_editing.py +230 -0
  25. dao_ai/middleware/core.py +4 -4
  26. dao_ai/middleware/guardrails.py +3 -3
  27. dao_ai/middleware/human_in_the_loop.py +3 -2
  28. dao_ai/middleware/message_validation.py +4 -4
  29. dao_ai/middleware/model_call_limit.py +77 -0
  30. dao_ai/middleware/model_retry.py +121 -0
  31. dao_ai/middleware/pii.py +157 -0
  32. dao_ai/middleware/summarization.py +1 -1
  33. dao_ai/middleware/tool_call_limit.py +210 -0
  34. dao_ai/middleware/tool_retry.py +174 -0
  35. dao_ai/middleware/tool_selector.py +129 -0
  36. dao_ai/models.py +327 -370
  37. dao_ai/nodes.py +9 -16
  38. dao_ai/orchestration/core.py +33 -9
  39. dao_ai/orchestration/supervisor.py +29 -13
  40. dao_ai/orchestration/swarm.py +6 -1
  41. dao_ai/{prompts.py → prompts/__init__.py} +12 -61
  42. dao_ai/prompts/instructed_retriever_decomposition.yaml +58 -0
  43. dao_ai/prompts/instruction_reranker.yaml +14 -0
  44. dao_ai/prompts/router.yaml +37 -0
  45. dao_ai/prompts/verifier.yaml +46 -0
  46. dao_ai/providers/base.py +28 -2
  47. dao_ai/providers/databricks.py +363 -33
  48. dao_ai/state.py +1 -0
  49. dao_ai/tools/__init__.py +5 -3
  50. dao_ai/tools/genie.py +103 -26
  51. dao_ai/tools/instructed_retriever.py +366 -0
  52. dao_ai/tools/instruction_reranker.py +202 -0
  53. dao_ai/tools/mcp.py +539 -97
  54. dao_ai/tools/router.py +89 -0
  55. dao_ai/tools/slack.py +13 -2
  56. dao_ai/tools/sql.py +7 -3
  57. dao_ai/tools/unity_catalog.py +32 -10
  58. dao_ai/tools/vector_search.py +493 -160
  59. dao_ai/tools/verifier.py +159 -0
  60. dao_ai/utils.py +182 -2
  61. dao_ai/vector_search.py +46 -1
  62. {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/METADATA +45 -9
  63. dao_ai-0.1.20.dist-info/RECORD +89 -0
  64. dao_ai/agent_as_code.py +0 -22
  65. dao_ai/genie/cache/semantic.py +0 -970
  66. dao_ai-0.1.2.dist-info/RECORD +0 -64
  67. {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/WHEEL +0 -0
  68. {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/entry_points.txt +0 -0
  69. {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,890 @@
1
+ """
2
+ Semantic cache threshold optimization using Optuna Bayesian optimization.
3
+
4
+ This module provides optimization for Genie semantic cache thresholds using
5
+ Optuna's Tree-structured Parzen Estimator (TPE) algorithm with LLM-as-Judge
6
+ evaluation for semantic match validation.
7
+
8
+ The optimizer tunes these thresholds:
9
+ - similarity_threshold: Minimum similarity for question matching
10
+ - context_similarity_threshold: Minimum similarity for context matching
11
+ - question_weight: Weight for question similarity in combined score
12
+
13
+ Usage:
14
+ from dao_ai.genie.cache.optimization import optimize_semantic_cache_thresholds
15
+
16
+ result = optimize_semantic_cache_thresholds(
17
+ dataset=my_eval_dataset,
18
+ judge_model="databricks-meta-llama-3-3-70b-instruct",
19
+ n_trials=50,
20
+ metric="f1",
21
+ )
22
+
23
+ if result.improved:
24
+ print(f"Improved by {result.improvement:.1%}")
25
+ print(f"Best thresholds: {result.optimized_thresholds}")
26
+ """
27
+
28
+ import hashlib
29
+ import math
30
+ from dataclasses import dataclass, field
31
+ from datetime import datetime, timezone
32
+ from typing import Any, Callable, Literal, Sequence
33
+
34
+ import mlflow
35
+ import optuna
36
+ from loguru import logger
37
+ from optuna.samplers import TPESampler
38
+
39
+ # Optional MLflow integration - requires optuna-integration[mlflow]
40
+ try:
41
+ from optuna.integration import MLflowCallback
42
+
43
+ MLFLOW_CALLBACK_AVAILABLE = True
44
+ except ModuleNotFoundError:
45
+ MLFLOW_CALLBACK_AVAILABLE = False
46
+ MLflowCallback = None # type: ignore
47
+
48
+ from dao_ai.config import GenieContextAwareCacheParametersModel, LLMModel
49
+ from dao_ai.utils import dao_ai_version
50
+
51
+ __all__ = [
52
+ "SemanticCacheEvalEntry",
53
+ "SemanticCacheEvalDataset",
54
+ "ThresholdOptimizationResult",
55
+ "optimize_semantic_cache_thresholds",
56
+ "generate_eval_dataset_from_cache",
57
+ "semantic_match_judge",
58
+ ]
59
+
60
+
61
+ @dataclass
62
+ class SemanticCacheEvalEntry:
63
+ """Single evaluation entry for threshold optimization.
64
+
65
+ Represents a pair of question/context combinations to evaluate
66
+ whether the cache should return a hit or miss.
67
+
68
+ Attributes:
69
+ question: Current question being asked
70
+ question_embedding: Pre-computed embedding of the question
71
+ context: Conversation context string
72
+ context_embedding: Pre-computed embedding of the context
73
+ cached_question: Question from the cache entry
74
+ cached_question_embedding: Embedding of the cached question
75
+ cached_context: Context from the cache entry
76
+ cached_context_embedding: Embedding of the cached context
77
+ expected_match: Whether this pair should be a cache hit (True),
78
+ miss (False), or use LLM to determine (None)
79
+ """
80
+
81
+ question: str
82
+ question_embedding: list[float]
83
+ context: str
84
+ context_embedding: list[float]
85
+ cached_question: str
86
+ cached_question_embedding: list[float]
87
+ cached_context: str
88
+ cached_context_embedding: list[float]
89
+ expected_match: bool | None = None
90
+
91
+
92
+ @dataclass
93
+ class SemanticCacheEvalDataset:
94
+ """Dataset for semantic cache threshold optimization.
95
+
96
+ Attributes:
97
+ name: Name of the dataset for tracking
98
+ entries: List of evaluation entries
99
+ description: Optional description of the dataset
100
+ """
101
+
102
+ name: str
103
+ entries: list[SemanticCacheEvalEntry]
104
+ description: str = ""
105
+
106
+ def __len__(self) -> int:
107
+ return len(self.entries)
108
+
109
+ def __iter__(self):
110
+ return iter(self.entries)
111
+
112
+
113
+ @dataclass
114
+ class ThresholdOptimizationResult:
115
+ """Result of semantic cache threshold optimization.
116
+
117
+ Attributes:
118
+ optimized_thresholds: Dictionary of optimized threshold values
119
+ original_thresholds: Dictionary of original threshold values
120
+ original_score: Score with original thresholds
121
+ optimized_score: Score with optimized thresholds
122
+ improvement: Percentage improvement (0.0-1.0)
123
+ n_trials: Number of optimization trials run
124
+ best_trial_number: Trial number that produced best result
125
+ study_name: Name of the Optuna study
126
+ metadata: Additional optimization metadata
127
+ """
128
+
129
+ optimized_thresholds: dict[str, float]
130
+ original_thresholds: dict[str, float]
131
+ original_score: float
132
+ optimized_score: float
133
+ improvement: float
134
+ n_trials: int
135
+ best_trial_number: int
136
+ study_name: str
137
+ metadata: dict[str, Any] = field(default_factory=dict)
138
+
139
+ @property
140
+ def improved(self) -> bool:
141
+ """Whether the optimization improved the thresholds."""
142
+ return self.optimized_score > self.original_score
143
+
144
+
145
+ # Cache for LLM judge results to avoid redundant calls
146
+ _judge_cache: dict[str, bool] = {}
147
+
148
+
149
+ def _compute_cache_key(
150
+ question1: str, context1: str, question2: str, context2: str
151
+ ) -> str:
152
+ """Compute a cache key for judge results."""
153
+ content = f"{question1}|{context1}|{question2}|{context2}"
154
+ return hashlib.sha256(content.encode()).hexdigest()
155
+
156
+
157
+ def semantic_match_judge(
158
+ question1: str,
159
+ context1: str,
160
+ question2: str,
161
+ context2: str,
162
+ model: LLMModel | str,
163
+ use_cache: bool = True,
164
+ ) -> bool:
165
+ """
166
+ Use LLM to determine if two question/context pairs are semantically equivalent.
167
+
168
+ This function acts as a judge to determine whether two questions with their
169
+ respective conversation contexts are asking for the same information and
170
+ would expect the same SQL query response.
171
+
172
+ Args:
173
+ question1: First question
174
+ context1: Conversation context for first question
175
+ question2: Second question
176
+ context2: Conversation context for second question
177
+ model: LLM model to use for judging
178
+ use_cache: Whether to cache results (default True)
179
+
180
+ Returns:
181
+ True if the pairs are semantically equivalent, False otherwise
182
+ """
183
+ global _judge_cache
184
+
185
+ # Check cache first
186
+ if use_cache:
187
+ cache_key = _compute_cache_key(question1, context1, question2, context2)
188
+ if cache_key in _judge_cache:
189
+ return _judge_cache[cache_key]
190
+
191
+ # Convert model to LLMModel if string
192
+ llm_model: LLMModel = LLMModel(name=model) if isinstance(model, str) else model
193
+
194
+ # Create the chat model
195
+ chat = llm_model.as_chat_model()
196
+
197
+ # Construct the prompt for semantic equivalence judgment
198
+ prompt = f"""You are an expert at determining semantic equivalence between database queries.
199
+
200
+ Given two question-context pairs, determine if they are semantically equivalent - meaning they are asking for the same information and would expect the same SQL query result.
201
+
202
+ Consider:
203
+ 1. Are both questions asking for the same data/information?
204
+ 2. Do the conversation contexts provide similar filtering or constraints?
205
+ 3. Would answering both require the same SQL query?
206
+
207
+ IMPORTANT: Be strict. Only return "MATCH" if the questions are truly asking for the same thing in the same context. Similar but different questions should return "NO_MATCH".
208
+
209
+ Question 1: {question1}
210
+ Context 1: {context1 if context1 else "(no context)"}
211
+
212
+ Question 2: {question2}
213
+ Context 2: {context2 if context2 else "(no context)"}
214
+
215
+ Respond with ONLY one word: "MATCH" or "NO_MATCH"
216
+ """
217
+
218
+ try:
219
+ response = chat.invoke(prompt)
220
+ result_text = response.content.strip().upper()
221
+ is_match = "MATCH" in result_text and "NO_MATCH" not in result_text
222
+
223
+ # Cache the result
224
+ if use_cache:
225
+ _judge_cache[cache_key] = is_match
226
+
227
+ logger.trace(
228
+ "LLM judge result",
229
+ question1=question1[:50],
230
+ question2=question2[:50],
231
+ is_match=is_match,
232
+ )
233
+
234
+ return is_match
235
+
236
+ except Exception as e:
237
+ logger.warning("LLM judge failed", error=str(e))
238
+ # Default to not matching on error (conservative)
239
+ return False
240
+
241
+
242
+ def _compute_l2_similarity(embedding1: list[float], embedding2: list[float]) -> float:
243
+ """
244
+ Compute similarity from L2 (Euclidean) distance.
245
+
246
+ Uses the same formula as the semantic cache:
247
+ similarity = 1.0 / (1.0 + L2_distance)
248
+
249
+ This gives a value in range [0, 1] where 1 means identical.
250
+
251
+ Args:
252
+ embedding1: First embedding vector
253
+ embedding2: Second embedding vector
254
+
255
+ Returns:
256
+ Similarity score in range [0, 1]
257
+ """
258
+ if len(embedding1) != len(embedding2):
259
+ raise ValueError(
260
+ f"Embedding dimensions must match: {len(embedding1)} vs {len(embedding2)}"
261
+ )
262
+
263
+ # Compute L2 distance
264
+ squared_diff_sum = sum(
265
+ (a - b) ** 2 for a, b in zip(embedding1, embedding2, strict=True)
266
+ )
267
+ l2_distance = math.sqrt(squared_diff_sum)
268
+
269
+ # Convert to similarity
270
+ similarity = 1.0 / (1.0 + l2_distance)
271
+ return similarity
272
+
273
+
274
+ def _evaluate_thresholds(
275
+ dataset: SemanticCacheEvalDataset,
276
+ similarity_threshold: float,
277
+ context_similarity_threshold: float,
278
+ question_weight: float,
279
+ judge_model: LLMModel | str | None = None,
280
+ ) -> tuple[float, float, float, dict[str, int]]:
281
+ """
282
+ Evaluate a set of thresholds against the dataset.
283
+
284
+ Args:
285
+ dataset: Evaluation dataset
286
+ similarity_threshold: Threshold for question similarity
287
+ context_similarity_threshold: Threshold for context similarity
288
+ question_weight: Weight for question in combined score
289
+ judge_model: Optional LLM model for judging unlabeled entries
290
+
291
+ Returns:
292
+ Tuple of (precision, recall, f1, confusion_matrix_dict)
293
+ """
294
+ true_positives = 0
295
+ false_positives = 0
296
+ true_negatives = 0
297
+ false_negatives = 0
298
+
299
+ # Note: context_weight = 1.0 - question_weight could be used for weighted scoring
300
+ # but we currently use independent thresholds for question and context similarity
301
+
302
+ for entry in dataset.entries:
303
+ # Compute similarities
304
+ question_sim = _compute_l2_similarity(
305
+ entry.question_embedding, entry.cached_question_embedding
306
+ )
307
+ context_sim = _compute_l2_similarity(
308
+ entry.context_embedding, entry.cached_context_embedding
309
+ )
310
+
311
+ # Apply threshold logic (same as production cache)
312
+ predicted_match = (
313
+ question_sim >= similarity_threshold
314
+ and context_sim >= context_similarity_threshold
315
+ )
316
+
317
+ # Get expected match
318
+ expected_match = entry.expected_match
319
+ if expected_match is None:
320
+ if judge_model is None:
321
+ # Skip entries without labels if no judge provided
322
+ continue
323
+ # Use LLM judge to determine expected match
324
+ expected_match = semantic_match_judge(
325
+ entry.question,
326
+ entry.context,
327
+ entry.cached_question,
328
+ entry.cached_context,
329
+ judge_model,
330
+ )
331
+
332
+ # Update confusion matrix
333
+ if predicted_match and expected_match:
334
+ true_positives += 1
335
+ elif predicted_match and not expected_match:
336
+ false_positives += 1
337
+ elif not predicted_match and expected_match:
338
+ false_negatives += 1
339
+ else:
340
+ true_negatives += 1
341
+
342
+ # Calculate metrics
343
+ total = true_positives + false_positives + true_negatives + false_negatives
344
+
345
+ precision = (
346
+ true_positives / (true_positives + false_positives)
347
+ if (true_positives + false_positives) > 0
348
+ else 0.0
349
+ )
350
+ recall = (
351
+ true_positives / (true_positives + false_negatives)
352
+ if (true_positives + false_negatives) > 0
353
+ else 0.0
354
+ )
355
+ f1 = (
356
+ 2 * precision * recall / (precision + recall)
357
+ if (precision + recall) > 0
358
+ else 0.0
359
+ )
360
+
361
+ confusion = {
362
+ "true_positives": true_positives,
363
+ "false_positives": false_positives,
364
+ "true_negatives": true_negatives,
365
+ "false_negatives": false_negatives,
366
+ "total": total,
367
+ }
368
+
369
+ return precision, recall, f1, confusion
370
+
371
+
372
+ def _create_objective(
373
+ dataset: SemanticCacheEvalDataset,
374
+ judge_model: LLMModel | str | None,
375
+ metric: Literal["f1", "precision", "recall", "fbeta"],
376
+ beta: float = 1.0,
377
+ ) -> Callable[[optuna.Trial], float]:
378
+ """Create the Optuna objective function."""
379
+
380
+ def objective(trial: optuna.Trial) -> float:
381
+ # Sample parameters
382
+ similarity_threshold = trial.suggest_float(
383
+ "similarity_threshold", 0.5, 0.99, log=False
384
+ )
385
+ context_similarity_threshold = trial.suggest_float(
386
+ "context_similarity_threshold", 0.5, 0.99, log=False
387
+ )
388
+ question_weight = trial.suggest_float("question_weight", 0.1, 0.9, log=False)
389
+
390
+ # Evaluate
391
+ precision, recall, f1, confusion = _evaluate_thresholds(
392
+ dataset=dataset,
393
+ similarity_threshold=similarity_threshold,
394
+ context_similarity_threshold=context_similarity_threshold,
395
+ question_weight=question_weight,
396
+ judge_model=judge_model,
397
+ )
398
+
399
+ # Log intermediate results
400
+ trial.set_user_attr("precision", precision)
401
+ trial.set_user_attr("recall", recall)
402
+ trial.set_user_attr("f1", f1)
403
+ trial.set_user_attr("confusion", confusion)
404
+
405
+ # Return selected metric
406
+ if metric == "f1":
407
+ return f1
408
+ elif metric == "precision":
409
+ return precision
410
+ elif metric == "recall":
411
+ return recall
412
+ elif metric == "fbeta":
413
+ # F-beta score: (1 + beta^2) * (precision * recall) / (beta^2 * precision + recall)
414
+ if precision + recall == 0:
415
+ return 0.0
416
+ fbeta = (
417
+ (1 + beta**2) * (precision * recall) / (beta**2 * precision + recall)
418
+ )
419
+ return fbeta
420
+ else:
421
+ raise ValueError(f"Unknown metric: {metric}")
422
+
423
+ return objective
424
+
425
+
426
+ def optimize_semantic_cache_thresholds(
427
+ dataset: SemanticCacheEvalDataset,
428
+ original_thresholds: dict[str, float]
429
+ | GenieContextAwareCacheParametersModel
430
+ | None = None,
431
+ judge_model: LLMModel | str = "databricks-meta-llama-3-3-70b-instruct",
432
+ n_trials: int = 50,
433
+ metric: Literal["f1", "precision", "recall", "fbeta"] = "f1",
434
+ beta: float = 1.0,
435
+ register_if_improved: bool = True,
436
+ study_name: str | None = None,
437
+ seed: int | None = None,
438
+ show_progress_bar: bool = True,
439
+ ) -> ThresholdOptimizationResult:
440
+ """
441
+ Optimize semantic cache thresholds using Bayesian optimization.
442
+
443
+ Uses Optuna's Tree-structured Parzen Estimator (TPE) to efficiently
444
+ search the parameter space and find optimal threshold values.
445
+
446
+ Args:
447
+ dataset: Evaluation dataset with question/context pairs
448
+ original_thresholds: Original thresholds to compare against.
449
+ Can be a dict or GenieContextAwareCacheParametersModel.
450
+ If None, uses default values.
451
+ judge_model: LLM model for semantic match judging (for unlabeled entries)
452
+ n_trials: Number of optimization trials to run
453
+ metric: Optimization metric ("f1", "precision", "recall", "fbeta")
454
+ beta: Beta value for fbeta metric (higher = favor recall)
455
+ register_if_improved: Log results to MLflow if improved
456
+ study_name: Optional name for the Optuna study
457
+ seed: Random seed for reproducibility
458
+ show_progress_bar: Whether to show progress bar during optimization
459
+
460
+ Returns:
461
+ ThresholdOptimizationResult with optimized thresholds and metrics
462
+
463
+ Example:
464
+ from dao_ai.genie.cache.optimization import (
465
+ optimize_semantic_cache_thresholds,
466
+ SemanticCacheEvalDataset,
467
+ )
468
+
469
+ result = optimize_semantic_cache_thresholds(
470
+ dataset=my_dataset,
471
+ judge_model="databricks-meta-llama-3-3-70b-instruct",
472
+ n_trials=50,
473
+ metric="f1",
474
+ )
475
+
476
+ if result.improved:
477
+ print(f"New thresholds: {result.optimized_thresholds}")
478
+ """
479
+ logger.info(
480
+ "Starting semantic cache threshold optimization",
481
+ dataset_name=dataset.name,
482
+ dataset_size=len(dataset),
483
+ n_trials=n_trials,
484
+ metric=metric,
485
+ )
486
+
487
+ # Parse original thresholds
488
+ if original_thresholds is None:
489
+ orig_thresholds = {
490
+ "similarity_threshold": 0.85,
491
+ "context_similarity_threshold": 0.80,
492
+ "question_weight": 0.6,
493
+ }
494
+ elif isinstance(original_thresholds, GenieContextAwareCacheParametersModel):
495
+ orig_thresholds = {
496
+ "similarity_threshold": original_thresholds.similarity_threshold,
497
+ "context_similarity_threshold": original_thresholds.context_similarity_threshold,
498
+ "question_weight": original_thresholds.question_weight or 0.6,
499
+ }
500
+ else:
501
+ orig_thresholds = original_thresholds.copy()
502
+
503
+ # Evaluate original thresholds
504
+ orig_precision, orig_recall, orig_f1, orig_confusion = _evaluate_thresholds(
505
+ dataset=dataset,
506
+ similarity_threshold=orig_thresholds["similarity_threshold"],
507
+ context_similarity_threshold=orig_thresholds["context_similarity_threshold"],
508
+ question_weight=orig_thresholds["question_weight"],
509
+ judge_model=judge_model,
510
+ )
511
+
512
+ # Determine original score based on metric
513
+ if metric == "f1":
514
+ original_score = orig_f1
515
+ elif metric == "precision":
516
+ original_score = orig_precision
517
+ elif metric == "recall":
518
+ original_score = orig_recall
519
+ elif metric == "fbeta":
520
+ if orig_precision + orig_recall == 0:
521
+ original_score = 0.0
522
+ else:
523
+ original_score = (
524
+ (1 + beta**2)
525
+ * (orig_precision * orig_recall)
526
+ / (beta**2 * orig_precision + orig_recall)
527
+ )
528
+ else:
529
+ raise ValueError(f"Unknown metric: {metric}")
530
+
531
+ logger.info(
532
+ "Evaluated original thresholds",
533
+ precision=f"{orig_precision:.4f}",
534
+ recall=f"{orig_recall:.4f}",
535
+ f1=f"{orig_f1:.4f}",
536
+ original_score=f"{original_score:.4f}",
537
+ )
538
+
539
+ # Create study name if not provided
540
+ if study_name is None:
541
+ timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
542
+ study_name = f"semantic_cache_threshold_optimization_{timestamp}"
543
+
544
+ # Create Optuna study
545
+ sampler = TPESampler(seed=seed)
546
+ study = optuna.create_study(
547
+ study_name=study_name,
548
+ direction="maximize",
549
+ sampler=sampler,
550
+ )
551
+
552
+ # Add original thresholds as first trial for comparison
553
+ study.enqueue_trial(orig_thresholds)
554
+
555
+ # Create objective function
556
+ objective = _create_objective(
557
+ dataset=dataset,
558
+ judge_model=judge_model,
559
+ metric=metric,
560
+ beta=beta,
561
+ )
562
+
563
+ # Set up MLflow callback if available
564
+ callbacks = []
565
+ if MLFLOW_CALLBACK_AVAILABLE and MLflowCallback is not None:
566
+ try:
567
+ mlflow_callback = MLflowCallback(
568
+ tracking_uri=mlflow.get_tracking_uri(),
569
+ metric_name=metric,
570
+ create_experiment=False,
571
+ )
572
+ callbacks.append(mlflow_callback)
573
+ except Exception as e:
574
+ logger.debug("MLflow callback not available", error=str(e))
575
+
576
+ # Run optimization
577
+ study.optimize(
578
+ objective,
579
+ n_trials=n_trials,
580
+ show_progress_bar=show_progress_bar,
581
+ callbacks=callbacks if callbacks else None,
582
+ )
583
+
584
+ # Get best trial
585
+ best_trial = study.best_trial
586
+ best_thresholds = {
587
+ "similarity_threshold": best_trial.params["similarity_threshold"],
588
+ "context_similarity_threshold": best_trial.params[
589
+ "context_similarity_threshold"
590
+ ],
591
+ "question_weight": best_trial.params["question_weight"],
592
+ }
593
+
594
+ # Get metrics from best trial
595
+ best_precision = best_trial.user_attrs.get("precision", 0.0)
596
+ best_recall = best_trial.user_attrs.get("recall", 0.0)
597
+ best_f1 = best_trial.user_attrs.get("f1", 0.0)
598
+ best_confusion = best_trial.user_attrs.get("confusion", {})
599
+ optimized_score = best_trial.value
600
+
601
+ # Calculate improvement
602
+ improvement = (
603
+ (optimized_score - original_score) / original_score
604
+ if original_score > 0
605
+ else 0.0
606
+ )
607
+
608
+ logger.success(
609
+ "Optimization complete",
610
+ best_trial_number=best_trial.number,
611
+ original_score=f"{original_score:.4f}",
612
+ optimized_score=f"{optimized_score:.4f}",
613
+ improvement=f"{improvement:.1%}",
614
+ best_thresholds=best_thresholds,
615
+ )
616
+
617
+ # Log to MLflow if improved
618
+ if register_if_improved and improvement > 0:
619
+ try:
620
+ _log_optimization_to_mlflow(
621
+ study_name=study_name,
622
+ dataset_name=dataset.name,
623
+ dataset_size=len(dataset),
624
+ original_thresholds=orig_thresholds,
625
+ optimized_thresholds=best_thresholds,
626
+ original_score=original_score,
627
+ optimized_score=optimized_score,
628
+ improvement=improvement,
629
+ metric=metric,
630
+ n_trials=n_trials,
631
+ best_precision=best_precision,
632
+ best_recall=best_recall,
633
+ best_f1=best_f1,
634
+ best_confusion=best_confusion,
635
+ judge_model=judge_model,
636
+ )
637
+ except Exception as e:
638
+ logger.warning("Failed to log optimization to MLflow", error=str(e))
639
+
640
+ # Build result
641
+ result = ThresholdOptimizationResult(
642
+ optimized_thresholds=best_thresholds,
643
+ original_thresholds=orig_thresholds,
644
+ original_score=original_score,
645
+ optimized_score=optimized_score,
646
+ improvement=improvement,
647
+ n_trials=n_trials,
648
+ best_trial_number=best_trial.number,
649
+ study_name=study_name,
650
+ metadata={
651
+ "metric": metric,
652
+ "beta": beta if metric == "fbeta" else None,
653
+ "judge_model": str(judge_model),
654
+ "dataset_name": dataset.name,
655
+ "dataset_size": len(dataset),
656
+ "original_precision": orig_precision,
657
+ "original_recall": orig_recall,
658
+ "original_f1": orig_f1,
659
+ "optimized_precision": best_precision,
660
+ "optimized_recall": best_recall,
661
+ "optimized_f1": best_f1,
662
+ "confusion_matrix": best_confusion,
663
+ },
664
+ )
665
+
666
+ return result
667
+
668
+
669
+ def _log_optimization_to_mlflow(
670
+ study_name: str,
671
+ dataset_name: str,
672
+ dataset_size: int,
673
+ original_thresholds: dict[str, float],
674
+ optimized_thresholds: dict[str, float],
675
+ original_score: float,
676
+ optimized_score: float,
677
+ improvement: float,
678
+ metric: str,
679
+ n_trials: int,
680
+ best_precision: float,
681
+ best_recall: float,
682
+ best_f1: float,
683
+ best_confusion: dict[str, int],
684
+ judge_model: LLMModel | str,
685
+ ) -> None:
686
+ """Log optimization results to MLflow."""
687
+ with mlflow.start_run(run_name=study_name):
688
+ # Log parameters
689
+ mlflow.log_params(
690
+ {
691
+ "optimizer": "optuna_tpe",
692
+ "metric": metric,
693
+ "n_trials": n_trials,
694
+ "dataset_name": dataset_name,
695
+ "dataset_size": dataset_size,
696
+ "judge_model": str(judge_model),
697
+ "dao_ai_version": dao_ai_version(),
698
+ }
699
+ )
700
+
701
+ # Log original thresholds
702
+ for key, value in original_thresholds.items():
703
+ mlflow.log_param(f"original_{key}", value)
704
+
705
+ # Log optimized thresholds
706
+ for key, value in optimized_thresholds.items():
707
+ mlflow.log_param(f"optimized_{key}", value)
708
+
709
+ # Log metrics
710
+ mlflow.log_metrics(
711
+ {
712
+ "original_score": original_score,
713
+ "optimized_score": optimized_score,
714
+ "improvement": improvement,
715
+ "precision": best_precision,
716
+ "recall": best_recall,
717
+ "f1": best_f1,
718
+ **{f"confusion_{k}": v for k, v in best_confusion.items()},
719
+ }
720
+ )
721
+
722
+ # Log thresholds as artifact
723
+ thresholds_artifact = {
724
+ "study_name": study_name,
725
+ "original": original_thresholds,
726
+ "optimized": optimized_thresholds,
727
+ "improvement": improvement,
728
+ "metric": metric,
729
+ }
730
+ mlflow.log_dict(thresholds_artifact, "optimized_thresholds.json")
731
+
732
+ logger.info(
733
+ "Logged optimization results to MLflow",
734
+ study_name=study_name,
735
+ improvement=f"{improvement:.1%}",
736
+ )
737
+
738
+
739
+ def generate_eval_dataset_from_cache(
740
+ cache_entries: Sequence[dict[str, Any]],
741
+ embedding_model: LLMModel | str = "databricks-gte-large-en",
742
+ num_positive_pairs: int = 50,
743
+ num_negative_pairs: int = 50,
744
+ paraphrase_model: LLMModel | str | None = None,
745
+ dataset_name: str = "generated_eval_dataset",
746
+ ) -> SemanticCacheEvalDataset:
747
+ """
748
+ Generate an evaluation dataset from existing cache entries.
749
+
750
+ Creates positive pairs (semantically equivalent questions) using LLM paraphrasing
751
+ and negative pairs (different questions) from random cache entry pairs.
752
+
753
+ Args:
754
+ cache_entries: List of cache entries with 'question', 'conversation_context',
755
+ 'question_embedding', and 'context_embedding' keys
756
+ embedding_model: Model for generating embeddings for paraphrased questions
757
+ num_positive_pairs: Number of positive (matching) pairs to generate
758
+ num_negative_pairs: Number of negative (non-matching) pairs to generate
759
+ paraphrase_model: LLM for generating paraphrases (defaults to embedding_model)
760
+ dataset_name: Name for the generated dataset
761
+
762
+ Returns:
763
+ SemanticCacheEvalDataset with generated entries
764
+ """
765
+ import random
766
+
767
+ if len(cache_entries) < 2:
768
+ raise ValueError("Need at least 2 cache entries to generate dataset")
769
+
770
+ # Convert embedding model
771
+ emb_model: LLMModel = (
772
+ LLMModel(name=embedding_model)
773
+ if isinstance(embedding_model, str)
774
+ else embedding_model
775
+ )
776
+ embeddings = emb_model.as_embeddings_model()
777
+
778
+ # Use paraphrase model or default to a capable LLM
779
+ para_model: LLMModel = (
780
+ LLMModel(name=paraphrase_model)
781
+ if isinstance(paraphrase_model, str)
782
+ else (
783
+ paraphrase_model
784
+ if paraphrase_model
785
+ else LLMModel(name="databricks-meta-llama-3-3-70b-instruct")
786
+ )
787
+ )
788
+ chat = para_model.as_chat_model()
789
+
790
+ entries: list[SemanticCacheEvalEntry] = []
791
+
792
+ # Generate positive pairs (paraphrases)
793
+ logger.info(
794
+ "Generating positive pairs using paraphrasing", count=num_positive_pairs
795
+ )
796
+
797
+ for i in range(min(num_positive_pairs, len(cache_entries))):
798
+ entry = cache_entries[i % len(cache_entries)]
799
+ original_question = entry.get("question", "")
800
+ original_context = entry.get("conversation_context", "")
801
+ original_q_emb = entry.get("question_embedding", [])
802
+ original_c_emb = entry.get("context_embedding", [])
803
+
804
+ if not original_question or not original_q_emb:
805
+ continue
806
+
807
+ # Generate paraphrase
808
+ try:
809
+ paraphrase_prompt = f"""Rephrase the following question to ask the same thing but using different words.
810
+ Keep the same meaning and intent. Only output the rephrased question, nothing else.
811
+
812
+ Original question: {original_question}
813
+
814
+ Rephrased question:"""
815
+ response = chat.invoke(paraphrase_prompt)
816
+ paraphrased_question = response.content.strip()
817
+
818
+ # Generate embedding for paraphrase
819
+ para_q_emb = embeddings.embed_query(paraphrased_question)
820
+ para_c_emb = (
821
+ embeddings.embed_query(original_context)
822
+ if original_context
823
+ else original_c_emb
824
+ )
825
+
826
+ entries.append(
827
+ SemanticCacheEvalEntry(
828
+ question=paraphrased_question,
829
+ question_embedding=para_q_emb,
830
+ context=original_context,
831
+ context_embedding=para_c_emb,
832
+ cached_question=original_question,
833
+ cached_question_embedding=original_q_emb,
834
+ cached_context=original_context,
835
+ cached_context_embedding=original_c_emb,
836
+ expected_match=True,
837
+ )
838
+ )
839
+ except Exception as e:
840
+ logger.warning("Failed to generate paraphrase", error=str(e))
841
+
842
+ # Generate negative pairs (random different questions)
843
+ logger.info(
844
+ "Generating negative pairs from different cache entries",
845
+ count=num_negative_pairs,
846
+ )
847
+
848
+ for _ in range(num_negative_pairs):
849
+ # Pick two different random entries
850
+ if len(cache_entries) < 2:
851
+ break
852
+ idx1, idx2 = random.sample(range(len(cache_entries)), 2)
853
+ entry1 = cache_entries[idx1]
854
+ entry2 = cache_entries[idx2]
855
+
856
+ # Use entry1 as the "question" and entry2 as the "cached" entry
857
+ entries.append(
858
+ SemanticCacheEvalEntry(
859
+ question=entry1.get("question", ""),
860
+ question_embedding=entry1.get("question_embedding", []),
861
+ context=entry1.get("conversation_context", ""),
862
+ context_embedding=entry1.get("context_embedding", []),
863
+ cached_question=entry2.get("question", ""),
864
+ cached_question_embedding=entry2.get("question_embedding", []),
865
+ cached_context=entry2.get("conversation_context", ""),
866
+ cached_context_embedding=entry2.get("context_embedding", []),
867
+ expected_match=False,
868
+ )
869
+ )
870
+
871
+ logger.info(
872
+ "Generated evaluation dataset",
873
+ name=dataset_name,
874
+ total_entries=len(entries),
875
+ positive_pairs=sum(1 for e in entries if e.expected_match is True),
876
+ negative_pairs=sum(1 for e in entries if e.expected_match is False),
877
+ )
878
+
879
+ return SemanticCacheEvalDataset(
880
+ name=dataset_name,
881
+ entries=entries,
882
+ description=f"Generated from {len(cache_entries)} cache entries",
883
+ )
884
+
885
+
886
+ def clear_judge_cache() -> None:
887
+ """Clear the LLM judge result cache."""
888
+ global _judge_cache
889
+ _judge_cache.clear()
890
+ logger.debug("Cleared judge cache")