dao-ai 0.1.19__py3-none-any.whl → 0.1.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/cli.py +37 -7
- dao_ai/config.py +265 -10
- dao_ai/genie/__init__.py +55 -7
- dao_ai/genie/cache/__init__.py +36 -9
- dao_ai/genie/cache/base.py +143 -2
- dao_ai/genie/cache/context_aware/__init__.py +52 -0
- dao_ai/genie/cache/context_aware/base.py +1204 -0
- dao_ai/genie/cache/{in_memory_semantic.py → context_aware/in_memory.py} +233 -383
- dao_ai/genie/cache/context_aware/optimization.py +930 -0
- dao_ai/genie/cache/context_aware/persistent.py +802 -0
- dao_ai/genie/cache/context_aware/postgres.py +1343 -0
- dao_ai/genie/cache/lru.py +248 -70
- dao_ai/genie/core.py +235 -11
- dao_ai/middleware/__init__.py +8 -1
- dao_ai/middleware/tool_call_observability.py +227 -0
- dao_ai/nodes.py +4 -4
- dao_ai/tools/__init__.py +2 -2
- dao_ai/tools/genie.py +10 -10
- dao_ai/utils.py +7 -3
- {dao_ai-0.1.19.dist-info → dao_ai-0.1.21.dist-info}/METADATA +1 -1
- {dao_ai-0.1.19.dist-info → dao_ai-0.1.21.dist-info}/RECORD +24 -19
- dao_ai/genie/cache/semantic.py +0 -1004
- {dao_ai-0.1.19.dist-info → dao_ai-0.1.21.dist-info}/WHEEL +0 -0
- {dao_ai-0.1.19.dist-info → dao_ai-0.1.21.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.1.19.dist-info → dao_ai-0.1.21.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,22 +1,18 @@
|
|
|
1
1
|
"""
|
|
2
|
-
In-memory
|
|
2
|
+
In-memory context-aware Genie cache implementation.
|
|
3
3
|
|
|
4
|
-
This module provides a
|
|
4
|
+
This module provides a context-aware cache that stores embeddings and cache entries
|
|
5
5
|
entirely in memory, without requiring external database dependencies like PostgreSQL
|
|
6
6
|
or Databricks Lakebase. It uses L2 distance for similarity search and supports
|
|
7
7
|
dual embedding matching (question + conversation context).
|
|
8
8
|
|
|
9
|
-
The cache supports conversation-aware embedding using a rolling window approach
|
|
10
|
-
to capture context from recent conversation turns, improving accuracy for
|
|
11
|
-
multi-turn conversations with anaphoric references.
|
|
12
|
-
|
|
13
9
|
Use this when:
|
|
14
10
|
- No external database access is available
|
|
15
11
|
- Single-instance deployments (cache not shared across instances)
|
|
16
12
|
- Cache persistence across restarts is not required
|
|
17
13
|
- Cache sizes are moderate (hundreds to low thousands of entries)
|
|
18
14
|
|
|
19
|
-
For multi-instance deployments or large cache sizes, use
|
|
15
|
+
For multi-instance deployments or large cache sizes, use PostgresContextAwareGenieService
|
|
20
16
|
with PostgreSQL backend instead.
|
|
21
17
|
"""
|
|
22
18
|
|
|
@@ -29,25 +25,19 @@ from typing import Any
|
|
|
29
25
|
|
|
30
26
|
import mlflow
|
|
31
27
|
import numpy as np
|
|
32
|
-
import pandas as pd
|
|
33
28
|
from databricks.sdk import WorkspaceClient
|
|
34
|
-
from databricks.sdk.service.sql import StatementResponse, StatementState
|
|
35
29
|
from databricks_ai_bridge.genie import GenieResponse
|
|
36
30
|
from loguru import logger
|
|
37
31
|
|
|
38
32
|
from dao_ai.config import (
|
|
39
33
|
GenieInMemorySemanticCacheParametersModel,
|
|
40
|
-
LLMModel,
|
|
41
34
|
WarehouseModel,
|
|
42
35
|
)
|
|
43
36
|
from dao_ai.genie.cache.base import (
|
|
44
|
-
CacheResult,
|
|
45
37
|
GenieServiceBase,
|
|
46
38
|
SQLCacheEntry,
|
|
47
39
|
)
|
|
48
|
-
from dao_ai.genie.cache.
|
|
49
|
-
get_conversation_history,
|
|
50
|
-
)
|
|
40
|
+
from dao_ai.genie.cache.context_aware.base import ContextAwareGenieService
|
|
51
41
|
|
|
52
42
|
|
|
53
43
|
@dataclass
|
|
@@ -59,6 +49,19 @@ class InMemoryCacheEntry:
|
|
|
59
49
|
dual embeddings (question + context) for high-precision semantic matching.
|
|
60
50
|
|
|
61
51
|
Uses LRU (Least Recently Used) eviction strategy when capacity is reached.
|
|
52
|
+
|
|
53
|
+
Attributes:
|
|
54
|
+
genie_space_id: The Genie space ID this entry belongs to
|
|
55
|
+
question: The original question text
|
|
56
|
+
conversation_context: Previous conversation context for embedding
|
|
57
|
+
question_embedding: Embedding vector for the question
|
|
58
|
+
context_embedding: Embedding vector for the conversation context
|
|
59
|
+
sql_query: The SQL query to re-execute on cache hit
|
|
60
|
+
description: Description of the query
|
|
61
|
+
conversation_id: The conversation ID where this query originated
|
|
62
|
+
created_at: When the entry was created
|
|
63
|
+
last_accessed_at: Last access time for LRU eviction
|
|
64
|
+
message_id: The original Genie message ID (for feedback on cache hits)
|
|
62
65
|
"""
|
|
63
66
|
|
|
64
67
|
genie_space_id: str
|
|
@@ -71,6 +74,7 @@ class InMemoryCacheEntry:
|
|
|
71
74
|
conversation_id: str
|
|
72
75
|
created_at: datetime
|
|
73
76
|
last_accessed_at: datetime # Track last access time for LRU eviction
|
|
77
|
+
message_id: str | None = None # Original Genie message ID for feedback
|
|
74
78
|
|
|
75
79
|
|
|
76
80
|
def l2_distance(a: list[float], b: list[float]) -> float:
|
|
@@ -106,9 +110,9 @@ def distance_to_similarity(distance: float) -> float:
|
|
|
106
110
|
return 1.0 / (1.0 + distance)
|
|
107
111
|
|
|
108
112
|
|
|
109
|
-
class
|
|
113
|
+
class InMemoryContextAwareGenieService(ContextAwareGenieService):
|
|
110
114
|
"""
|
|
111
|
-
In-memory
|
|
115
|
+
In-memory context-aware caching decorator using dual embeddings for similarity lookup.
|
|
112
116
|
|
|
113
117
|
This service caches the SQL query generated by Genie along with dual embeddings
|
|
114
118
|
(question + conversation context) for high-precision semantic matching. On
|
|
@@ -123,7 +127,7 @@ class InMemorySemanticCacheService(GenieServiceBase):
|
|
|
123
127
|
|
|
124
128
|
Example:
|
|
125
129
|
from dao_ai.config import GenieInMemorySemanticCacheParametersModel
|
|
126
|
-
from dao_ai.genie.cache import
|
|
130
|
+
from dao_ai.genie.cache.context_aware import InMemoryContextAwareGenieService
|
|
127
131
|
|
|
128
132
|
cache_params = GenieInMemorySemanticCacheParametersModel(
|
|
129
133
|
warehouse=warehouse_model,
|
|
@@ -132,7 +136,7 @@ class InMemorySemanticCacheService(GenieServiceBase):
|
|
|
132
136
|
similarity_threshold=0.85,
|
|
133
137
|
capacity=1000, # Limit to 1000 entries
|
|
134
138
|
)
|
|
135
|
-
genie =
|
|
139
|
+
genie = InMemoryContextAwareGenieService(
|
|
136
140
|
impl=GenieService(Genie(space_id="my-space")),
|
|
137
141
|
parameters=cache_params,
|
|
138
142
|
workspace_client=workspace_client,
|
|
@@ -143,7 +147,7 @@ class InMemorySemanticCacheService(GenieServiceBase):
|
|
|
143
147
|
|
|
144
148
|
impl: GenieServiceBase
|
|
145
149
|
parameters: GenieInMemorySemanticCacheParametersModel
|
|
146
|
-
|
|
150
|
+
_workspace_client: WorkspaceClient | None
|
|
147
151
|
name: str
|
|
148
152
|
_embeddings: Any # DatabricksEmbeddings
|
|
149
153
|
_cache: list[InMemoryCacheEntry]
|
|
@@ -159,7 +163,7 @@ class InMemorySemanticCacheService(GenieServiceBase):
|
|
|
159
163
|
name: str | None = None,
|
|
160
164
|
) -> None:
|
|
161
165
|
"""
|
|
162
|
-
Initialize the in-memory
|
|
166
|
+
Initialize the in-memory context-aware cache service.
|
|
163
167
|
|
|
164
168
|
Args:
|
|
165
169
|
impl: The underlying GenieServiceBase to delegate to on cache miss.
|
|
@@ -171,7 +175,7 @@ class InMemorySemanticCacheService(GenieServiceBase):
|
|
|
171
175
|
"""
|
|
172
176
|
self.impl = impl
|
|
173
177
|
self.parameters = parameters
|
|
174
|
-
self.
|
|
178
|
+
self._workspace_client = workspace_client
|
|
175
179
|
self.name = name if name is not None else self.__class__.__name__
|
|
176
180
|
self._embeddings = None
|
|
177
181
|
self._cache = []
|
|
@@ -179,56 +183,27 @@ class InMemorySemanticCacheService(GenieServiceBase):
|
|
|
179
183
|
self._embedding_dims = None
|
|
180
184
|
self._setup_complete = False
|
|
181
185
|
|
|
182
|
-
def initialize(self) -> "InMemorySemanticCacheService":
|
|
183
|
-
"""
|
|
184
|
-
Eagerly initialize the cache service.
|
|
185
|
-
|
|
186
|
-
Call this during tool creation to:
|
|
187
|
-
- Validate configuration early (fail fast)
|
|
188
|
-
- Initialize embeddings model before any requests
|
|
189
|
-
- Avoid first-request latency from lazy initialization
|
|
190
|
-
|
|
191
|
-
Returns:
|
|
192
|
-
self for method chaining
|
|
193
|
-
"""
|
|
194
|
-
self._setup()
|
|
195
|
-
return self
|
|
196
|
-
|
|
197
186
|
def _setup(self) -> None:
|
|
198
187
|
"""Initialize embeddings model lazily."""
|
|
199
188
|
if self._setup_complete:
|
|
200
189
|
return
|
|
201
190
|
|
|
202
|
-
# Initialize embeddings
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
if isinstance(self.parameters.embedding_model, str)
|
|
207
|
-
else self.parameters.embedding_model
|
|
191
|
+
# Initialize embeddings using base class helper
|
|
192
|
+
self._initialize_embeddings(
|
|
193
|
+
self.parameters.embedding_model,
|
|
194
|
+
self.parameters.embedding_dims,
|
|
208
195
|
)
|
|
209
|
-
self._embeddings = embedding_model.as_embeddings_model()
|
|
210
|
-
|
|
211
|
-
# Auto-detect embedding dimensions if not provided
|
|
212
|
-
if self.parameters.embedding_dims is None:
|
|
213
|
-
sample_embedding: list[float] = self._embeddings.embed_query("test")
|
|
214
|
-
self._embedding_dims = len(sample_embedding)
|
|
215
|
-
logger.debug(
|
|
216
|
-
"Auto-detected embedding dimensions",
|
|
217
|
-
layer=self.name,
|
|
218
|
-
dims=self._embedding_dims,
|
|
219
|
-
)
|
|
220
|
-
else:
|
|
221
|
-
self._embedding_dims = self.parameters.embedding_dims
|
|
222
196
|
|
|
223
197
|
self._setup_complete = True
|
|
224
198
|
logger.debug(
|
|
225
|
-
"In-memory
|
|
199
|
+
"In-memory context-aware cache initialized",
|
|
226
200
|
layer=self.name,
|
|
227
201
|
space_id=self.space_id,
|
|
228
202
|
dims=self._embedding_dims,
|
|
229
203
|
capacity=self.parameters.capacity,
|
|
230
204
|
)
|
|
231
205
|
|
|
206
|
+
# Property implementations
|
|
232
207
|
@property
|
|
233
208
|
def warehouse(self) -> WarehouseModel:
|
|
234
209
|
"""The warehouse used for executing cached SQL queries."""
|
|
@@ -248,23 +223,25 @@ class InMemorySemanticCacheService(GenieServiceBase):
|
|
|
248
223
|
return self.parameters.similarity_threshold
|
|
249
224
|
|
|
250
225
|
@property
|
|
251
|
-
def
|
|
252
|
-
"""
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
226
|
+
def context_similarity_threshold(self) -> float:
|
|
227
|
+
"""Minimum similarity for context matching."""
|
|
228
|
+
return self.parameters.context_similarity_threshold
|
|
229
|
+
|
|
230
|
+
@property
|
|
231
|
+
def question_weight(self) -> float:
|
|
232
|
+
"""Weight for question similarity in combined score."""
|
|
233
|
+
return self.parameters.question_weight
|
|
234
|
+
|
|
235
|
+
@property
|
|
236
|
+
def context_weight(self) -> float:
|
|
237
|
+
"""Weight for context similarity in combined score."""
|
|
238
|
+
return self.parameters.context_weight
|
|
258
239
|
|
|
259
240
|
def _embed_question(
|
|
260
241
|
self, question: str, conversation_id: str | None = None
|
|
261
242
|
) -> tuple[list[float], list[float], str]:
|
|
262
243
|
"""
|
|
263
|
-
Generate dual embeddings
|
|
264
|
-
|
|
265
|
-
This enables separate matching of question similarity vs context similarity,
|
|
266
|
-
improving precision by ensuring both the question AND the conversation context
|
|
267
|
-
are semantically similar before returning a cached result.
|
|
244
|
+
Generate dual embeddings using Genie API for conversation history.
|
|
268
245
|
|
|
269
246
|
Args:
|
|
270
247
|
question: The question to embed
|
|
@@ -272,84 +249,13 @@ class InMemorySemanticCacheService(GenieServiceBase):
|
|
|
272
249
|
|
|
273
250
|
Returns:
|
|
274
251
|
Tuple of (question_embedding, context_embedding, conversation_context_string)
|
|
275
|
-
- question_embedding: Vector embedding of just the question
|
|
276
|
-
- context_embedding: Vector embedding of the conversation context (or zero vector if no context)
|
|
277
|
-
- conversation_context_string: The conversation context string (empty if no context)
|
|
278
252
|
"""
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
self.
|
|
284
|
-
|
|
285
|
-
and self.parameters.context_window_size > 0
|
|
286
|
-
):
|
|
287
|
-
try:
|
|
288
|
-
# Retrieve conversation history
|
|
289
|
-
conversation_messages = get_conversation_history(
|
|
290
|
-
workspace_client=self.workspace_client,
|
|
291
|
-
space_id=self.space_id,
|
|
292
|
-
conversation_id=conversation_id,
|
|
293
|
-
max_messages=self.parameters.context_window_size
|
|
294
|
-
* 2, # Get extra for safety
|
|
295
|
-
)
|
|
296
|
-
|
|
297
|
-
# Build context string (just the "Previous:" messages, not the current question)
|
|
298
|
-
if conversation_messages:
|
|
299
|
-
recent_messages = (
|
|
300
|
-
conversation_messages[-self.parameters.context_window_size :]
|
|
301
|
-
if len(conversation_messages)
|
|
302
|
-
> self.parameters.context_window_size
|
|
303
|
-
else conversation_messages
|
|
304
|
-
)
|
|
305
|
-
|
|
306
|
-
context_parts: list[str] = []
|
|
307
|
-
for msg in recent_messages:
|
|
308
|
-
if msg.content:
|
|
309
|
-
content: str = msg.content
|
|
310
|
-
if len(content) > 500:
|
|
311
|
-
content = content[:500] + "..."
|
|
312
|
-
context_parts.append(f"Previous: {content}")
|
|
313
|
-
|
|
314
|
-
conversation_context = "\n".join(context_parts)
|
|
315
|
-
|
|
316
|
-
# Truncate if too long
|
|
317
|
-
estimated_tokens = len(conversation_context) / 4
|
|
318
|
-
if estimated_tokens > self.parameters.max_context_tokens:
|
|
319
|
-
target_chars = self.parameters.max_context_tokens * 4
|
|
320
|
-
conversation_context = (
|
|
321
|
-
conversation_context[:target_chars] + "..."
|
|
322
|
-
)
|
|
323
|
-
|
|
324
|
-
logger.trace(
|
|
325
|
-
"Using conversation context",
|
|
326
|
-
layer=self.name,
|
|
327
|
-
messages_count=len(conversation_messages),
|
|
328
|
-
window_size=self.parameters.context_window_size,
|
|
329
|
-
)
|
|
330
|
-
except Exception as e:
|
|
331
|
-
logger.warning(
|
|
332
|
-
"Failed to build conversation context, using question only",
|
|
333
|
-
layer=self.name,
|
|
334
|
-
error=str(e),
|
|
335
|
-
)
|
|
336
|
-
conversation_context = ""
|
|
337
|
-
|
|
338
|
-
# Generate dual embeddings
|
|
339
|
-
if conversation_context:
|
|
340
|
-
# Embed both question and context
|
|
341
|
-
embeddings: list[list[float]] = self._embeddings.embed_documents(
|
|
342
|
-
[question, conversation_context]
|
|
343
|
-
)
|
|
344
|
-
question_embedding = embeddings[0]
|
|
345
|
-
context_embedding = embeddings[1]
|
|
346
|
-
else:
|
|
347
|
-
# Only embed question, use zero vector for context
|
|
348
|
-
embeddings = self._embeddings.embed_documents([question])
|
|
349
|
-
question_embedding = embeddings[0]
|
|
350
|
-
context_embedding = [0.0] * len(question_embedding) # Zero vector
|
|
351
|
-
|
|
352
|
-
return question_embedding, context_embedding, conversation_context
|
|
253
|
+
return self._embed_question_with_genie_history(
|
|
254
|
+
question,
|
|
255
|
+
conversation_id,
|
|
256
|
+
self.parameters.context_window_size,
|
|
257
|
+
self.parameters.max_context_tokens,
|
|
258
|
+
)
|
|
353
259
|
|
|
354
260
|
@mlflow.trace(name="semantic_search_in_memory")
|
|
355
261
|
def _find_similar(
|
|
@@ -363,9 +269,6 @@ class InMemorySemanticCacheService(GenieServiceBase):
|
|
|
363
269
|
"""
|
|
364
270
|
Find a semantically similar cached entry using dual embedding matching.
|
|
365
271
|
|
|
366
|
-
This method matches BOTH the question AND the conversation context separately,
|
|
367
|
-
ensuring high precision by requiring both to be semantically similar.
|
|
368
|
-
|
|
369
272
|
Performs linear scan through all cache entries, filtering by space_id and
|
|
370
273
|
calculating L2 distances for similarity matching.
|
|
371
274
|
|
|
@@ -382,8 +285,8 @@ class InMemorySemanticCacheService(GenieServiceBase):
|
|
|
382
285
|
ttl_seconds = self.parameters.time_to_live_seconds
|
|
383
286
|
ttl_disabled = ttl_seconds is None or ttl_seconds < 0
|
|
384
287
|
|
|
385
|
-
question_weight
|
|
386
|
-
context_weight
|
|
288
|
+
question_weight = self.question_weight
|
|
289
|
+
context_weight = self.context_weight
|
|
387
290
|
|
|
388
291
|
best_entry: InMemoryCacheEntry | None = None
|
|
389
292
|
best_question_sim: float = 0.0
|
|
@@ -406,7 +309,6 @@ class InMemorySemanticCacheService(GenieServiceBase):
|
|
|
406
309
|
is_valid = age.total_seconds() <= ttl_seconds
|
|
407
310
|
|
|
408
311
|
if not is_valid:
|
|
409
|
-
# Mark for deletion
|
|
410
312
|
entries_to_delete.append(idx)
|
|
411
313
|
continue
|
|
412
314
|
|
|
@@ -436,11 +338,7 @@ class InMemorySemanticCacheService(GenieServiceBase):
|
|
|
436
338
|
# Delete expired entries
|
|
437
339
|
for idx in reversed(entries_to_delete):
|
|
438
340
|
del self._cache[idx]
|
|
439
|
-
logger.trace(
|
|
440
|
-
"Deleted expired entry",
|
|
441
|
-
layer=self.name,
|
|
442
|
-
index=idx,
|
|
443
|
-
)
|
|
341
|
+
logger.trace("Deleted expired entry", layer=self.name, index=idx)
|
|
444
342
|
|
|
445
343
|
# No entries found
|
|
446
344
|
if best_entry is None:
|
|
@@ -461,32 +359,28 @@ class InMemorySemanticCacheService(GenieServiceBase):
|
|
|
461
359
|
context_sim=f"{best_context_sim:.4f}",
|
|
462
360
|
combined_sim=f"{best_combined_sim:.4f}",
|
|
463
361
|
cached_question=best_entry.question[:50],
|
|
464
|
-
cached_context=best_entry.conversation_context[:80],
|
|
465
362
|
)
|
|
466
363
|
|
|
467
|
-
# Check BOTH similarity thresholds
|
|
468
|
-
if best_question_sim < self.
|
|
364
|
+
# Check BOTH similarity thresholds
|
|
365
|
+
if best_question_sim < self.similarity_threshold:
|
|
469
366
|
logger.info(
|
|
470
367
|
"Cache MISS (question similarity too low)",
|
|
471
368
|
layer=self.name,
|
|
472
369
|
question_sim=f"{best_question_sim:.4f}",
|
|
473
|
-
threshold=self.
|
|
474
|
-
delegating_to=type(self.impl).__name__,
|
|
370
|
+
threshold=self.similarity_threshold,
|
|
475
371
|
)
|
|
476
372
|
return None
|
|
477
373
|
|
|
478
|
-
if best_context_sim < self.
|
|
374
|
+
if best_context_sim < self.context_similarity_threshold:
|
|
479
375
|
logger.info(
|
|
480
376
|
"Cache MISS (context similarity too low)",
|
|
481
377
|
layer=self.name,
|
|
482
378
|
context_sim=f"{best_context_sim:.4f}",
|
|
483
|
-
threshold=self.
|
|
484
|
-
delegating_to=type(self.impl).__name__,
|
|
379
|
+
threshold=self.context_similarity_threshold,
|
|
485
380
|
)
|
|
486
381
|
return None
|
|
487
382
|
|
|
488
|
-
# Cache HIT
|
|
489
|
-
# Update last accessed time for LRU eviction
|
|
383
|
+
# Cache HIT - Update last accessed time
|
|
490
384
|
with self._lock:
|
|
491
385
|
best_entry.last_accessed_at = datetime.now()
|
|
492
386
|
|
|
@@ -501,8 +395,6 @@ class InMemorySemanticCacheService(GenieServiceBase):
|
|
|
501
395
|
question_similarity=f"{best_question_sim:.4f}",
|
|
502
396
|
context_similarity=f"{best_context_sim:.4f}",
|
|
503
397
|
combined_similarity=f"{best_combined_sim:.4f}",
|
|
504
|
-
cached_sql=best_entry.sql_query[:80] if best_entry.sql_query else None,
|
|
505
|
-
ttl_seconds=self.parameters.time_to_live_seconds,
|
|
506
398
|
)
|
|
507
399
|
|
|
508
400
|
cache_entry = SQLCacheEntry(
|
|
@@ -510,6 +402,9 @@ class InMemorySemanticCacheService(GenieServiceBase):
|
|
|
510
402
|
description=best_entry.description,
|
|
511
403
|
conversation_id=best_entry.conversation_id,
|
|
512
404
|
created_at=best_entry.created_at,
|
|
405
|
+
message_id=best_entry.message_id,
|
|
406
|
+
# In-memory caches don't have database row IDs
|
|
407
|
+
cache_entry_id=None,
|
|
513
408
|
)
|
|
514
409
|
return cache_entry, best_combined_sim
|
|
515
410
|
|
|
@@ -520,9 +415,10 @@ class InMemorySemanticCacheService(GenieServiceBase):
|
|
|
520
415
|
question_embedding: list[float],
|
|
521
416
|
context_embedding: list[float],
|
|
522
417
|
response: GenieResponse,
|
|
418
|
+
message_id: str | None = None,
|
|
523
419
|
) -> None:
|
|
524
420
|
"""
|
|
525
|
-
Store a new cache entry with dual embeddings
|
|
421
|
+
Store a new cache entry with dual embeddings and message_id.
|
|
526
422
|
|
|
527
423
|
If capacity is set and reached, evicts least recently used entry (LRU).
|
|
528
424
|
"""
|
|
@@ -537,19 +433,19 @@ class InMemorySemanticCacheService(GenieServiceBase):
|
|
|
537
433
|
description=response.description,
|
|
538
434
|
conversation_id=response.conversation_id,
|
|
539
435
|
created_at=now,
|
|
540
|
-
last_accessed_at=now,
|
|
436
|
+
last_accessed_at=now,
|
|
437
|
+
message_id=message_id,
|
|
541
438
|
)
|
|
542
439
|
|
|
543
440
|
with self._lock:
|
|
544
441
|
# Enforce capacity limit (LRU eviction)
|
|
545
442
|
if self.parameters.capacity is not None:
|
|
546
|
-
# Count entries for this space_id
|
|
547
443
|
space_entries = [
|
|
548
444
|
e for e in self._cache if e.genie_space_id == self.space_id
|
|
549
445
|
]
|
|
550
446
|
|
|
551
447
|
while len(space_entries) >= self.parameters.capacity:
|
|
552
|
-
# Find and remove least recently used entry
|
|
448
|
+
# Find and remove least recently used entry
|
|
553
449
|
lru_idx = None
|
|
554
450
|
lru_time = None
|
|
555
451
|
|
|
@@ -578,8 +474,6 @@ class InMemorySemanticCacheService(GenieServiceBase):
|
|
|
578
474
|
"Stored cache entry",
|
|
579
475
|
layer=self.name,
|
|
580
476
|
question=question[:50],
|
|
581
|
-
context=conversation_context[:80],
|
|
582
|
-
sql=response.query[:50] if response.query else None,
|
|
583
477
|
space=self.space_id,
|
|
584
478
|
cache_size=len(
|
|
585
479
|
[e for e in self._cache if e.genie_space_id == self.space_id]
|
|
@@ -587,189 +481,53 @@ class InMemorySemanticCacheService(GenieServiceBase):
|
|
|
587
481
|
capacity=self.parameters.capacity,
|
|
588
482
|
)
|
|
589
483
|
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
)
|
|
603
|
-
|
|
604
|
-
if (
|
|
605
|
-
statement_response.status is not None
|
|
606
|
-
and statement_response.status.state != StatementState.SUCCEEDED
|
|
607
|
-
):
|
|
608
|
-
error_msg: str = (
|
|
609
|
-
f"SQL execution failed: {statement_response.status.error.message}"
|
|
610
|
-
if statement_response.status.error is not None
|
|
611
|
-
else f"SQL execution failed with state: {statement_response.status.state}"
|
|
612
|
-
)
|
|
613
|
-
logger.error("SQL execution failed", layer=self.name, error=error_msg)
|
|
614
|
-
return error_msg
|
|
615
|
-
|
|
616
|
-
if statement_response.result and statement_response.result.data_array:
|
|
617
|
-
columns: list[str] = []
|
|
618
|
-
if (
|
|
619
|
-
statement_response.manifest
|
|
620
|
-
and statement_response.manifest.schema
|
|
621
|
-
and statement_response.manifest.schema.columns
|
|
622
|
-
):
|
|
623
|
-
columns = [
|
|
624
|
-
col.name
|
|
625
|
-
for col in statement_response.manifest.schema.columns
|
|
626
|
-
if col.name is not None
|
|
627
|
-
]
|
|
628
|
-
|
|
629
|
-
data: list[list[Any]] = statement_response.result.data_array
|
|
630
|
-
if columns:
|
|
631
|
-
return pd.DataFrame(data, columns=columns)
|
|
632
|
-
else:
|
|
633
|
-
return pd.DataFrame(data)
|
|
634
|
-
|
|
635
|
-
return pd.DataFrame()
|
|
636
|
-
|
|
637
|
-
def ask_question(
|
|
638
|
-
self, question: str, conversation_id: str | None = None
|
|
639
|
-
) -> CacheResult:
|
|
640
|
-
"""
|
|
641
|
-
Ask a question, using semantic cache if a similar query exists.
|
|
484
|
+
def _on_stale_cache_entry(self, question: str) -> None:
|
|
485
|
+
"""Remove stale cache entry from memory."""
|
|
486
|
+
with self._lock:
|
|
487
|
+
for idx, entry in enumerate(self._cache):
|
|
488
|
+
if entry.genie_space_id == self.space_id and entry.question == question:
|
|
489
|
+
del self._cache[idx]
|
|
490
|
+
logger.info(
|
|
491
|
+
"Deleted stale cache entry from memory",
|
|
492
|
+
layer=self.name,
|
|
493
|
+
question=question[:50],
|
|
494
|
+
)
|
|
495
|
+
break
|
|
642
496
|
|
|
643
|
-
|
|
644
|
-
Returns CacheResult with cache metadata.
|
|
497
|
+
def _invalidate_by_question(self, question: str) -> bool:
|
|
645
498
|
"""
|
|
646
|
-
|
|
499
|
+
Invalidate cache entries matching a specific question.
|
|
647
500
|
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
self,
|
|
651
|
-
question: str,
|
|
652
|
-
conversation_id: str | None = None,
|
|
653
|
-
) -> CacheResult:
|
|
654
|
-
"""
|
|
655
|
-
Ask a question with detailed cache hit information.
|
|
656
|
-
|
|
657
|
-
On cache hit, the cached SQL is re-executed to return fresh data, but the
|
|
658
|
-
conversation_id returned is the current conversation_id (not the cached one).
|
|
501
|
+
This method is called when negative feedback is received to remove
|
|
502
|
+
the corresponding cache entry from the in-memory cache.
|
|
659
503
|
|
|
660
504
|
Args:
|
|
661
|
-
question: The question to
|
|
662
|
-
conversation_id: Optional conversation ID for context and continuation
|
|
505
|
+
question: The question text to match and invalidate
|
|
663
506
|
|
|
664
507
|
Returns:
|
|
665
|
-
|
|
666
|
-
"""
|
|
667
|
-
# Ensure initialization (lazy init if initialize() wasn't called)
|
|
668
|
-
self._setup()
|
|
669
|
-
|
|
670
|
-
# Generate dual embeddings for the question and conversation context
|
|
671
|
-
question_embedding: list[float]
|
|
672
|
-
context_embedding: list[float]
|
|
673
|
-
conversation_context: str
|
|
674
|
-
question_embedding, context_embedding, conversation_context = (
|
|
675
|
-
self._embed_question(question, conversation_id)
|
|
676
|
-
)
|
|
677
|
-
|
|
678
|
-
# Check cache using dual embedding similarity
|
|
679
|
-
cache_result: tuple[SQLCacheEntry, float] | None = self._find_similar(
|
|
680
|
-
question,
|
|
681
|
-
conversation_context,
|
|
682
|
-
question_embedding,
|
|
683
|
-
context_embedding,
|
|
684
|
-
conversation_id,
|
|
685
|
-
)
|
|
686
|
-
|
|
687
|
-
if cache_result is not None:
|
|
688
|
-
cached, combined_similarity = cache_result
|
|
689
|
-
logger.debug(
|
|
690
|
-
"In-memory semantic cache hit",
|
|
691
|
-
layer=self.name,
|
|
692
|
-
combined_similarity=f"{combined_similarity:.3f}",
|
|
693
|
-
question=question[:50],
|
|
694
|
-
conversation_id=conversation_id,
|
|
695
|
-
)
|
|
696
|
-
|
|
697
|
-
# Re-execute the cached SQL to get fresh data
|
|
698
|
-
result: pd.DataFrame | str = self._execute_sql(cached.query)
|
|
699
|
-
|
|
700
|
-
# IMPORTANT: Use the current conversation_id (from the request), not the cached one
|
|
701
|
-
# This ensures the conversation continues properly
|
|
702
|
-
response: GenieResponse = GenieResponse(
|
|
703
|
-
result=result,
|
|
704
|
-
query=cached.query,
|
|
705
|
-
description=cached.description,
|
|
706
|
-
conversation_id=conversation_id
|
|
707
|
-
if conversation_id
|
|
708
|
-
else cached.conversation_id,
|
|
709
|
-
)
|
|
710
|
-
|
|
711
|
-
return CacheResult(response=response, cache_hit=True, served_by=self.name)
|
|
712
|
-
|
|
713
|
-
# Cache miss - delegate to wrapped service
|
|
714
|
-
logger.info(
|
|
715
|
-
"Cache MISS",
|
|
716
|
-
layer=self.name,
|
|
717
|
-
question=question[:80],
|
|
718
|
-
conversation_id=conversation_id,
|
|
719
|
-
space_id=self.space_id,
|
|
720
|
-
similarity_threshold=self.similarity_threshold,
|
|
721
|
-
delegating_to=type(self.impl).__name__,
|
|
722
|
-
)
|
|
723
|
-
|
|
724
|
-
result: CacheResult = self.impl.ask_question(question, conversation_id)
|
|
725
|
-
|
|
726
|
-
# Store in cache if we got a SQL query
|
|
727
|
-
if result.response.query:
|
|
728
|
-
logger.debug(
|
|
729
|
-
"Storing new cache entry",
|
|
730
|
-
layer=self.name,
|
|
731
|
-
question=question[:50],
|
|
732
|
-
conversation_id=conversation_id,
|
|
733
|
-
space=self.space_id,
|
|
734
|
-
)
|
|
735
|
-
self._store_entry(
|
|
736
|
-
question,
|
|
737
|
-
conversation_context,
|
|
738
|
-
question_embedding,
|
|
739
|
-
context_embedding,
|
|
740
|
-
result.response,
|
|
741
|
-
)
|
|
742
|
-
elif not result.response.query:
|
|
743
|
-
logger.warning(
|
|
744
|
-
"Not caching: response has no SQL query",
|
|
745
|
-
layer=self.name,
|
|
746
|
-
question=question[:50],
|
|
747
|
-
)
|
|
748
|
-
|
|
749
|
-
return CacheResult(response=result.response, cache_hit=False, served_by=None)
|
|
750
|
-
|
|
751
|
-
@property
|
|
752
|
-
def space_id(self) -> str:
|
|
753
|
-
return self.impl.space_id
|
|
754
|
-
|
|
755
|
-
def invalidate_expired(self) -> int:
|
|
508
|
+
True if an entry was found and invalidated, False otherwise
|
|
756
509
|
"""
|
|
757
|
-
|
|
510
|
+
with self._lock:
|
|
511
|
+
for idx, entry in enumerate(self._cache):
|
|
512
|
+
if entry.genie_space_id == self.space_id and entry.question == question:
|
|
513
|
+
del self._cache[idx]
|
|
514
|
+
logger.info(
|
|
515
|
+
"Invalidated cache entry by question",
|
|
516
|
+
layer=self.name,
|
|
517
|
+
question=question[:50],
|
|
518
|
+
space_id=self.space_id,
|
|
519
|
+
)
|
|
520
|
+
return True
|
|
521
|
+
return False
|
|
758
522
|
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
ttl_seconds = self.parameters.time_to_live_seconds
|
|
523
|
+
# Note: ask_question_with_cache_info is inherited from ContextAwareGenieService
|
|
524
|
+
# using the Template Method pattern. InMemoryContextAwareGenieService uses the
|
|
525
|
+
# default empty hook implementations since it doesn't track prompt history.
|
|
763
526
|
|
|
764
|
-
|
|
765
|
-
if ttl_seconds is None or ttl_seconds < 0:
|
|
766
|
-
logger.trace(
|
|
767
|
-
"TTL disabled, no entries to expire",
|
|
768
|
-
layer=self.name,
|
|
769
|
-
space=self.space_id,
|
|
770
|
-
)
|
|
771
|
-
return 0
|
|
527
|
+
# Template Method implementations for invalidate_expired() and clear()
|
|
772
528
|
|
|
529
|
+
def _delete_expired_entries(self, ttl_seconds: int) -> int:
|
|
530
|
+
"""Delete expired entries from the cache."""
|
|
773
531
|
deleted = 0
|
|
774
532
|
with self._lock:
|
|
775
533
|
indices_to_delete: list[int] = []
|
|
@@ -783,7 +541,7 @@ class InMemorySemanticCacheService(GenieServiceBase):
|
|
|
783
541
|
if age.total_seconds() > ttl_seconds:
|
|
784
542
|
indices_to_delete.append(idx)
|
|
785
543
|
|
|
786
|
-
# Delete in reverse order
|
|
544
|
+
# Delete in reverse order
|
|
787
545
|
for idx in reversed(indices_to_delete):
|
|
788
546
|
del self._cache[idx]
|
|
789
547
|
deleted += 1
|
|
@@ -792,18 +550,15 @@ class InMemorySemanticCacheService(GenieServiceBase):
|
|
|
792
550
|
"Deleted expired entries",
|
|
793
551
|
layer=self.name,
|
|
794
552
|
deleted_count=deleted,
|
|
795
|
-
space=self.space_id,
|
|
796
553
|
)
|
|
797
554
|
|
|
798
555
|
return deleted
|
|
799
556
|
|
|
800
|
-
def
|
|
801
|
-
"""
|
|
802
|
-
self._setup()
|
|
557
|
+
def _delete_all_entries(self) -> int:
|
|
558
|
+
"""Delete all entries for this Genie space."""
|
|
803
559
|
deleted = 0
|
|
804
560
|
|
|
805
561
|
with self._lock:
|
|
806
|
-
# Find indices for this space
|
|
807
562
|
indices_to_delete: list[int] = []
|
|
808
563
|
for idx, entry in enumerate(self._cache):
|
|
809
564
|
if entry.genie_space_id == self.space_id:
|
|
@@ -815,10 +570,7 @@ class InMemorySemanticCacheService(GenieServiceBase):
|
|
|
815
570
|
deleted += 1
|
|
816
571
|
|
|
817
572
|
logger.debug(
|
|
818
|
-
"Cleared cache entries",
|
|
819
|
-
layer=self.name,
|
|
820
|
-
deleted_count=deleted,
|
|
821
|
-
space=self.space_id,
|
|
573
|
+
"Cleared cache entries", layer=self.name, deleted_count=deleted
|
|
822
574
|
)
|
|
823
575
|
|
|
824
576
|
return deleted
|
|
@@ -830,42 +582,140 @@ class InMemorySemanticCacheService(GenieServiceBase):
|
|
|
830
582
|
with self._lock:
|
|
831
583
|
return len([e for e in self._cache if e.genie_space_id == self.space_id])
|
|
832
584
|
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
585
|
+
# Template Method implementations for stats()
|
|
586
|
+
|
|
587
|
+
def _count_all_entries(self) -> int:
|
|
588
|
+
"""Count all cache entries for this Genie space."""
|
|
589
|
+
with self._lock:
|
|
590
|
+
return len([e for e in self._cache if e.genie_space_id == self.space_id])
|
|
838
591
|
|
|
592
|
+
def _count_entries_with_ttl(self, ttl_seconds: int) -> tuple[int, int]:
|
|
593
|
+
"""Count total and expired entries for this Genie space."""
|
|
594
|
+
now = datetime.now()
|
|
839
595
|
with self._lock:
|
|
840
596
|
space_entries = [
|
|
841
597
|
e for e in self._cache if e.genie_space_id == self.space_id
|
|
842
598
|
]
|
|
843
599
|
total = len(space_entries)
|
|
844
|
-
|
|
845
|
-
# If TTL is disabled, all entries are valid
|
|
846
|
-
if ttl_seconds is None or ttl_seconds < 0:
|
|
847
|
-
return {
|
|
848
|
-
"size": total,
|
|
849
|
-
"capacity": self.parameters.capacity,
|
|
850
|
-
"ttl_seconds": None,
|
|
851
|
-
"similarity_threshold": self.similarity_threshold,
|
|
852
|
-
"expired_entries": 0,
|
|
853
|
-
"valid_entries": total,
|
|
854
|
-
}
|
|
855
|
-
|
|
856
|
-
# Count expired entries
|
|
857
|
-
now = datetime.now()
|
|
858
600
|
expired = 0
|
|
859
601
|
for entry in space_entries:
|
|
860
602
|
age = now - entry.created_at
|
|
861
603
|
if age.total_seconds() > ttl_seconds:
|
|
862
604
|
expired += 1
|
|
605
|
+
return total, expired
|
|
606
|
+
|
|
607
|
+
def _get_additional_stats(self) -> dict[str, Any]:
|
|
608
|
+
"""Add capacity info to stats."""
|
|
609
|
+
return {"capacity": self.parameters.capacity}
|
|
610
|
+
|
|
611
|
+
def get_entries(
|
|
612
|
+
self,
|
|
613
|
+
limit: int | None = None,
|
|
614
|
+
offset: int | None = None,
|
|
615
|
+
include_embeddings: bool = False,
|
|
616
|
+
conversation_id: str | None = None,
|
|
617
|
+
created_after: datetime | None = None,
|
|
618
|
+
created_before: datetime | None = None,
|
|
619
|
+
question_contains: str | None = None,
|
|
620
|
+
) -> list[dict[str, Any]]:
|
|
621
|
+
"""
|
|
622
|
+
Get cache entries with optional filtering.
|
|
623
|
+
|
|
624
|
+
This method retrieves cache entries for inspection, debugging, or
|
|
625
|
+
generating evaluation datasets for threshold optimization.
|
|
626
|
+
|
|
627
|
+
Args:
|
|
628
|
+
limit: Maximum number of entries to return (None = no limit)
|
|
629
|
+
offset: Number of entries to skip for pagination (None = 0)
|
|
630
|
+
include_embeddings: Whether to include embedding vectors in results.
|
|
631
|
+
Embeddings are large, so set False for general inspection.
|
|
632
|
+
conversation_id: Filter by conversation ID (None = all conversations)
|
|
633
|
+
created_after: Only entries created after this time (None = no filter)
|
|
634
|
+
created_before: Only entries created before this time (None = no filter)
|
|
635
|
+
question_contains: Case-insensitive text search on question field
|
|
636
|
+
|
|
637
|
+
Returns:
|
|
638
|
+
List of cache entry dicts. See base class for full key documentation.
|
|
639
|
+
|
|
640
|
+
Example:
|
|
641
|
+
# Get entries with embeddings for evaluation dataset generation
|
|
642
|
+
entries = cache.get_entries(include_embeddings=True, limit=100)
|
|
643
|
+
eval_dataset = generate_eval_dataset_from_cache(entries)
|
|
644
|
+
"""
|
|
645
|
+
self._setup()
|
|
646
|
+
|
|
647
|
+
with self._lock:
|
|
648
|
+
# Filter entries for this space
|
|
649
|
+
filtered_entries: list[InMemoryCacheEntry] = []
|
|
650
|
+
|
|
651
|
+
for entry in self._cache:
|
|
652
|
+
# Filter by space_id
|
|
653
|
+
if entry.genie_space_id != self.space_id:
|
|
654
|
+
continue
|
|
655
|
+
|
|
656
|
+
# Filter by conversation_id
|
|
657
|
+
if (
|
|
658
|
+
conversation_id is not None
|
|
659
|
+
and entry.conversation_id != conversation_id
|
|
660
|
+
):
|
|
661
|
+
continue
|
|
662
|
+
|
|
663
|
+
# Filter by created_after
|
|
664
|
+
if created_after is not None and entry.created_at <= created_after:
|
|
665
|
+
continue
|
|
666
|
+
|
|
667
|
+
# Filter by created_before
|
|
668
|
+
if created_before is not None and entry.created_at >= created_before:
|
|
669
|
+
continue
|
|
670
|
+
|
|
671
|
+
# Filter by question_contains (case-insensitive)
|
|
672
|
+
if question_contains is not None:
|
|
673
|
+
if question_contains.lower() not in entry.question.lower():
|
|
674
|
+
continue
|
|
675
|
+
|
|
676
|
+
filtered_entries.append(entry)
|
|
677
|
+
|
|
678
|
+
# Sort by created_at descending (most recent first)
|
|
679
|
+
filtered_entries.sort(key=lambda e: e.created_at, reverse=True)
|
|
680
|
+
|
|
681
|
+
# Apply offset
|
|
682
|
+
if offset is not None and offset > 0:
|
|
683
|
+
filtered_entries = filtered_entries[offset:]
|
|
684
|
+
|
|
685
|
+
# Apply limit
|
|
686
|
+
if limit is not None:
|
|
687
|
+
filtered_entries = filtered_entries[:limit]
|
|
688
|
+
|
|
689
|
+
# Convert to dicts
|
|
690
|
+
entries: list[dict[str, Any]] = []
|
|
691
|
+
for entry in filtered_entries:
|
|
692
|
+
result: dict[str, Any] = {
|
|
693
|
+
"id": None, # In-memory caches don't have database IDs
|
|
694
|
+
"question": entry.question,
|
|
695
|
+
"conversation_context": entry.conversation_context,
|
|
696
|
+
"sql_query": entry.sql_query,
|
|
697
|
+
"description": entry.description,
|
|
698
|
+
"conversation_id": entry.conversation_id,
|
|
699
|
+
"created_at": entry.created_at,
|
|
700
|
+
}
|
|
701
|
+
|
|
702
|
+
if include_embeddings:
|
|
703
|
+
result["question_embedding"] = entry.question_embedding
|
|
704
|
+
result["context_embedding"] = entry.context_embedding
|
|
705
|
+
|
|
706
|
+
entries.append(result)
|
|
707
|
+
|
|
708
|
+
logger.debug(
|
|
709
|
+
"Retrieved cache entries",
|
|
710
|
+
layer=self.name,
|
|
711
|
+
count=len(entries),
|
|
712
|
+
include_embeddings=include_embeddings,
|
|
713
|
+
filters={
|
|
714
|
+
"conversation_id": conversation_id,
|
|
715
|
+
"created_after": str(created_after) if created_after else None,
|
|
716
|
+
"created_before": str(created_before) if created_before else None,
|
|
717
|
+
"question_contains": question_contains,
|
|
718
|
+
},
|
|
719
|
+
)
|
|
863
720
|
|
|
864
|
-
return
|
|
865
|
-
"size": total,
|
|
866
|
-
"capacity": self.parameters.capacity,
|
|
867
|
-
"ttl_seconds": ttl.total_seconds() if ttl else None,
|
|
868
|
-
"similarity_threshold": self.similarity_threshold,
|
|
869
|
-
"expired_entries": expired,
|
|
870
|
-
"valid_entries": total - expired,
|
|
871
|
-
}
|
|
721
|
+
return entries
|