dao-ai 0.1.2__py3-none-any.whl → 0.1.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/apps/__init__.py +24 -0
- dao_ai/apps/handlers.py +105 -0
- dao_ai/apps/model_serving.py +29 -0
- dao_ai/apps/resources.py +1122 -0
- dao_ai/apps/server.py +39 -0
- dao_ai/cli.py +546 -37
- dao_ai/config.py +1179 -139
- dao_ai/evaluation.py +543 -0
- dao_ai/genie/__init__.py +55 -7
- dao_ai/genie/cache/__init__.py +34 -7
- dao_ai/genie/cache/base.py +143 -2
- dao_ai/genie/cache/context_aware/__init__.py +31 -0
- dao_ai/genie/cache/context_aware/base.py +1151 -0
- dao_ai/genie/cache/context_aware/in_memory.py +609 -0
- dao_ai/genie/cache/context_aware/persistent.py +802 -0
- dao_ai/genie/cache/context_aware/postgres.py +1166 -0
- dao_ai/genie/cache/core.py +1 -1
- dao_ai/genie/cache/lru.py +257 -75
- dao_ai/genie/cache/optimization.py +890 -0
- dao_ai/genie/core.py +235 -11
- dao_ai/memory/postgres.py +175 -39
- dao_ai/middleware/__init__.py +38 -0
- dao_ai/middleware/assertions.py +3 -3
- dao_ai/middleware/context_editing.py +230 -0
- dao_ai/middleware/core.py +4 -4
- dao_ai/middleware/guardrails.py +3 -3
- dao_ai/middleware/human_in_the_loop.py +3 -2
- dao_ai/middleware/message_validation.py +4 -4
- dao_ai/middleware/model_call_limit.py +77 -0
- dao_ai/middleware/model_retry.py +121 -0
- dao_ai/middleware/pii.py +157 -0
- dao_ai/middleware/summarization.py +1 -1
- dao_ai/middleware/tool_call_limit.py +210 -0
- dao_ai/middleware/tool_retry.py +174 -0
- dao_ai/middleware/tool_selector.py +129 -0
- dao_ai/models.py +327 -370
- dao_ai/nodes.py +9 -16
- dao_ai/orchestration/core.py +33 -9
- dao_ai/orchestration/supervisor.py +29 -13
- dao_ai/orchestration/swarm.py +6 -1
- dao_ai/{prompts.py → prompts/__init__.py} +12 -61
- dao_ai/prompts/instructed_retriever_decomposition.yaml +58 -0
- dao_ai/prompts/instruction_reranker.yaml +14 -0
- dao_ai/prompts/router.yaml +37 -0
- dao_ai/prompts/verifier.yaml +46 -0
- dao_ai/providers/base.py +28 -2
- dao_ai/providers/databricks.py +363 -33
- dao_ai/state.py +1 -0
- dao_ai/tools/__init__.py +5 -3
- dao_ai/tools/genie.py +103 -26
- dao_ai/tools/instructed_retriever.py +366 -0
- dao_ai/tools/instruction_reranker.py +202 -0
- dao_ai/tools/mcp.py +539 -97
- dao_ai/tools/router.py +89 -0
- dao_ai/tools/slack.py +13 -2
- dao_ai/tools/sql.py +7 -3
- dao_ai/tools/unity_catalog.py +32 -10
- dao_ai/tools/vector_search.py +493 -160
- dao_ai/tools/verifier.py +159 -0
- dao_ai/utils.py +182 -2
- dao_ai/vector_search.py +46 -1
- {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/METADATA +45 -9
- dao_ai-0.1.20.dist-info/RECORD +89 -0
- dao_ai/agent_as_code.py +0 -22
- dao_ai/genie/cache/semantic.py +0 -970
- dao_ai-0.1.2.dist-info/RECORD +0 -64
- {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/WHEEL +0 -0
- {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,609 @@
|
|
|
1
|
+
"""
|
|
2
|
+
In-memory context-aware Genie cache implementation.
|
|
3
|
+
|
|
4
|
+
This module provides a context-aware cache that stores embeddings and cache entries
|
|
5
|
+
entirely in memory, without requiring external database dependencies like PostgreSQL
|
|
6
|
+
or Databricks Lakebase. It uses L2 distance for similarity search and supports
|
|
7
|
+
dual embedding matching (question + conversation context).
|
|
8
|
+
|
|
9
|
+
Use this when:
|
|
10
|
+
- No external database access is available
|
|
11
|
+
- Single-instance deployments (cache not shared across instances)
|
|
12
|
+
- Cache persistence across restarts is not required
|
|
13
|
+
- Cache sizes are moderate (hundreds to low thousands of entries)
|
|
14
|
+
|
|
15
|
+
For multi-instance deployments or large cache sizes, use PostgresContextAwareGenieService
|
|
16
|
+
with PostgreSQL backend instead.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
from __future__ import annotations
|
|
20
|
+
|
|
21
|
+
from dataclasses import dataclass
|
|
22
|
+
from datetime import datetime, timedelta
|
|
23
|
+
from threading import Lock
|
|
24
|
+
from typing import Any
|
|
25
|
+
|
|
26
|
+
import mlflow
|
|
27
|
+
import numpy as np
|
|
28
|
+
from databricks.sdk import WorkspaceClient
|
|
29
|
+
from databricks_ai_bridge.genie import GenieResponse
|
|
30
|
+
from loguru import logger
|
|
31
|
+
|
|
32
|
+
from dao_ai.config import (
|
|
33
|
+
GenieInMemorySemanticCacheParametersModel,
|
|
34
|
+
WarehouseModel,
|
|
35
|
+
)
|
|
36
|
+
from dao_ai.genie.cache.base import (
|
|
37
|
+
GenieServiceBase,
|
|
38
|
+
SQLCacheEntry,
|
|
39
|
+
)
|
|
40
|
+
from dao_ai.genie.cache.context_aware.base import ContextAwareGenieService
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclass
|
|
44
|
+
class InMemoryCacheEntry:
|
|
45
|
+
"""
|
|
46
|
+
In-memory cache entry storing embeddings and SQL query metadata.
|
|
47
|
+
|
|
48
|
+
This dataclass represents a single cache entry stored in memory, including
|
|
49
|
+
dual embeddings (question + context) for high-precision semantic matching.
|
|
50
|
+
|
|
51
|
+
Uses LRU (Least Recently Used) eviction strategy when capacity is reached.
|
|
52
|
+
|
|
53
|
+
Attributes:
|
|
54
|
+
genie_space_id: The Genie space ID this entry belongs to
|
|
55
|
+
question: The original question text
|
|
56
|
+
conversation_context: Previous conversation context for embedding
|
|
57
|
+
question_embedding: Embedding vector for the question
|
|
58
|
+
context_embedding: Embedding vector for the conversation context
|
|
59
|
+
sql_query: The SQL query to re-execute on cache hit
|
|
60
|
+
description: Description of the query
|
|
61
|
+
conversation_id: The conversation ID where this query originated
|
|
62
|
+
created_at: When the entry was created
|
|
63
|
+
last_accessed_at: Last access time for LRU eviction
|
|
64
|
+
message_id: The original Genie message ID (for feedback on cache hits)
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
genie_space_id: str
|
|
68
|
+
question: str
|
|
69
|
+
conversation_context: str
|
|
70
|
+
question_embedding: list[float]
|
|
71
|
+
context_embedding: list[float]
|
|
72
|
+
sql_query: str
|
|
73
|
+
description: str
|
|
74
|
+
conversation_id: str
|
|
75
|
+
created_at: datetime
|
|
76
|
+
last_accessed_at: datetime # Track last access time for LRU eviction
|
|
77
|
+
message_id: str | None = None # Original Genie message ID for feedback
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def l2_distance(a: list[float], b: list[float]) -> float:
|
|
81
|
+
"""
|
|
82
|
+
Calculate L2 (Euclidean) distance between two embedding vectors.
|
|
83
|
+
|
|
84
|
+
This uses the same distance metric as PostgreSQL pg_vector to ensure
|
|
85
|
+
consistent behavior between in-memory and PostgreSQL caches.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
a: First embedding vector
|
|
89
|
+
b: Second embedding vector
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
L2 distance (0 = identical vectors, larger = more different)
|
|
93
|
+
"""
|
|
94
|
+
return float(np.linalg.norm(np.array(a) - np.array(b)))
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def distance_to_similarity(distance: float) -> float:
|
|
98
|
+
"""
|
|
99
|
+
Convert L2 distance to similarity score in range [0, 1].
|
|
100
|
+
|
|
101
|
+
Uses the formula: similarity = 1.0 / (1.0 + distance)
|
|
102
|
+
This matches the conversion used by PostgreSQL semantic cache.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
distance: L2 distance value
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
Similarity score where 1.0 = perfect match, approaching 0 = very different
|
|
109
|
+
"""
|
|
110
|
+
return 1.0 / (1.0 + distance)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class InMemoryContextAwareGenieService(ContextAwareGenieService):
|
|
114
|
+
"""
|
|
115
|
+
In-memory context-aware caching decorator using dual embeddings for similarity lookup.
|
|
116
|
+
|
|
117
|
+
This service caches the SQL query generated by Genie along with dual embeddings
|
|
118
|
+
(question + conversation context) for high-precision semantic matching. On
|
|
119
|
+
subsequent queries, it performs similarity search to find cached queries that
|
|
120
|
+
match both the question intent AND conversation context.
|
|
121
|
+
|
|
122
|
+
Cache entries are partitioned by genie_space_id to ensure queries from different
|
|
123
|
+
Genie spaces don't return incorrect cache hits.
|
|
124
|
+
|
|
125
|
+
On cache hit, it re-executes the cached SQL using the provided warehouse
|
|
126
|
+
to return fresh data while avoiding the Genie NL-to-SQL translation cost.
|
|
127
|
+
|
|
128
|
+
Example:
|
|
129
|
+
from dao_ai.config import GenieInMemorySemanticCacheParametersModel
|
|
130
|
+
from dao_ai.genie.cache.context_aware import InMemoryContextAwareGenieService
|
|
131
|
+
|
|
132
|
+
cache_params = GenieInMemorySemanticCacheParametersModel(
|
|
133
|
+
warehouse=warehouse_model,
|
|
134
|
+
embedding_model="databricks-gte-large-en",
|
|
135
|
+
time_to_live_seconds=86400, # 24 hours
|
|
136
|
+
similarity_threshold=0.85,
|
|
137
|
+
capacity=1000, # Limit to 1000 entries
|
|
138
|
+
)
|
|
139
|
+
genie = InMemoryContextAwareGenieService(
|
|
140
|
+
impl=GenieService(Genie(space_id="my-space")),
|
|
141
|
+
parameters=cache_params,
|
|
142
|
+
workspace_client=workspace_client,
|
|
143
|
+
).initialize()
|
|
144
|
+
|
|
145
|
+
Thread-safe: Uses a lock to protect cache operations.
|
|
146
|
+
"""
|
|
147
|
+
|
|
148
|
+
impl: GenieServiceBase
|
|
149
|
+
parameters: GenieInMemorySemanticCacheParametersModel
|
|
150
|
+
_workspace_client: WorkspaceClient | None
|
|
151
|
+
name: str
|
|
152
|
+
_embeddings: Any # DatabricksEmbeddings
|
|
153
|
+
_cache: list[InMemoryCacheEntry]
|
|
154
|
+
_lock: Lock
|
|
155
|
+
_embedding_dims: int | None
|
|
156
|
+
_setup_complete: bool
|
|
157
|
+
|
|
158
|
+
def __init__(
|
|
159
|
+
self,
|
|
160
|
+
impl: GenieServiceBase,
|
|
161
|
+
parameters: GenieInMemorySemanticCacheParametersModel,
|
|
162
|
+
workspace_client: WorkspaceClient | None = None,
|
|
163
|
+
name: str | None = None,
|
|
164
|
+
) -> None:
|
|
165
|
+
"""
|
|
166
|
+
Initialize the in-memory context-aware cache service.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
impl: The underlying GenieServiceBase to delegate to on cache miss.
|
|
170
|
+
The space_id will be obtained from impl.space_id.
|
|
171
|
+
parameters: Cache configuration including warehouse, embedding model, and thresholds
|
|
172
|
+
workspace_client: Optional WorkspaceClient for retrieving conversation history.
|
|
173
|
+
If None, conversation context will not be used.
|
|
174
|
+
name: Name for this cache layer (for logging). Defaults to class name.
|
|
175
|
+
"""
|
|
176
|
+
self.impl = impl
|
|
177
|
+
self.parameters = parameters
|
|
178
|
+
self._workspace_client = workspace_client
|
|
179
|
+
self.name = name if name is not None else self.__class__.__name__
|
|
180
|
+
self._embeddings = None
|
|
181
|
+
self._cache = []
|
|
182
|
+
self._lock = Lock()
|
|
183
|
+
self._embedding_dims = None
|
|
184
|
+
self._setup_complete = False
|
|
185
|
+
|
|
186
|
+
def _setup(self) -> None:
|
|
187
|
+
"""Initialize embeddings model lazily."""
|
|
188
|
+
if self._setup_complete:
|
|
189
|
+
return
|
|
190
|
+
|
|
191
|
+
# Initialize embeddings using base class helper
|
|
192
|
+
self._initialize_embeddings(
|
|
193
|
+
self.parameters.embedding_model,
|
|
194
|
+
self.parameters.embedding_dims,
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
self._setup_complete = True
|
|
198
|
+
logger.debug(
|
|
199
|
+
"In-memory context-aware cache initialized",
|
|
200
|
+
layer=self.name,
|
|
201
|
+
space_id=self.space_id,
|
|
202
|
+
dims=self._embedding_dims,
|
|
203
|
+
capacity=self.parameters.capacity,
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
# Property implementations
|
|
207
|
+
@property
|
|
208
|
+
def warehouse(self) -> WarehouseModel:
|
|
209
|
+
"""The warehouse used for executing cached SQL queries."""
|
|
210
|
+
return self.parameters.warehouse
|
|
211
|
+
|
|
212
|
+
@property
|
|
213
|
+
def time_to_live(self) -> timedelta | None:
|
|
214
|
+
"""Time-to-live for cache entries. None means never expires."""
|
|
215
|
+
ttl = self.parameters.time_to_live_seconds
|
|
216
|
+
if ttl is None or ttl < 0:
|
|
217
|
+
return None
|
|
218
|
+
return timedelta(seconds=ttl)
|
|
219
|
+
|
|
220
|
+
@property
|
|
221
|
+
def similarity_threshold(self) -> float:
|
|
222
|
+
"""Minimum similarity for cache hit (using L2 distance converted to similarity)."""
|
|
223
|
+
return self.parameters.similarity_threshold
|
|
224
|
+
|
|
225
|
+
@property
|
|
226
|
+
def context_similarity_threshold(self) -> float:
|
|
227
|
+
"""Minimum similarity for context matching."""
|
|
228
|
+
return self.parameters.context_similarity_threshold
|
|
229
|
+
|
|
230
|
+
@property
|
|
231
|
+
def question_weight(self) -> float:
|
|
232
|
+
"""Weight for question similarity in combined score."""
|
|
233
|
+
return self.parameters.question_weight
|
|
234
|
+
|
|
235
|
+
@property
|
|
236
|
+
def context_weight(self) -> float:
|
|
237
|
+
"""Weight for context similarity in combined score."""
|
|
238
|
+
return self.parameters.context_weight
|
|
239
|
+
|
|
240
|
+
def _embed_question(
|
|
241
|
+
self, question: str, conversation_id: str | None = None
|
|
242
|
+
) -> tuple[list[float], list[float], str]:
|
|
243
|
+
"""
|
|
244
|
+
Generate dual embeddings using Genie API for conversation history.
|
|
245
|
+
|
|
246
|
+
Args:
|
|
247
|
+
question: The question to embed
|
|
248
|
+
conversation_id: Optional conversation ID for retrieving context
|
|
249
|
+
|
|
250
|
+
Returns:
|
|
251
|
+
Tuple of (question_embedding, context_embedding, conversation_context_string)
|
|
252
|
+
"""
|
|
253
|
+
return self._embed_question_with_genie_history(
|
|
254
|
+
question,
|
|
255
|
+
conversation_id,
|
|
256
|
+
self.parameters.context_window_size,
|
|
257
|
+
self.parameters.max_context_tokens,
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
@mlflow.trace(name="semantic_search_in_memory")
|
|
261
|
+
def _find_similar(
|
|
262
|
+
self,
|
|
263
|
+
question: str,
|
|
264
|
+
conversation_context: str,
|
|
265
|
+
question_embedding: list[float],
|
|
266
|
+
context_embedding: list[float],
|
|
267
|
+
conversation_id: str | None = None,
|
|
268
|
+
) -> tuple[SQLCacheEntry, float] | None:
|
|
269
|
+
"""
|
|
270
|
+
Find a semantically similar cached entry using dual embedding matching.
|
|
271
|
+
|
|
272
|
+
Performs linear scan through all cache entries, filtering by space_id and
|
|
273
|
+
calculating L2 distances for similarity matching.
|
|
274
|
+
|
|
275
|
+
Args:
|
|
276
|
+
question: The original question (for logging)
|
|
277
|
+
conversation_context: The conversation context string
|
|
278
|
+
question_embedding: The embedding vector of just the question
|
|
279
|
+
context_embedding: The embedding vector of the conversation context
|
|
280
|
+
conversation_id: Optional conversation ID (for logging)
|
|
281
|
+
|
|
282
|
+
Returns:
|
|
283
|
+
Tuple of (SQLCacheEntry, combined_similarity_score) if found, None otherwise
|
|
284
|
+
"""
|
|
285
|
+
ttl_seconds = self.parameters.time_to_live_seconds
|
|
286
|
+
ttl_disabled = ttl_seconds is None or ttl_seconds < 0
|
|
287
|
+
|
|
288
|
+
question_weight = self.question_weight
|
|
289
|
+
context_weight = self.context_weight
|
|
290
|
+
|
|
291
|
+
best_entry: InMemoryCacheEntry | None = None
|
|
292
|
+
best_question_sim: float = 0.0
|
|
293
|
+
best_context_sim: float = 0.0
|
|
294
|
+
best_combined_sim: float = 0.0
|
|
295
|
+
|
|
296
|
+
# Linear scan through all entries
|
|
297
|
+
with self._lock:
|
|
298
|
+
entries_to_delete: list[int] = []
|
|
299
|
+
|
|
300
|
+
for idx, entry in enumerate(self._cache):
|
|
301
|
+
# Filter by space_id (partition)
|
|
302
|
+
if entry.genie_space_id != self.space_id:
|
|
303
|
+
continue
|
|
304
|
+
|
|
305
|
+
# Check TTL
|
|
306
|
+
is_valid = True
|
|
307
|
+
if not ttl_disabled:
|
|
308
|
+
age = datetime.now() - entry.created_at
|
|
309
|
+
is_valid = age.total_seconds() <= ttl_seconds
|
|
310
|
+
|
|
311
|
+
if not is_valid:
|
|
312
|
+
entries_to_delete.append(idx)
|
|
313
|
+
continue
|
|
314
|
+
|
|
315
|
+
# Calculate L2 distances and convert to similarities
|
|
316
|
+
question_distance = l2_distance(
|
|
317
|
+
question_embedding, entry.question_embedding
|
|
318
|
+
)
|
|
319
|
+
context_distance = l2_distance(
|
|
320
|
+
context_embedding, entry.context_embedding
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
question_sim = distance_to_similarity(question_distance)
|
|
324
|
+
context_sim = distance_to_similarity(context_distance)
|
|
325
|
+
|
|
326
|
+
# Calculate weighted combined similarity
|
|
327
|
+
combined_sim = (question_weight * question_sim) + (
|
|
328
|
+
context_weight * context_sim
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
# Track best match
|
|
332
|
+
if combined_sim > best_combined_sim:
|
|
333
|
+
best_entry = entry
|
|
334
|
+
best_question_sim = question_sim
|
|
335
|
+
best_context_sim = context_sim
|
|
336
|
+
best_combined_sim = combined_sim
|
|
337
|
+
|
|
338
|
+
# Delete expired entries
|
|
339
|
+
for idx in reversed(entries_to_delete):
|
|
340
|
+
del self._cache[idx]
|
|
341
|
+
logger.trace("Deleted expired entry", layer=self.name, index=idx)
|
|
342
|
+
|
|
343
|
+
# No entries found
|
|
344
|
+
if best_entry is None:
|
|
345
|
+
logger.info(
|
|
346
|
+
"Cache MISS (no entries)",
|
|
347
|
+
layer=self.name,
|
|
348
|
+
question=question[:50],
|
|
349
|
+
space=self.space_id,
|
|
350
|
+
delegating_to=type(self.impl).__name__,
|
|
351
|
+
)
|
|
352
|
+
return None
|
|
353
|
+
|
|
354
|
+
# Log best match info
|
|
355
|
+
logger.debug(
|
|
356
|
+
"Best match found",
|
|
357
|
+
layer=self.name,
|
|
358
|
+
question_sim=f"{best_question_sim:.4f}",
|
|
359
|
+
context_sim=f"{best_context_sim:.4f}",
|
|
360
|
+
combined_sim=f"{best_combined_sim:.4f}",
|
|
361
|
+
cached_question=best_entry.question[:50],
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
# Check BOTH similarity thresholds
|
|
365
|
+
if best_question_sim < self.similarity_threshold:
|
|
366
|
+
logger.info(
|
|
367
|
+
"Cache MISS (question similarity too low)",
|
|
368
|
+
layer=self.name,
|
|
369
|
+
question_sim=f"{best_question_sim:.4f}",
|
|
370
|
+
threshold=self.similarity_threshold,
|
|
371
|
+
)
|
|
372
|
+
return None
|
|
373
|
+
|
|
374
|
+
if best_context_sim < self.context_similarity_threshold:
|
|
375
|
+
logger.info(
|
|
376
|
+
"Cache MISS (context similarity too low)",
|
|
377
|
+
layer=self.name,
|
|
378
|
+
context_sim=f"{best_context_sim:.4f}",
|
|
379
|
+
threshold=self.context_similarity_threshold,
|
|
380
|
+
)
|
|
381
|
+
return None
|
|
382
|
+
|
|
383
|
+
# Cache HIT - Update last accessed time
|
|
384
|
+
with self._lock:
|
|
385
|
+
best_entry.last_accessed_at = datetime.now()
|
|
386
|
+
|
|
387
|
+
cache_age_seconds = (datetime.now() - best_entry.created_at).total_seconds()
|
|
388
|
+
logger.info(
|
|
389
|
+
"Cache HIT",
|
|
390
|
+
layer=self.name,
|
|
391
|
+
question=question[:80],
|
|
392
|
+
conversation_id=conversation_id,
|
|
393
|
+
matched_question=best_entry.question[:80],
|
|
394
|
+
cache_age_seconds=round(cache_age_seconds, 1),
|
|
395
|
+
question_similarity=f"{best_question_sim:.4f}",
|
|
396
|
+
context_similarity=f"{best_context_sim:.4f}",
|
|
397
|
+
combined_similarity=f"{best_combined_sim:.4f}",
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
cache_entry = SQLCacheEntry(
|
|
401
|
+
query=best_entry.sql_query,
|
|
402
|
+
description=best_entry.description,
|
|
403
|
+
conversation_id=best_entry.conversation_id,
|
|
404
|
+
created_at=best_entry.created_at,
|
|
405
|
+
message_id=best_entry.message_id,
|
|
406
|
+
# In-memory caches don't have database row IDs
|
|
407
|
+
cache_entry_id=None,
|
|
408
|
+
)
|
|
409
|
+
return cache_entry, best_combined_sim
|
|
410
|
+
|
|
411
|
+
def _store_entry(
|
|
412
|
+
self,
|
|
413
|
+
question: str,
|
|
414
|
+
conversation_context: str,
|
|
415
|
+
question_embedding: list[float],
|
|
416
|
+
context_embedding: list[float],
|
|
417
|
+
response: GenieResponse,
|
|
418
|
+
message_id: str | None = None,
|
|
419
|
+
) -> None:
|
|
420
|
+
"""
|
|
421
|
+
Store a new cache entry with dual embeddings and message_id.
|
|
422
|
+
|
|
423
|
+
If capacity is set and reached, evicts least recently used entry (LRU).
|
|
424
|
+
"""
|
|
425
|
+
now = datetime.now()
|
|
426
|
+
new_entry = InMemoryCacheEntry(
|
|
427
|
+
genie_space_id=self.space_id,
|
|
428
|
+
question=question,
|
|
429
|
+
conversation_context=conversation_context,
|
|
430
|
+
question_embedding=question_embedding,
|
|
431
|
+
context_embedding=context_embedding,
|
|
432
|
+
sql_query=response.query,
|
|
433
|
+
description=response.description,
|
|
434
|
+
conversation_id=response.conversation_id,
|
|
435
|
+
created_at=now,
|
|
436
|
+
last_accessed_at=now,
|
|
437
|
+
message_id=message_id,
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
with self._lock:
|
|
441
|
+
# Enforce capacity limit (LRU eviction)
|
|
442
|
+
if self.parameters.capacity is not None:
|
|
443
|
+
space_entries = [
|
|
444
|
+
e for e in self._cache if e.genie_space_id == self.space_id
|
|
445
|
+
]
|
|
446
|
+
|
|
447
|
+
while len(space_entries) >= self.parameters.capacity:
|
|
448
|
+
# Find and remove least recently used entry
|
|
449
|
+
lru_idx = None
|
|
450
|
+
lru_time = None
|
|
451
|
+
|
|
452
|
+
for idx, entry in enumerate(self._cache):
|
|
453
|
+
if entry.genie_space_id == self.space_id:
|
|
454
|
+
if lru_time is None or entry.last_accessed_at < lru_time:
|
|
455
|
+
lru_time = entry.last_accessed_at
|
|
456
|
+
lru_idx = idx
|
|
457
|
+
|
|
458
|
+
if lru_idx is not None:
|
|
459
|
+
evicted = self._cache.pop(lru_idx)
|
|
460
|
+
logger.trace(
|
|
461
|
+
"Evicted LRU cache entry",
|
|
462
|
+
layer=self.name,
|
|
463
|
+
question=evicted.question[:50],
|
|
464
|
+
capacity=self.parameters.capacity,
|
|
465
|
+
)
|
|
466
|
+
space_entries = [
|
|
467
|
+
e for e in self._cache if e.genie_space_id == self.space_id
|
|
468
|
+
]
|
|
469
|
+
else:
|
|
470
|
+
break
|
|
471
|
+
|
|
472
|
+
self._cache.append(new_entry)
|
|
473
|
+
logger.debug(
|
|
474
|
+
"Stored cache entry",
|
|
475
|
+
layer=self.name,
|
|
476
|
+
question=question[:50],
|
|
477
|
+
space=self.space_id,
|
|
478
|
+
cache_size=len(
|
|
479
|
+
[e for e in self._cache if e.genie_space_id == self.space_id]
|
|
480
|
+
),
|
|
481
|
+
capacity=self.parameters.capacity,
|
|
482
|
+
)
|
|
483
|
+
|
|
484
|
+
def _on_stale_cache_entry(self, question: str) -> None:
|
|
485
|
+
"""Remove stale cache entry from memory."""
|
|
486
|
+
with self._lock:
|
|
487
|
+
for idx, entry in enumerate(self._cache):
|
|
488
|
+
if entry.genie_space_id == self.space_id and entry.question == question:
|
|
489
|
+
del self._cache[idx]
|
|
490
|
+
logger.info(
|
|
491
|
+
"Deleted stale cache entry from memory",
|
|
492
|
+
layer=self.name,
|
|
493
|
+
question=question[:50],
|
|
494
|
+
)
|
|
495
|
+
break
|
|
496
|
+
|
|
497
|
+
def _invalidate_by_question(self, question: str) -> bool:
|
|
498
|
+
"""
|
|
499
|
+
Invalidate cache entries matching a specific question.
|
|
500
|
+
|
|
501
|
+
This method is called when negative feedback is received to remove
|
|
502
|
+
the corresponding cache entry from the in-memory cache.
|
|
503
|
+
|
|
504
|
+
Args:
|
|
505
|
+
question: The question text to match and invalidate
|
|
506
|
+
|
|
507
|
+
Returns:
|
|
508
|
+
True if an entry was found and invalidated, False otherwise
|
|
509
|
+
"""
|
|
510
|
+
with self._lock:
|
|
511
|
+
for idx, entry in enumerate(self._cache):
|
|
512
|
+
if entry.genie_space_id == self.space_id and entry.question == question:
|
|
513
|
+
del self._cache[idx]
|
|
514
|
+
logger.info(
|
|
515
|
+
"Invalidated cache entry by question",
|
|
516
|
+
layer=self.name,
|
|
517
|
+
question=question[:50],
|
|
518
|
+
space_id=self.space_id,
|
|
519
|
+
)
|
|
520
|
+
return True
|
|
521
|
+
return False
|
|
522
|
+
|
|
523
|
+
# Note: ask_question_with_cache_info is inherited from ContextAwareGenieService
|
|
524
|
+
# using the Template Method pattern. InMemoryContextAwareGenieService uses the
|
|
525
|
+
# default empty hook implementations since it doesn't track prompt history.
|
|
526
|
+
|
|
527
|
+
# Template Method implementations for invalidate_expired() and clear()
|
|
528
|
+
|
|
529
|
+
def _delete_expired_entries(self, ttl_seconds: int) -> int:
|
|
530
|
+
"""Delete expired entries from the cache."""
|
|
531
|
+
deleted = 0
|
|
532
|
+
with self._lock:
|
|
533
|
+
indices_to_delete: list[int] = []
|
|
534
|
+
now = datetime.now()
|
|
535
|
+
|
|
536
|
+
for idx, entry in enumerate(self._cache):
|
|
537
|
+
if entry.genie_space_id != self.space_id:
|
|
538
|
+
continue
|
|
539
|
+
|
|
540
|
+
age = now - entry.created_at
|
|
541
|
+
if age.total_seconds() > ttl_seconds:
|
|
542
|
+
indices_to_delete.append(idx)
|
|
543
|
+
|
|
544
|
+
# Delete in reverse order
|
|
545
|
+
for idx in reversed(indices_to_delete):
|
|
546
|
+
del self._cache[idx]
|
|
547
|
+
deleted += 1
|
|
548
|
+
|
|
549
|
+
logger.trace(
|
|
550
|
+
"Deleted expired entries",
|
|
551
|
+
layer=self.name,
|
|
552
|
+
deleted_count=deleted,
|
|
553
|
+
)
|
|
554
|
+
|
|
555
|
+
return deleted
|
|
556
|
+
|
|
557
|
+
def _delete_all_entries(self) -> int:
|
|
558
|
+
"""Delete all entries for this Genie space."""
|
|
559
|
+
deleted = 0
|
|
560
|
+
|
|
561
|
+
with self._lock:
|
|
562
|
+
indices_to_delete: list[int] = []
|
|
563
|
+
for idx, entry in enumerate(self._cache):
|
|
564
|
+
if entry.genie_space_id == self.space_id:
|
|
565
|
+
indices_to_delete.append(idx)
|
|
566
|
+
|
|
567
|
+
# Delete in reverse order
|
|
568
|
+
for idx in reversed(indices_to_delete):
|
|
569
|
+
del self._cache[idx]
|
|
570
|
+
deleted += 1
|
|
571
|
+
|
|
572
|
+
logger.debug(
|
|
573
|
+
"Cleared cache entries", layer=self.name, deleted_count=deleted
|
|
574
|
+
)
|
|
575
|
+
|
|
576
|
+
return deleted
|
|
577
|
+
|
|
578
|
+
@property
|
|
579
|
+
def size(self) -> int:
|
|
580
|
+
"""Current number of entries in the cache for this Genie space."""
|
|
581
|
+
self._setup()
|
|
582
|
+
with self._lock:
|
|
583
|
+
return len([e for e in self._cache if e.genie_space_id == self.space_id])
|
|
584
|
+
|
|
585
|
+
# Template Method implementations for stats()
|
|
586
|
+
|
|
587
|
+
def _count_all_entries(self) -> int:
|
|
588
|
+
"""Count all cache entries for this Genie space."""
|
|
589
|
+
with self._lock:
|
|
590
|
+
return len([e for e in self._cache if e.genie_space_id == self.space_id])
|
|
591
|
+
|
|
592
|
+
def _count_entries_with_ttl(self, ttl_seconds: int) -> tuple[int, int]:
|
|
593
|
+
"""Count total and expired entries for this Genie space."""
|
|
594
|
+
now = datetime.now()
|
|
595
|
+
with self._lock:
|
|
596
|
+
space_entries = [
|
|
597
|
+
e for e in self._cache if e.genie_space_id == self.space_id
|
|
598
|
+
]
|
|
599
|
+
total = len(space_entries)
|
|
600
|
+
expired = 0
|
|
601
|
+
for entry in space_entries:
|
|
602
|
+
age = now - entry.created_at
|
|
603
|
+
if age.total_seconds() > ttl_seconds:
|
|
604
|
+
expired += 1
|
|
605
|
+
return total, expired
|
|
606
|
+
|
|
607
|
+
def _get_additional_stats(self) -> dict[str, Any]:
|
|
608
|
+
"""Add capacity info to stats."""
|
|
609
|
+
return {"capacity": self.parameters.capacity}
|