dao-ai 0.1.19__py3-none-any.whl → 0.1.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/cli.py +37 -7
- dao_ai/config.py +265 -10
- dao_ai/genie/__init__.py +55 -7
- dao_ai/genie/cache/__init__.py +36 -9
- dao_ai/genie/cache/base.py +143 -2
- dao_ai/genie/cache/context_aware/__init__.py +52 -0
- dao_ai/genie/cache/context_aware/base.py +1204 -0
- dao_ai/genie/cache/{in_memory_semantic.py → context_aware/in_memory.py} +233 -383
- dao_ai/genie/cache/context_aware/optimization.py +930 -0
- dao_ai/genie/cache/context_aware/persistent.py +802 -0
- dao_ai/genie/cache/context_aware/postgres.py +1343 -0
- dao_ai/genie/cache/lru.py +248 -70
- dao_ai/genie/core.py +235 -11
- dao_ai/middleware/__init__.py +8 -1
- dao_ai/middleware/tool_call_observability.py +227 -0
- dao_ai/nodes.py +4 -4
- dao_ai/tools/__init__.py +2 -2
- dao_ai/tools/genie.py +10 -10
- dao_ai/utils.py +7 -3
- {dao_ai-0.1.19.dist-info → dao_ai-0.1.21.dist-info}/METADATA +1 -1
- {dao_ai-0.1.19.dist-info → dao_ai-0.1.21.dist-info}/RECORD +24 -19
- dao_ai/genie/cache/semantic.py +0 -1004
- {dao_ai-0.1.19.dist-info → dao_ai-0.1.21.dist-info}/WHEEL +0 -0
- {dao_ai-0.1.19.dist-info → dao_ai-0.1.21.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.1.19.dist-info → dao_ai-0.1.21.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,930 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Context-aware semantic cache threshold optimization using Optuna Bayesian optimization.
|
|
3
|
+
|
|
4
|
+
This module provides optimization for context-aware Genie cache thresholds using
|
|
5
|
+
Optuna's Tree-structured Parzen Estimator (TPE) algorithm with LLM-as-Judge
|
|
6
|
+
evaluation for semantic match validation.
|
|
7
|
+
|
|
8
|
+
The optimizer tunes these thresholds:
|
|
9
|
+
- similarity_threshold: Minimum similarity for question matching
|
|
10
|
+
- context_similarity_threshold: Minimum similarity for context matching
|
|
11
|
+
- question_weight: Weight for question similarity in combined score
|
|
12
|
+
|
|
13
|
+
Usage:
|
|
14
|
+
from dao_ai.genie.cache.context_aware.optimization import (
|
|
15
|
+
optimize_context_aware_cache_thresholds,
|
|
16
|
+
generate_eval_dataset_from_cache,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
# Get entries from your cache
|
|
20
|
+
entries = cache_service.get_entries(include_embeddings=True, limit=100)
|
|
21
|
+
|
|
22
|
+
# Generate evaluation dataset
|
|
23
|
+
eval_dataset = generate_eval_dataset_from_cache(
|
|
24
|
+
cache_entries=entries,
|
|
25
|
+
dataset_name="my_cache_eval",
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
# Optimize thresholds
|
|
29
|
+
result = optimize_context_aware_cache_thresholds(
|
|
30
|
+
dataset=eval_dataset,
|
|
31
|
+
judge_model="databricks-meta-llama-3-3-70b-instruct",
|
|
32
|
+
n_trials=50,
|
|
33
|
+
metric="f1",
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
if result.improved:
|
|
37
|
+
print(f"Improved by {result.improvement:.1%}")
|
|
38
|
+
print(f"Best thresholds: {result.optimized_thresholds}")
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
import hashlib
|
|
42
|
+
import math
|
|
43
|
+
from dataclasses import dataclass, field
|
|
44
|
+
from datetime import datetime, timezone
|
|
45
|
+
from typing import TYPE_CHECKING, Any, Callable, Iterator, Literal, Sequence
|
|
46
|
+
|
|
47
|
+
import mlflow
|
|
48
|
+
from loguru import logger
|
|
49
|
+
|
|
50
|
+
from dao_ai.config import GenieContextAwareCacheParametersModel, LLMModel
|
|
51
|
+
from dao_ai.utils import dao_ai_version
|
|
52
|
+
|
|
53
|
+
# Type-only import for optuna.Trial to support type hints without runtime dependency
|
|
54
|
+
if TYPE_CHECKING:
|
|
55
|
+
import optuna
|
|
56
|
+
|
|
57
|
+
__all__ = [
|
|
58
|
+
"ContextAwareCacheEvalEntry",
|
|
59
|
+
"ContextAwareCacheEvalDataset",
|
|
60
|
+
"ThresholdOptimizationResult",
|
|
61
|
+
"optimize_context_aware_cache_thresholds",
|
|
62
|
+
"generate_eval_dataset_from_cache",
|
|
63
|
+
"semantic_match_judge",
|
|
64
|
+
"clear_judge_cache",
|
|
65
|
+
]
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@dataclass
|
|
69
|
+
class ContextAwareCacheEvalEntry:
|
|
70
|
+
"""Single evaluation entry for threshold optimization.
|
|
71
|
+
|
|
72
|
+
Represents a pair of question/context combinations to evaluate
|
|
73
|
+
whether the cache should return a hit or miss.
|
|
74
|
+
|
|
75
|
+
Attributes:
|
|
76
|
+
question: Current question being asked
|
|
77
|
+
question_embedding: Pre-computed embedding of the question
|
|
78
|
+
context: Conversation context string
|
|
79
|
+
context_embedding: Pre-computed embedding of the context
|
|
80
|
+
cached_question: Question from the cache entry
|
|
81
|
+
cached_question_embedding: Embedding of the cached question
|
|
82
|
+
cached_context: Context from the cache entry
|
|
83
|
+
cached_context_embedding: Embedding of the cached context
|
|
84
|
+
expected_match: Whether this pair should be a cache hit (True),
|
|
85
|
+
miss (False), or use LLM to determine (None)
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
question: str
|
|
89
|
+
question_embedding: list[float]
|
|
90
|
+
context: str
|
|
91
|
+
context_embedding: list[float]
|
|
92
|
+
cached_question: str
|
|
93
|
+
cached_question_embedding: list[float]
|
|
94
|
+
cached_context: str
|
|
95
|
+
cached_context_embedding: list[float]
|
|
96
|
+
expected_match: bool | None = None
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
@dataclass
|
|
100
|
+
class ContextAwareCacheEvalDataset:
|
|
101
|
+
"""Dataset for semantic cache threshold optimization.
|
|
102
|
+
|
|
103
|
+
Attributes:
|
|
104
|
+
name: Name of the dataset for tracking
|
|
105
|
+
entries: List of evaluation entries
|
|
106
|
+
description: Optional description of the dataset
|
|
107
|
+
"""
|
|
108
|
+
|
|
109
|
+
name: str
|
|
110
|
+
entries: list[ContextAwareCacheEvalEntry]
|
|
111
|
+
description: str = ""
|
|
112
|
+
|
|
113
|
+
def __len__(self) -> int:
|
|
114
|
+
return len(self.entries)
|
|
115
|
+
|
|
116
|
+
def __iter__(self) -> Iterator[ContextAwareCacheEvalEntry]:
|
|
117
|
+
return iter(self.entries)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@dataclass
|
|
121
|
+
class ThresholdOptimizationResult:
|
|
122
|
+
"""Result of semantic cache threshold optimization.
|
|
123
|
+
|
|
124
|
+
Attributes:
|
|
125
|
+
optimized_thresholds: Dictionary of optimized threshold values
|
|
126
|
+
original_thresholds: Dictionary of original threshold values
|
|
127
|
+
original_score: Score with original thresholds
|
|
128
|
+
optimized_score: Score with optimized thresholds
|
|
129
|
+
improvement: Percentage improvement (0.0-1.0)
|
|
130
|
+
n_trials: Number of optimization trials run
|
|
131
|
+
best_trial_number: Trial number that produced best result
|
|
132
|
+
study_name: Name of the Optuna study
|
|
133
|
+
metadata: Additional optimization metadata
|
|
134
|
+
"""
|
|
135
|
+
|
|
136
|
+
optimized_thresholds: dict[str, float]
|
|
137
|
+
original_thresholds: dict[str, float]
|
|
138
|
+
original_score: float
|
|
139
|
+
optimized_score: float
|
|
140
|
+
improvement: float
|
|
141
|
+
n_trials: int
|
|
142
|
+
best_trial_number: int
|
|
143
|
+
study_name: str
|
|
144
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
145
|
+
|
|
146
|
+
@property
|
|
147
|
+
def improved(self) -> bool:
|
|
148
|
+
"""Whether the optimization improved the thresholds."""
|
|
149
|
+
return self.optimized_score > self.original_score
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
# Cache for LLM judge results to avoid redundant calls
|
|
153
|
+
_judge_cache: dict[str, bool] = {}
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def _compute_cache_key(
|
|
157
|
+
question1: str, context1: str, question2: str, context2: str
|
|
158
|
+
) -> str:
|
|
159
|
+
"""Compute a cache key for judge results."""
|
|
160
|
+
content = f"{question1}|{context1}|{question2}|{context2}"
|
|
161
|
+
return hashlib.sha256(content.encode()).hexdigest()
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def semantic_match_judge(
|
|
165
|
+
question1: str,
|
|
166
|
+
context1: str,
|
|
167
|
+
question2: str,
|
|
168
|
+
context2: str,
|
|
169
|
+
model: LLMModel | str,
|
|
170
|
+
use_cache: bool = True,
|
|
171
|
+
) -> bool:
|
|
172
|
+
"""
|
|
173
|
+
Use LLM to determine if two question/context pairs are semantically equivalent.
|
|
174
|
+
|
|
175
|
+
This function acts as a judge to determine whether two questions with their
|
|
176
|
+
respective conversation contexts are asking for the same information and
|
|
177
|
+
would expect the same SQL query response.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
question1: First question
|
|
181
|
+
context1: Conversation context for first question
|
|
182
|
+
question2: Second question
|
|
183
|
+
context2: Conversation context for second question
|
|
184
|
+
model: LLM model to use for judging
|
|
185
|
+
use_cache: Whether to cache results (default True)
|
|
186
|
+
|
|
187
|
+
Returns:
|
|
188
|
+
True if the pairs are semantically equivalent, False otherwise
|
|
189
|
+
"""
|
|
190
|
+
global _judge_cache
|
|
191
|
+
|
|
192
|
+
# Check cache first
|
|
193
|
+
if use_cache:
|
|
194
|
+
cache_key = _compute_cache_key(question1, context1, question2, context2)
|
|
195
|
+
if cache_key in _judge_cache:
|
|
196
|
+
return _judge_cache[cache_key]
|
|
197
|
+
|
|
198
|
+
# Convert model to LLMModel if string
|
|
199
|
+
llm_model: LLMModel = LLMModel(name=model) if isinstance(model, str) else model
|
|
200
|
+
|
|
201
|
+
# Create the chat model
|
|
202
|
+
chat = llm_model.as_chat_model()
|
|
203
|
+
|
|
204
|
+
# Construct the prompt for semantic equivalence judgment
|
|
205
|
+
prompt = f"""You are an expert at determining semantic equivalence between database queries.
|
|
206
|
+
|
|
207
|
+
Given two question-context pairs, determine if they are semantically equivalent - meaning they are asking for the same information and would expect the same SQL query result.
|
|
208
|
+
|
|
209
|
+
Consider:
|
|
210
|
+
1. Are both questions asking for the same data/information?
|
|
211
|
+
2. Do the conversation contexts provide similar filtering or constraints?
|
|
212
|
+
3. Would answering both require the same SQL query?
|
|
213
|
+
|
|
214
|
+
IMPORTANT: Be strict. Only return "MATCH" if the questions are truly asking for the same thing in the same context. Similar but different questions should return "NO_MATCH".
|
|
215
|
+
|
|
216
|
+
Question 1: {question1}
|
|
217
|
+
Context 1: {context1 if context1 else "(no context)"}
|
|
218
|
+
|
|
219
|
+
Question 2: {question2}
|
|
220
|
+
Context 2: {context2 if context2 else "(no context)"}
|
|
221
|
+
|
|
222
|
+
Respond with ONLY one word: "MATCH" or "NO_MATCH"
|
|
223
|
+
"""
|
|
224
|
+
|
|
225
|
+
try:
|
|
226
|
+
response = chat.invoke(prompt)
|
|
227
|
+
result_text = response.content.strip().upper()
|
|
228
|
+
is_match = "MATCH" in result_text and "NO_MATCH" not in result_text
|
|
229
|
+
|
|
230
|
+
# Cache the result
|
|
231
|
+
if use_cache:
|
|
232
|
+
_judge_cache[cache_key] = is_match
|
|
233
|
+
|
|
234
|
+
logger.trace(
|
|
235
|
+
"LLM judge result",
|
|
236
|
+
question1=question1[:50],
|
|
237
|
+
question2=question2[:50],
|
|
238
|
+
is_match=is_match,
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
return is_match
|
|
242
|
+
|
|
243
|
+
except Exception as e:
|
|
244
|
+
logger.warning("LLM judge failed", error=str(e))
|
|
245
|
+
# Default to not matching on error (conservative)
|
|
246
|
+
return False
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def _compute_l2_similarity(embedding1: list[float], embedding2: list[float]) -> float:
|
|
250
|
+
"""
|
|
251
|
+
Compute similarity from L2 (Euclidean) distance.
|
|
252
|
+
|
|
253
|
+
Uses the same formula as the semantic cache:
|
|
254
|
+
similarity = 1.0 / (1.0 + L2_distance)
|
|
255
|
+
|
|
256
|
+
This gives a value in range [0, 1] where 1 means identical.
|
|
257
|
+
|
|
258
|
+
Args:
|
|
259
|
+
embedding1: First embedding vector
|
|
260
|
+
embedding2: Second embedding vector
|
|
261
|
+
|
|
262
|
+
Returns:
|
|
263
|
+
Similarity score in range [0, 1]
|
|
264
|
+
"""
|
|
265
|
+
if len(embedding1) != len(embedding2):
|
|
266
|
+
raise ValueError(
|
|
267
|
+
f"Embedding dimensions must match: {len(embedding1)} vs {len(embedding2)}"
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
# Compute L2 distance
|
|
271
|
+
squared_diff_sum = sum(
|
|
272
|
+
(a - b) ** 2 for a, b in zip(embedding1, embedding2, strict=True)
|
|
273
|
+
)
|
|
274
|
+
l2_distance = math.sqrt(squared_diff_sum)
|
|
275
|
+
|
|
276
|
+
# Convert to similarity
|
|
277
|
+
similarity = 1.0 / (1.0 + l2_distance)
|
|
278
|
+
return similarity
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
def _evaluate_thresholds(
|
|
282
|
+
dataset: ContextAwareCacheEvalDataset,
|
|
283
|
+
similarity_threshold: float,
|
|
284
|
+
context_similarity_threshold: float,
|
|
285
|
+
question_weight: float,
|
|
286
|
+
judge_model: LLMModel | str | None = None,
|
|
287
|
+
) -> tuple[float, float, float, dict[str, int]]:
|
|
288
|
+
"""
|
|
289
|
+
Evaluate a set of thresholds against the dataset.
|
|
290
|
+
|
|
291
|
+
Args:
|
|
292
|
+
dataset: Evaluation dataset
|
|
293
|
+
similarity_threshold: Threshold for question similarity
|
|
294
|
+
context_similarity_threshold: Threshold for context similarity
|
|
295
|
+
question_weight: Weight for question in combined score
|
|
296
|
+
judge_model: Optional LLM model for judging unlabeled entries
|
|
297
|
+
|
|
298
|
+
Returns:
|
|
299
|
+
Tuple of (precision, recall, f1, confusion_matrix_dict)
|
|
300
|
+
"""
|
|
301
|
+
true_positives = 0
|
|
302
|
+
false_positives = 0
|
|
303
|
+
true_negatives = 0
|
|
304
|
+
false_negatives = 0
|
|
305
|
+
|
|
306
|
+
# Note: context_weight = 1.0 - question_weight could be used for weighted scoring
|
|
307
|
+
# but we currently use independent thresholds for question and context similarity
|
|
308
|
+
|
|
309
|
+
for entry in dataset.entries:
|
|
310
|
+
# Compute similarities
|
|
311
|
+
question_sim = _compute_l2_similarity(
|
|
312
|
+
entry.question_embedding, entry.cached_question_embedding
|
|
313
|
+
)
|
|
314
|
+
context_sim = _compute_l2_similarity(
|
|
315
|
+
entry.context_embedding, entry.cached_context_embedding
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
# Apply threshold logic (same as production cache)
|
|
319
|
+
predicted_match = (
|
|
320
|
+
question_sim >= similarity_threshold
|
|
321
|
+
and context_sim >= context_similarity_threshold
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
# Get expected match
|
|
325
|
+
expected_match = entry.expected_match
|
|
326
|
+
if expected_match is None:
|
|
327
|
+
if judge_model is None:
|
|
328
|
+
# Skip entries without labels if no judge provided
|
|
329
|
+
continue
|
|
330
|
+
# Use LLM judge to determine expected match
|
|
331
|
+
expected_match = semantic_match_judge(
|
|
332
|
+
entry.question,
|
|
333
|
+
entry.context,
|
|
334
|
+
entry.cached_question,
|
|
335
|
+
entry.cached_context,
|
|
336
|
+
judge_model,
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
# Update confusion matrix
|
|
340
|
+
if predicted_match and expected_match:
|
|
341
|
+
true_positives += 1
|
|
342
|
+
elif predicted_match and not expected_match:
|
|
343
|
+
false_positives += 1
|
|
344
|
+
elif not predicted_match and expected_match:
|
|
345
|
+
false_negatives += 1
|
|
346
|
+
else:
|
|
347
|
+
true_negatives += 1
|
|
348
|
+
|
|
349
|
+
# Calculate metrics
|
|
350
|
+
total = true_positives + false_positives + true_negatives + false_negatives
|
|
351
|
+
|
|
352
|
+
precision = (
|
|
353
|
+
true_positives / (true_positives + false_positives)
|
|
354
|
+
if (true_positives + false_positives) > 0
|
|
355
|
+
else 0.0
|
|
356
|
+
)
|
|
357
|
+
recall = (
|
|
358
|
+
true_positives / (true_positives + false_negatives)
|
|
359
|
+
if (true_positives + false_negatives) > 0
|
|
360
|
+
else 0.0
|
|
361
|
+
)
|
|
362
|
+
f1 = (
|
|
363
|
+
2 * precision * recall / (precision + recall)
|
|
364
|
+
if (precision + recall) > 0
|
|
365
|
+
else 0.0
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
confusion = {
|
|
369
|
+
"true_positives": true_positives,
|
|
370
|
+
"false_positives": false_positives,
|
|
371
|
+
"true_negatives": true_negatives,
|
|
372
|
+
"false_negatives": false_negatives,
|
|
373
|
+
"total": total,
|
|
374
|
+
}
|
|
375
|
+
|
|
376
|
+
return precision, recall, f1, confusion
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
def _create_objective(
|
|
380
|
+
dataset: ContextAwareCacheEvalDataset,
|
|
381
|
+
judge_model: LLMModel | str | None,
|
|
382
|
+
metric: Literal["f1", "precision", "recall", "fbeta"],
|
|
383
|
+
beta: float = 1.0,
|
|
384
|
+
) -> Callable[["optuna.Trial"], float]:
|
|
385
|
+
"""Create the Optuna objective function."""
|
|
386
|
+
|
|
387
|
+
def objective(trial: "optuna.Trial") -> float:
|
|
388
|
+
# Sample parameters
|
|
389
|
+
similarity_threshold = trial.suggest_float(
|
|
390
|
+
"similarity_threshold", 0.5, 0.99, log=False
|
|
391
|
+
)
|
|
392
|
+
context_similarity_threshold = trial.suggest_float(
|
|
393
|
+
"context_similarity_threshold", 0.5, 0.99, log=False
|
|
394
|
+
)
|
|
395
|
+
question_weight = trial.suggest_float("question_weight", 0.1, 0.9, log=False)
|
|
396
|
+
|
|
397
|
+
# Evaluate
|
|
398
|
+
precision, recall, f1, confusion = _evaluate_thresholds(
|
|
399
|
+
dataset=dataset,
|
|
400
|
+
similarity_threshold=similarity_threshold,
|
|
401
|
+
context_similarity_threshold=context_similarity_threshold,
|
|
402
|
+
question_weight=question_weight,
|
|
403
|
+
judge_model=judge_model,
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
# Log intermediate results
|
|
407
|
+
trial.set_user_attr("precision", precision)
|
|
408
|
+
trial.set_user_attr("recall", recall)
|
|
409
|
+
trial.set_user_attr("f1", f1)
|
|
410
|
+
trial.set_user_attr("confusion", confusion)
|
|
411
|
+
|
|
412
|
+
# Return selected metric
|
|
413
|
+
if metric == "f1":
|
|
414
|
+
return f1
|
|
415
|
+
elif metric == "precision":
|
|
416
|
+
return precision
|
|
417
|
+
elif metric == "recall":
|
|
418
|
+
return recall
|
|
419
|
+
elif metric == "fbeta":
|
|
420
|
+
# F-beta score: (1 + beta^2) * (precision * recall) / (beta^2 * precision + recall)
|
|
421
|
+
if precision + recall == 0:
|
|
422
|
+
return 0.0
|
|
423
|
+
fbeta = (
|
|
424
|
+
(1 + beta**2) * (precision * recall) / (beta**2 * precision + recall)
|
|
425
|
+
)
|
|
426
|
+
return fbeta
|
|
427
|
+
else:
|
|
428
|
+
raise ValueError(f"Unknown metric: {metric}")
|
|
429
|
+
|
|
430
|
+
return objective
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
def optimize_context_aware_cache_thresholds(
|
|
434
|
+
dataset: ContextAwareCacheEvalDataset,
|
|
435
|
+
original_thresholds: dict[str, float]
|
|
436
|
+
| GenieContextAwareCacheParametersModel
|
|
437
|
+
| None = None,
|
|
438
|
+
judge_model: LLMModel | str = "databricks-meta-llama-3-3-70b-instruct",
|
|
439
|
+
n_trials: int = 50,
|
|
440
|
+
metric: Literal["f1", "precision", "recall", "fbeta"] = "f1",
|
|
441
|
+
beta: float = 1.0,
|
|
442
|
+
register_if_improved: bool = True,
|
|
443
|
+
study_name: str | None = None,
|
|
444
|
+
seed: int | None = None,
|
|
445
|
+
show_progress_bar: bool = True,
|
|
446
|
+
) -> ThresholdOptimizationResult:
|
|
447
|
+
"""
|
|
448
|
+
Optimize semantic cache thresholds using Bayesian optimization.
|
|
449
|
+
|
|
450
|
+
Uses Optuna's Tree-structured Parzen Estimator (TPE) to efficiently
|
|
451
|
+
search the parameter space and find optimal threshold values.
|
|
452
|
+
|
|
453
|
+
Args:
|
|
454
|
+
dataset: Evaluation dataset with question/context pairs
|
|
455
|
+
original_thresholds: Original thresholds to compare against.
|
|
456
|
+
Can be a dict or GenieContextAwareCacheParametersModel.
|
|
457
|
+
If None, uses default values.
|
|
458
|
+
judge_model: LLM model for semantic match judging (for unlabeled entries)
|
|
459
|
+
n_trials: Number of optimization trials to run
|
|
460
|
+
metric: Optimization metric ("f1", "precision", "recall", "fbeta")
|
|
461
|
+
beta: Beta value for fbeta metric (higher = favor recall)
|
|
462
|
+
register_if_improved: Log results to MLflow if improved
|
|
463
|
+
study_name: Optional name for the Optuna study
|
|
464
|
+
seed: Random seed for reproducibility
|
|
465
|
+
show_progress_bar: Whether to show progress bar during optimization
|
|
466
|
+
|
|
467
|
+
Returns:
|
|
468
|
+
ThresholdOptimizationResult with optimized thresholds and metrics
|
|
469
|
+
|
|
470
|
+
Example:
|
|
471
|
+
from dao_ai.genie.cache.context_aware.optimization import (
|
|
472
|
+
optimize_context_aware_cache_thresholds,
|
|
473
|
+
ContextAwareCacheEvalDataset,
|
|
474
|
+
)
|
|
475
|
+
|
|
476
|
+
result = optimize_context_aware_cache_thresholds(
|
|
477
|
+
dataset=my_dataset,
|
|
478
|
+
judge_model="databricks-meta-llama-3-3-70b-instruct",
|
|
479
|
+
n_trials=50,
|
|
480
|
+
metric="f1",
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
if result.improved:
|
|
484
|
+
print(f"New thresholds: {result.optimized_thresholds}")
|
|
485
|
+
"""
|
|
486
|
+
# Lazy import optuna - only loaded when optimization is actually called
|
|
487
|
+
# This allows the cache module to be imported without optuna installed
|
|
488
|
+
try:
|
|
489
|
+
import optuna
|
|
490
|
+
from optuna.samplers import TPESampler
|
|
491
|
+
except ImportError as e:
|
|
492
|
+
raise ImportError(
|
|
493
|
+
"optuna is required for cache threshold optimization. "
|
|
494
|
+
"Install it with: pip install optuna"
|
|
495
|
+
) from e
|
|
496
|
+
|
|
497
|
+
# Optional MLflow integration - requires optuna-integration[mlflow]
|
|
498
|
+
try:
|
|
499
|
+
from optuna.integration import MLflowCallback
|
|
500
|
+
|
|
501
|
+
mlflow_callback_available = True
|
|
502
|
+
except ModuleNotFoundError:
|
|
503
|
+
mlflow_callback_available = False
|
|
504
|
+
MLflowCallback = None # type: ignore
|
|
505
|
+
|
|
506
|
+
logger.info(
|
|
507
|
+
"Starting semantic cache threshold optimization",
|
|
508
|
+
dataset_name=dataset.name,
|
|
509
|
+
dataset_size=len(dataset),
|
|
510
|
+
n_trials=n_trials,
|
|
511
|
+
metric=metric,
|
|
512
|
+
)
|
|
513
|
+
|
|
514
|
+
# Parse original thresholds
|
|
515
|
+
if original_thresholds is None:
|
|
516
|
+
orig_thresholds = {
|
|
517
|
+
"similarity_threshold": 0.85,
|
|
518
|
+
"context_similarity_threshold": 0.80,
|
|
519
|
+
"question_weight": 0.6,
|
|
520
|
+
}
|
|
521
|
+
elif isinstance(original_thresholds, GenieContextAwareCacheParametersModel):
|
|
522
|
+
orig_thresholds = {
|
|
523
|
+
"similarity_threshold": original_thresholds.similarity_threshold,
|
|
524
|
+
"context_similarity_threshold": original_thresholds.context_similarity_threshold,
|
|
525
|
+
"question_weight": original_thresholds.question_weight or 0.6,
|
|
526
|
+
}
|
|
527
|
+
else:
|
|
528
|
+
orig_thresholds = original_thresholds.copy()
|
|
529
|
+
|
|
530
|
+
# Evaluate original thresholds
|
|
531
|
+
orig_precision, orig_recall, orig_f1, orig_confusion = _evaluate_thresholds(
|
|
532
|
+
dataset=dataset,
|
|
533
|
+
similarity_threshold=orig_thresholds["similarity_threshold"],
|
|
534
|
+
context_similarity_threshold=orig_thresholds["context_similarity_threshold"],
|
|
535
|
+
question_weight=orig_thresholds["question_weight"],
|
|
536
|
+
judge_model=judge_model,
|
|
537
|
+
)
|
|
538
|
+
|
|
539
|
+
# Determine original score based on metric
|
|
540
|
+
if metric == "f1":
|
|
541
|
+
original_score = orig_f1
|
|
542
|
+
elif metric == "precision":
|
|
543
|
+
original_score = orig_precision
|
|
544
|
+
elif metric == "recall":
|
|
545
|
+
original_score = orig_recall
|
|
546
|
+
elif metric == "fbeta":
|
|
547
|
+
if orig_precision + orig_recall == 0:
|
|
548
|
+
original_score = 0.0
|
|
549
|
+
else:
|
|
550
|
+
original_score = (
|
|
551
|
+
(1 + beta**2)
|
|
552
|
+
* (orig_precision * orig_recall)
|
|
553
|
+
/ (beta**2 * orig_precision + orig_recall)
|
|
554
|
+
)
|
|
555
|
+
else:
|
|
556
|
+
raise ValueError(f"Unknown metric: {metric}")
|
|
557
|
+
|
|
558
|
+
logger.info(
|
|
559
|
+
"Evaluated original thresholds",
|
|
560
|
+
precision=f"{orig_precision:.4f}",
|
|
561
|
+
recall=f"{orig_recall:.4f}",
|
|
562
|
+
f1=f"{orig_f1:.4f}",
|
|
563
|
+
original_score=f"{original_score:.4f}",
|
|
564
|
+
)
|
|
565
|
+
|
|
566
|
+
# Create study name if not provided
|
|
567
|
+
if study_name is None:
|
|
568
|
+
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
|
|
569
|
+
study_name = f"context_aware_cache_threshold_optimization_{timestamp}"
|
|
570
|
+
|
|
571
|
+
# Create Optuna study
|
|
572
|
+
sampler = TPESampler(seed=seed)
|
|
573
|
+
study = optuna.create_study(
|
|
574
|
+
study_name=study_name,
|
|
575
|
+
direction="maximize",
|
|
576
|
+
sampler=sampler,
|
|
577
|
+
)
|
|
578
|
+
|
|
579
|
+
# Add original thresholds as first trial for comparison
|
|
580
|
+
study.enqueue_trial(orig_thresholds)
|
|
581
|
+
|
|
582
|
+
# Create objective function
|
|
583
|
+
objective = _create_objective(
|
|
584
|
+
dataset=dataset,
|
|
585
|
+
judge_model=judge_model,
|
|
586
|
+
metric=metric,
|
|
587
|
+
beta=beta,
|
|
588
|
+
)
|
|
589
|
+
|
|
590
|
+
# Set up MLflow callback if available
|
|
591
|
+
callbacks = []
|
|
592
|
+
if mlflow_callback_available and MLflowCallback is not None:
|
|
593
|
+
try:
|
|
594
|
+
mlflow_callback = MLflowCallback(
|
|
595
|
+
tracking_uri=mlflow.get_tracking_uri(),
|
|
596
|
+
metric_name=metric,
|
|
597
|
+
create_experiment=False,
|
|
598
|
+
)
|
|
599
|
+
callbacks.append(mlflow_callback)
|
|
600
|
+
except Exception as e:
|
|
601
|
+
logger.debug("MLflow callback not available", error=str(e))
|
|
602
|
+
|
|
603
|
+
# Run optimization
|
|
604
|
+
study.optimize(
|
|
605
|
+
objective,
|
|
606
|
+
n_trials=n_trials,
|
|
607
|
+
show_progress_bar=show_progress_bar,
|
|
608
|
+
callbacks=callbacks if callbacks else None,
|
|
609
|
+
)
|
|
610
|
+
|
|
611
|
+
# Get best trial
|
|
612
|
+
best_trial = study.best_trial
|
|
613
|
+
best_thresholds = {
|
|
614
|
+
"similarity_threshold": best_trial.params["similarity_threshold"],
|
|
615
|
+
"context_similarity_threshold": best_trial.params[
|
|
616
|
+
"context_similarity_threshold"
|
|
617
|
+
],
|
|
618
|
+
"question_weight": best_trial.params["question_weight"],
|
|
619
|
+
}
|
|
620
|
+
|
|
621
|
+
# Get metrics from best trial
|
|
622
|
+
best_precision = best_trial.user_attrs.get("precision", 0.0)
|
|
623
|
+
best_recall = best_trial.user_attrs.get("recall", 0.0)
|
|
624
|
+
best_f1 = best_trial.user_attrs.get("f1", 0.0)
|
|
625
|
+
best_confusion = best_trial.user_attrs.get("confusion", {})
|
|
626
|
+
optimized_score = best_trial.value
|
|
627
|
+
|
|
628
|
+
# Calculate improvement
|
|
629
|
+
improvement = (
|
|
630
|
+
(optimized_score - original_score) / original_score
|
|
631
|
+
if original_score > 0
|
|
632
|
+
else 0.0
|
|
633
|
+
)
|
|
634
|
+
|
|
635
|
+
logger.success(
|
|
636
|
+
"Optimization complete",
|
|
637
|
+
best_trial_number=best_trial.number,
|
|
638
|
+
original_score=f"{original_score:.4f}",
|
|
639
|
+
optimized_score=f"{optimized_score:.4f}",
|
|
640
|
+
improvement=f"{improvement:.1%}",
|
|
641
|
+
best_thresholds=best_thresholds,
|
|
642
|
+
)
|
|
643
|
+
|
|
644
|
+
# Log to MLflow if improved
|
|
645
|
+
if register_if_improved and improvement > 0:
|
|
646
|
+
try:
|
|
647
|
+
_log_optimization_to_mlflow(
|
|
648
|
+
study_name=study_name,
|
|
649
|
+
dataset_name=dataset.name,
|
|
650
|
+
dataset_size=len(dataset),
|
|
651
|
+
original_thresholds=orig_thresholds,
|
|
652
|
+
optimized_thresholds=best_thresholds,
|
|
653
|
+
original_score=original_score,
|
|
654
|
+
optimized_score=optimized_score,
|
|
655
|
+
improvement=improvement,
|
|
656
|
+
metric=metric,
|
|
657
|
+
n_trials=n_trials,
|
|
658
|
+
best_precision=best_precision,
|
|
659
|
+
best_recall=best_recall,
|
|
660
|
+
best_f1=best_f1,
|
|
661
|
+
best_confusion=best_confusion,
|
|
662
|
+
judge_model=judge_model,
|
|
663
|
+
)
|
|
664
|
+
except Exception as e:
|
|
665
|
+
logger.warning("Failed to log optimization to MLflow", error=str(e))
|
|
666
|
+
|
|
667
|
+
# Build result
|
|
668
|
+
result = ThresholdOptimizationResult(
|
|
669
|
+
optimized_thresholds=best_thresholds,
|
|
670
|
+
original_thresholds=orig_thresholds,
|
|
671
|
+
original_score=original_score,
|
|
672
|
+
optimized_score=optimized_score,
|
|
673
|
+
improvement=improvement,
|
|
674
|
+
n_trials=n_trials,
|
|
675
|
+
best_trial_number=best_trial.number,
|
|
676
|
+
study_name=study_name,
|
|
677
|
+
metadata={
|
|
678
|
+
"metric": metric,
|
|
679
|
+
"beta": beta if metric == "fbeta" else None,
|
|
680
|
+
"judge_model": str(judge_model),
|
|
681
|
+
"dataset_name": dataset.name,
|
|
682
|
+
"dataset_size": len(dataset),
|
|
683
|
+
"original_precision": orig_precision,
|
|
684
|
+
"original_recall": orig_recall,
|
|
685
|
+
"original_f1": orig_f1,
|
|
686
|
+
"optimized_precision": best_precision,
|
|
687
|
+
"optimized_recall": best_recall,
|
|
688
|
+
"optimized_f1": best_f1,
|
|
689
|
+
"confusion_matrix": best_confusion,
|
|
690
|
+
},
|
|
691
|
+
)
|
|
692
|
+
|
|
693
|
+
return result
|
|
694
|
+
|
|
695
|
+
|
|
696
|
+
def _log_optimization_to_mlflow(
|
|
697
|
+
study_name: str,
|
|
698
|
+
dataset_name: str,
|
|
699
|
+
dataset_size: int,
|
|
700
|
+
original_thresholds: dict[str, float],
|
|
701
|
+
optimized_thresholds: dict[str, float],
|
|
702
|
+
original_score: float,
|
|
703
|
+
optimized_score: float,
|
|
704
|
+
improvement: float,
|
|
705
|
+
metric: str,
|
|
706
|
+
n_trials: int,
|
|
707
|
+
best_precision: float,
|
|
708
|
+
best_recall: float,
|
|
709
|
+
best_f1: float,
|
|
710
|
+
best_confusion: dict[str, int],
|
|
711
|
+
judge_model: LLMModel | str,
|
|
712
|
+
) -> None:
|
|
713
|
+
"""Log optimization results to MLflow."""
|
|
714
|
+
with mlflow.start_run(run_name=study_name):
|
|
715
|
+
# Log parameters
|
|
716
|
+
mlflow.log_params(
|
|
717
|
+
{
|
|
718
|
+
"optimizer": "optuna_tpe",
|
|
719
|
+
"metric": metric,
|
|
720
|
+
"n_trials": n_trials,
|
|
721
|
+
"dataset_name": dataset_name,
|
|
722
|
+
"dataset_size": dataset_size,
|
|
723
|
+
"judge_model": str(judge_model),
|
|
724
|
+
"dao_ai_version": dao_ai_version(),
|
|
725
|
+
}
|
|
726
|
+
)
|
|
727
|
+
|
|
728
|
+
# Log original thresholds
|
|
729
|
+
for key, value in original_thresholds.items():
|
|
730
|
+
mlflow.log_param(f"original_{key}", value)
|
|
731
|
+
|
|
732
|
+
# Log optimized thresholds
|
|
733
|
+
for key, value in optimized_thresholds.items():
|
|
734
|
+
mlflow.log_param(f"optimized_{key}", value)
|
|
735
|
+
|
|
736
|
+
# Log metrics
|
|
737
|
+
mlflow.log_metrics(
|
|
738
|
+
{
|
|
739
|
+
"original_score": original_score,
|
|
740
|
+
"optimized_score": optimized_score,
|
|
741
|
+
"improvement": improvement,
|
|
742
|
+
"precision": best_precision,
|
|
743
|
+
"recall": best_recall,
|
|
744
|
+
"f1": best_f1,
|
|
745
|
+
**{f"confusion_{k}": v for k, v in best_confusion.items()},
|
|
746
|
+
}
|
|
747
|
+
)
|
|
748
|
+
|
|
749
|
+
# Log thresholds as artifact
|
|
750
|
+
thresholds_artifact = {
|
|
751
|
+
"study_name": study_name,
|
|
752
|
+
"original": original_thresholds,
|
|
753
|
+
"optimized": optimized_thresholds,
|
|
754
|
+
"improvement": improvement,
|
|
755
|
+
"metric": metric,
|
|
756
|
+
}
|
|
757
|
+
mlflow.log_dict(thresholds_artifact, "optimized_thresholds.json")
|
|
758
|
+
|
|
759
|
+
logger.info(
|
|
760
|
+
"Logged optimization results to MLflow",
|
|
761
|
+
study_name=study_name,
|
|
762
|
+
improvement=f"{improvement:.1%}",
|
|
763
|
+
)
|
|
764
|
+
|
|
765
|
+
|
|
766
|
+
def generate_eval_dataset_from_cache(
|
|
767
|
+
cache_entries: Sequence[dict[str, Any]],
|
|
768
|
+
embedding_model: LLMModel | str = "databricks-gte-large-en",
|
|
769
|
+
num_positive_pairs: int = 50,
|
|
770
|
+
num_negative_pairs: int = 50,
|
|
771
|
+
paraphrase_model: LLMModel | str | None = None,
|
|
772
|
+
dataset_name: str = "generated_eval_dataset",
|
|
773
|
+
) -> ContextAwareCacheEvalDataset:
|
|
774
|
+
"""
|
|
775
|
+
Generate an evaluation dataset from existing cache entries.
|
|
776
|
+
|
|
777
|
+
Creates positive pairs (semantically equivalent questions) using LLM paraphrasing
|
|
778
|
+
and negative pairs (different questions) from random cache entry pairs.
|
|
779
|
+
|
|
780
|
+
Args:
|
|
781
|
+
cache_entries: List of cache entries with 'question', 'conversation_context',
|
|
782
|
+
'question_embedding', and 'context_embedding' keys. Use cache.get_entries()
|
|
783
|
+
with include_embeddings=True to retrieve these.
|
|
784
|
+
embedding_model: Model for generating embeddings for paraphrased questions
|
|
785
|
+
num_positive_pairs: Number of positive (matching) pairs to generate
|
|
786
|
+
num_negative_pairs: Number of negative (non-matching) pairs to generate
|
|
787
|
+
paraphrase_model: LLM for generating paraphrases (defaults to embedding_model)
|
|
788
|
+
dataset_name: Name for the generated dataset
|
|
789
|
+
|
|
790
|
+
Returns:
|
|
791
|
+
ContextAwareCacheEvalDataset with generated entries
|
|
792
|
+
|
|
793
|
+
Example:
|
|
794
|
+
# Get entries from cache with embeddings
|
|
795
|
+
entries = cache_service.get_entries(include_embeddings=True, limit=100)
|
|
796
|
+
|
|
797
|
+
# Generate evaluation dataset
|
|
798
|
+
eval_dataset = generate_eval_dataset_from_cache(
|
|
799
|
+
cache_entries=entries,
|
|
800
|
+
num_positive_pairs=50,
|
|
801
|
+
num_negative_pairs=50,
|
|
802
|
+
dataset_name="my_cache_eval",
|
|
803
|
+
)
|
|
804
|
+
"""
|
|
805
|
+
import random
|
|
806
|
+
|
|
807
|
+
if len(cache_entries) < 2:
|
|
808
|
+
raise ValueError("Need at least 2 cache entries to generate dataset")
|
|
809
|
+
|
|
810
|
+
# Convert embedding model
|
|
811
|
+
emb_model: LLMModel = (
|
|
812
|
+
LLMModel(name=embedding_model)
|
|
813
|
+
if isinstance(embedding_model, str)
|
|
814
|
+
else embedding_model
|
|
815
|
+
)
|
|
816
|
+
embeddings = emb_model.as_embeddings_model()
|
|
817
|
+
|
|
818
|
+
# Use paraphrase model or default to a capable LLM
|
|
819
|
+
para_model: LLMModel = (
|
|
820
|
+
LLMModel(name=paraphrase_model)
|
|
821
|
+
if isinstance(paraphrase_model, str)
|
|
822
|
+
else (
|
|
823
|
+
paraphrase_model
|
|
824
|
+
if paraphrase_model
|
|
825
|
+
else LLMModel(name="databricks-meta-llama-3-3-70b-instruct")
|
|
826
|
+
)
|
|
827
|
+
)
|
|
828
|
+
chat = para_model.as_chat_model()
|
|
829
|
+
|
|
830
|
+
entries: list[ContextAwareCacheEvalEntry] = []
|
|
831
|
+
|
|
832
|
+
# Generate positive pairs (paraphrases)
|
|
833
|
+
logger.info(
|
|
834
|
+
"Generating positive pairs using paraphrasing", count=num_positive_pairs
|
|
835
|
+
)
|
|
836
|
+
|
|
837
|
+
for i in range(min(num_positive_pairs, len(cache_entries))):
|
|
838
|
+
entry = cache_entries[i % len(cache_entries)]
|
|
839
|
+
original_question = entry.get("question", "")
|
|
840
|
+
original_context = entry.get("conversation_context", "")
|
|
841
|
+
original_q_emb = entry.get("question_embedding", [])
|
|
842
|
+
original_c_emb = entry.get("context_embedding", [])
|
|
843
|
+
|
|
844
|
+
if not original_question or not original_q_emb:
|
|
845
|
+
continue
|
|
846
|
+
|
|
847
|
+
# Generate paraphrase
|
|
848
|
+
try:
|
|
849
|
+
paraphrase_prompt = f"""Rephrase the following question to ask the same thing but using different words.
|
|
850
|
+
Keep the same meaning and intent. Only output the rephrased question, nothing else.
|
|
851
|
+
|
|
852
|
+
Original question: {original_question}
|
|
853
|
+
|
|
854
|
+
Rephrased question:"""
|
|
855
|
+
response = chat.invoke(paraphrase_prompt)
|
|
856
|
+
paraphrased_question = response.content.strip()
|
|
857
|
+
|
|
858
|
+
# Generate embedding for paraphrase
|
|
859
|
+
para_q_emb = embeddings.embed_query(paraphrased_question)
|
|
860
|
+
para_c_emb = (
|
|
861
|
+
embeddings.embed_query(original_context)
|
|
862
|
+
if original_context
|
|
863
|
+
else original_c_emb
|
|
864
|
+
)
|
|
865
|
+
|
|
866
|
+
entries.append(
|
|
867
|
+
ContextAwareCacheEvalEntry(
|
|
868
|
+
question=paraphrased_question,
|
|
869
|
+
question_embedding=para_q_emb,
|
|
870
|
+
context=original_context,
|
|
871
|
+
context_embedding=para_c_emb,
|
|
872
|
+
cached_question=original_question,
|
|
873
|
+
cached_question_embedding=original_q_emb,
|
|
874
|
+
cached_context=original_context,
|
|
875
|
+
cached_context_embedding=original_c_emb,
|
|
876
|
+
expected_match=True,
|
|
877
|
+
)
|
|
878
|
+
)
|
|
879
|
+
except Exception as e:
|
|
880
|
+
logger.warning("Failed to generate paraphrase", error=str(e))
|
|
881
|
+
|
|
882
|
+
# Generate negative pairs (random different questions)
|
|
883
|
+
logger.info(
|
|
884
|
+
"Generating negative pairs from different cache entries",
|
|
885
|
+
count=num_negative_pairs,
|
|
886
|
+
)
|
|
887
|
+
|
|
888
|
+
for _ in range(num_negative_pairs):
|
|
889
|
+
# Pick two different random entries
|
|
890
|
+
if len(cache_entries) < 2:
|
|
891
|
+
break
|
|
892
|
+
idx1, idx2 = random.sample(range(len(cache_entries)), 2)
|
|
893
|
+
entry1 = cache_entries[idx1]
|
|
894
|
+
entry2 = cache_entries[idx2]
|
|
895
|
+
|
|
896
|
+
# Use entry1 as the "question" and entry2 as the "cached" entry
|
|
897
|
+
entries.append(
|
|
898
|
+
ContextAwareCacheEvalEntry(
|
|
899
|
+
question=entry1.get("question", ""),
|
|
900
|
+
question_embedding=entry1.get("question_embedding", []),
|
|
901
|
+
context=entry1.get("conversation_context", ""),
|
|
902
|
+
context_embedding=entry1.get("context_embedding", []),
|
|
903
|
+
cached_question=entry2.get("question", ""),
|
|
904
|
+
cached_question_embedding=entry2.get("question_embedding", []),
|
|
905
|
+
cached_context=entry2.get("conversation_context", ""),
|
|
906
|
+
cached_context_embedding=entry2.get("context_embedding", []),
|
|
907
|
+
expected_match=False,
|
|
908
|
+
)
|
|
909
|
+
)
|
|
910
|
+
|
|
911
|
+
logger.info(
|
|
912
|
+
"Generated evaluation dataset",
|
|
913
|
+
name=dataset_name,
|
|
914
|
+
total_entries=len(entries),
|
|
915
|
+
positive_pairs=sum(1 for e in entries if e.expected_match is True),
|
|
916
|
+
negative_pairs=sum(1 for e in entries if e.expected_match is False),
|
|
917
|
+
)
|
|
918
|
+
|
|
919
|
+
return ContextAwareCacheEvalDataset(
|
|
920
|
+
name=dataset_name,
|
|
921
|
+
entries=entries,
|
|
922
|
+
description=f"Generated from {len(cache_entries)} cache entries",
|
|
923
|
+
)
|
|
924
|
+
|
|
925
|
+
|
|
926
|
+
def clear_judge_cache() -> None:
|
|
927
|
+
"""Clear the LLM judge result cache."""
|
|
928
|
+
global _judge_cache
|
|
929
|
+
_judge_cache.clear()
|
|
930
|
+
logger.debug("Cleared judge cache")
|