dao-ai 0.1.18__py3-none-any.whl → 0.1.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/config.py +99 -0
- dao_ai/genie/cache/__init__.py +2 -0
- dao_ai/genie/cache/core.py +1 -1
- dao_ai/genie/cache/in_memory_semantic.py +871 -0
- dao_ai/genie/cache/lru.py +15 -11
- dao_ai/genie/cache/semantic.py +52 -18
- dao_ai/tools/genie.py +28 -3
- {dao_ai-0.1.18.dist-info → dao_ai-0.1.19.dist-info}/METADATA +3 -2
- {dao_ai-0.1.18.dist-info → dao_ai-0.1.19.dist-info}/RECORD +12 -11
- {dao_ai-0.1.18.dist-info → dao_ai-0.1.19.dist-info}/WHEEL +0 -0
- {dao_ai-0.1.18.dist-info → dao_ai-0.1.19.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.1.18.dist-info → dao_ai-0.1.19.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,871 @@
|
|
|
1
|
+
"""
|
|
2
|
+
In-memory semantic cache implementation for Genie SQL queries.
|
|
3
|
+
|
|
4
|
+
This module provides a semantic cache that stores embeddings and cache entries
|
|
5
|
+
entirely in memory, without requiring external database dependencies like PostgreSQL
|
|
6
|
+
or Databricks Lakebase. It uses L2 distance for similarity search and supports
|
|
7
|
+
dual embedding matching (question + conversation context).
|
|
8
|
+
|
|
9
|
+
The cache supports conversation-aware embedding using a rolling window approach
|
|
10
|
+
to capture context from recent conversation turns, improving accuracy for
|
|
11
|
+
multi-turn conversations with anaphoric references.
|
|
12
|
+
|
|
13
|
+
Use this when:
|
|
14
|
+
- No external database access is available
|
|
15
|
+
- Single-instance deployments (cache not shared across instances)
|
|
16
|
+
- Cache persistence across restarts is not required
|
|
17
|
+
- Cache sizes are moderate (hundreds to low thousands of entries)
|
|
18
|
+
|
|
19
|
+
For multi-instance deployments or large cache sizes, use SemanticCacheService
|
|
20
|
+
with PostgreSQL backend instead.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
from __future__ import annotations
|
|
24
|
+
|
|
25
|
+
from dataclasses import dataclass
|
|
26
|
+
from datetime import datetime, timedelta
|
|
27
|
+
from threading import Lock
|
|
28
|
+
from typing import Any
|
|
29
|
+
|
|
30
|
+
import mlflow
|
|
31
|
+
import numpy as np
|
|
32
|
+
import pandas as pd
|
|
33
|
+
from databricks.sdk import WorkspaceClient
|
|
34
|
+
from databricks.sdk.service.sql import StatementResponse, StatementState
|
|
35
|
+
from databricks_ai_bridge.genie import GenieResponse
|
|
36
|
+
from loguru import logger
|
|
37
|
+
|
|
38
|
+
from dao_ai.config import (
|
|
39
|
+
GenieInMemorySemanticCacheParametersModel,
|
|
40
|
+
LLMModel,
|
|
41
|
+
WarehouseModel,
|
|
42
|
+
)
|
|
43
|
+
from dao_ai.genie.cache.base import (
|
|
44
|
+
CacheResult,
|
|
45
|
+
GenieServiceBase,
|
|
46
|
+
SQLCacheEntry,
|
|
47
|
+
)
|
|
48
|
+
from dao_ai.genie.cache.semantic import (
|
|
49
|
+
get_conversation_history,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@dataclass
|
|
54
|
+
class InMemoryCacheEntry:
|
|
55
|
+
"""
|
|
56
|
+
In-memory cache entry storing embeddings and SQL query metadata.
|
|
57
|
+
|
|
58
|
+
This dataclass represents a single cache entry stored in memory, including
|
|
59
|
+
dual embeddings (question + context) for high-precision semantic matching.
|
|
60
|
+
|
|
61
|
+
Uses LRU (Least Recently Used) eviction strategy when capacity is reached.
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
genie_space_id: str
|
|
65
|
+
question: str
|
|
66
|
+
conversation_context: str
|
|
67
|
+
question_embedding: list[float]
|
|
68
|
+
context_embedding: list[float]
|
|
69
|
+
sql_query: str
|
|
70
|
+
description: str
|
|
71
|
+
conversation_id: str
|
|
72
|
+
created_at: datetime
|
|
73
|
+
last_accessed_at: datetime # Track last access time for LRU eviction
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def l2_distance(a: list[float], b: list[float]) -> float:
|
|
77
|
+
"""
|
|
78
|
+
Calculate L2 (Euclidean) distance between two embedding vectors.
|
|
79
|
+
|
|
80
|
+
This uses the same distance metric as PostgreSQL pg_vector to ensure
|
|
81
|
+
consistent behavior between in-memory and PostgreSQL caches.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
a: First embedding vector
|
|
85
|
+
b: Second embedding vector
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
L2 distance (0 = identical vectors, larger = more different)
|
|
89
|
+
"""
|
|
90
|
+
return float(np.linalg.norm(np.array(a) - np.array(b)))
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def distance_to_similarity(distance: float) -> float:
|
|
94
|
+
"""
|
|
95
|
+
Convert L2 distance to similarity score in range [0, 1].
|
|
96
|
+
|
|
97
|
+
Uses the formula: similarity = 1.0 / (1.0 + distance)
|
|
98
|
+
This matches the conversion used by PostgreSQL semantic cache.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
distance: L2 distance value
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
Similarity score where 1.0 = perfect match, approaching 0 = very different
|
|
105
|
+
"""
|
|
106
|
+
return 1.0 / (1.0 + distance)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class InMemorySemanticCacheService(GenieServiceBase):
|
|
110
|
+
"""
|
|
111
|
+
In-memory semantic caching decorator using dual embeddings for similarity lookup.
|
|
112
|
+
|
|
113
|
+
This service caches the SQL query generated by Genie along with dual embeddings
|
|
114
|
+
(question + conversation context) for high-precision semantic matching. On
|
|
115
|
+
subsequent queries, it performs similarity search to find cached queries that
|
|
116
|
+
match both the question intent AND conversation context.
|
|
117
|
+
|
|
118
|
+
Cache entries are partitioned by genie_space_id to ensure queries from different
|
|
119
|
+
Genie spaces don't return incorrect cache hits.
|
|
120
|
+
|
|
121
|
+
On cache hit, it re-executes the cached SQL using the provided warehouse
|
|
122
|
+
to return fresh data while avoiding the Genie NL-to-SQL translation cost.
|
|
123
|
+
|
|
124
|
+
Example:
|
|
125
|
+
from dao_ai.config import GenieInMemorySemanticCacheParametersModel
|
|
126
|
+
from dao_ai.genie.cache import InMemorySemanticCacheService
|
|
127
|
+
|
|
128
|
+
cache_params = GenieInMemorySemanticCacheParametersModel(
|
|
129
|
+
warehouse=warehouse_model,
|
|
130
|
+
embedding_model="databricks-gte-large-en",
|
|
131
|
+
time_to_live_seconds=86400, # 24 hours
|
|
132
|
+
similarity_threshold=0.85,
|
|
133
|
+
capacity=1000, # Limit to 1000 entries
|
|
134
|
+
)
|
|
135
|
+
genie = InMemorySemanticCacheService(
|
|
136
|
+
impl=GenieService(Genie(space_id="my-space")),
|
|
137
|
+
parameters=cache_params,
|
|
138
|
+
workspace_client=workspace_client,
|
|
139
|
+
).initialize()
|
|
140
|
+
|
|
141
|
+
Thread-safe: Uses a lock to protect cache operations.
|
|
142
|
+
"""
|
|
143
|
+
|
|
144
|
+
impl: GenieServiceBase
|
|
145
|
+
parameters: GenieInMemorySemanticCacheParametersModel
|
|
146
|
+
workspace_client: WorkspaceClient | None
|
|
147
|
+
name: str
|
|
148
|
+
_embeddings: Any # DatabricksEmbeddings
|
|
149
|
+
_cache: list[InMemoryCacheEntry]
|
|
150
|
+
_lock: Lock
|
|
151
|
+
_embedding_dims: int | None
|
|
152
|
+
_setup_complete: bool
|
|
153
|
+
|
|
154
|
+
def __init__(
|
|
155
|
+
self,
|
|
156
|
+
impl: GenieServiceBase,
|
|
157
|
+
parameters: GenieInMemorySemanticCacheParametersModel,
|
|
158
|
+
workspace_client: WorkspaceClient | None = None,
|
|
159
|
+
name: str | None = None,
|
|
160
|
+
) -> None:
|
|
161
|
+
"""
|
|
162
|
+
Initialize the in-memory semantic cache service.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
impl: The underlying GenieServiceBase to delegate to on cache miss.
|
|
166
|
+
The space_id will be obtained from impl.space_id.
|
|
167
|
+
parameters: Cache configuration including warehouse, embedding model, and thresholds
|
|
168
|
+
workspace_client: Optional WorkspaceClient for retrieving conversation history.
|
|
169
|
+
If None, conversation context will not be used.
|
|
170
|
+
name: Name for this cache layer (for logging). Defaults to class name.
|
|
171
|
+
"""
|
|
172
|
+
self.impl = impl
|
|
173
|
+
self.parameters = parameters
|
|
174
|
+
self.workspace_client = workspace_client
|
|
175
|
+
self.name = name if name is not None else self.__class__.__name__
|
|
176
|
+
self._embeddings = None
|
|
177
|
+
self._cache = []
|
|
178
|
+
self._lock = Lock()
|
|
179
|
+
self._embedding_dims = None
|
|
180
|
+
self._setup_complete = False
|
|
181
|
+
|
|
182
|
+
def initialize(self) -> "InMemorySemanticCacheService":
|
|
183
|
+
"""
|
|
184
|
+
Eagerly initialize the cache service.
|
|
185
|
+
|
|
186
|
+
Call this during tool creation to:
|
|
187
|
+
- Validate configuration early (fail fast)
|
|
188
|
+
- Initialize embeddings model before any requests
|
|
189
|
+
- Avoid first-request latency from lazy initialization
|
|
190
|
+
|
|
191
|
+
Returns:
|
|
192
|
+
self for method chaining
|
|
193
|
+
"""
|
|
194
|
+
self._setup()
|
|
195
|
+
return self
|
|
196
|
+
|
|
197
|
+
def _setup(self) -> None:
|
|
198
|
+
"""Initialize embeddings model lazily."""
|
|
199
|
+
if self._setup_complete:
|
|
200
|
+
return
|
|
201
|
+
|
|
202
|
+
# Initialize embeddings
|
|
203
|
+
# Convert embedding_model to LLMModel if it's a string
|
|
204
|
+
embedding_model: LLMModel = (
|
|
205
|
+
LLMModel(name=self.parameters.embedding_model)
|
|
206
|
+
if isinstance(self.parameters.embedding_model, str)
|
|
207
|
+
else self.parameters.embedding_model
|
|
208
|
+
)
|
|
209
|
+
self._embeddings = embedding_model.as_embeddings_model()
|
|
210
|
+
|
|
211
|
+
# Auto-detect embedding dimensions if not provided
|
|
212
|
+
if self.parameters.embedding_dims is None:
|
|
213
|
+
sample_embedding: list[float] = self._embeddings.embed_query("test")
|
|
214
|
+
self._embedding_dims = len(sample_embedding)
|
|
215
|
+
logger.debug(
|
|
216
|
+
"Auto-detected embedding dimensions",
|
|
217
|
+
layer=self.name,
|
|
218
|
+
dims=self._embedding_dims,
|
|
219
|
+
)
|
|
220
|
+
else:
|
|
221
|
+
self._embedding_dims = self.parameters.embedding_dims
|
|
222
|
+
|
|
223
|
+
self._setup_complete = True
|
|
224
|
+
logger.debug(
|
|
225
|
+
"In-memory semantic cache initialized",
|
|
226
|
+
layer=self.name,
|
|
227
|
+
space_id=self.space_id,
|
|
228
|
+
dims=self._embedding_dims,
|
|
229
|
+
capacity=self.parameters.capacity,
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
@property
|
|
233
|
+
def warehouse(self) -> WarehouseModel:
|
|
234
|
+
"""The warehouse used for executing cached SQL queries."""
|
|
235
|
+
return self.parameters.warehouse
|
|
236
|
+
|
|
237
|
+
@property
|
|
238
|
+
def time_to_live(self) -> timedelta | None:
|
|
239
|
+
"""Time-to-live for cache entries. None means never expires."""
|
|
240
|
+
ttl = self.parameters.time_to_live_seconds
|
|
241
|
+
if ttl is None or ttl < 0:
|
|
242
|
+
return None
|
|
243
|
+
return timedelta(seconds=ttl)
|
|
244
|
+
|
|
245
|
+
@property
|
|
246
|
+
def similarity_threshold(self) -> float:
|
|
247
|
+
"""Minimum similarity for cache hit (using L2 distance converted to similarity)."""
|
|
248
|
+
return self.parameters.similarity_threshold
|
|
249
|
+
|
|
250
|
+
@property
|
|
251
|
+
def embedding_dims(self) -> int:
|
|
252
|
+
"""Dimension size for embeddings (auto-detected if not configured)."""
|
|
253
|
+
if self._embedding_dims is None:
|
|
254
|
+
raise RuntimeError(
|
|
255
|
+
"Embedding dimensions not yet initialized. Call _setup() first."
|
|
256
|
+
)
|
|
257
|
+
return self._embedding_dims
|
|
258
|
+
|
|
259
|
+
def _embed_question(
|
|
260
|
+
self, question: str, conversation_id: str | None = None
|
|
261
|
+
) -> tuple[list[float], list[float], str]:
|
|
262
|
+
"""
|
|
263
|
+
Generate dual embeddings: one for the question, one for the conversation context.
|
|
264
|
+
|
|
265
|
+
This enables separate matching of question similarity vs context similarity,
|
|
266
|
+
improving precision by ensuring both the question AND the conversation context
|
|
267
|
+
are semantically similar before returning a cached result.
|
|
268
|
+
|
|
269
|
+
Args:
|
|
270
|
+
question: The question to embed
|
|
271
|
+
conversation_id: Optional conversation ID for retrieving context
|
|
272
|
+
|
|
273
|
+
Returns:
|
|
274
|
+
Tuple of (question_embedding, context_embedding, conversation_context_string)
|
|
275
|
+
- question_embedding: Vector embedding of just the question
|
|
276
|
+
- context_embedding: Vector embedding of the conversation context (or zero vector if no context)
|
|
277
|
+
- conversation_context_string: The conversation context string (empty if no context)
|
|
278
|
+
"""
|
|
279
|
+
conversation_context = ""
|
|
280
|
+
|
|
281
|
+
# If conversation context is enabled and available
|
|
282
|
+
if (
|
|
283
|
+
self.workspace_client is not None
|
|
284
|
+
and conversation_id is not None
|
|
285
|
+
and self.parameters.context_window_size > 0
|
|
286
|
+
):
|
|
287
|
+
try:
|
|
288
|
+
# Retrieve conversation history
|
|
289
|
+
conversation_messages = get_conversation_history(
|
|
290
|
+
workspace_client=self.workspace_client,
|
|
291
|
+
space_id=self.space_id,
|
|
292
|
+
conversation_id=conversation_id,
|
|
293
|
+
max_messages=self.parameters.context_window_size
|
|
294
|
+
* 2, # Get extra for safety
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
# Build context string (just the "Previous:" messages, not the current question)
|
|
298
|
+
if conversation_messages:
|
|
299
|
+
recent_messages = (
|
|
300
|
+
conversation_messages[-self.parameters.context_window_size :]
|
|
301
|
+
if len(conversation_messages)
|
|
302
|
+
> self.parameters.context_window_size
|
|
303
|
+
else conversation_messages
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
context_parts: list[str] = []
|
|
307
|
+
for msg in recent_messages:
|
|
308
|
+
if msg.content:
|
|
309
|
+
content: str = msg.content
|
|
310
|
+
if len(content) > 500:
|
|
311
|
+
content = content[:500] + "..."
|
|
312
|
+
context_parts.append(f"Previous: {content}")
|
|
313
|
+
|
|
314
|
+
conversation_context = "\n".join(context_parts)
|
|
315
|
+
|
|
316
|
+
# Truncate if too long
|
|
317
|
+
estimated_tokens = len(conversation_context) / 4
|
|
318
|
+
if estimated_tokens > self.parameters.max_context_tokens:
|
|
319
|
+
target_chars = self.parameters.max_context_tokens * 4
|
|
320
|
+
conversation_context = (
|
|
321
|
+
conversation_context[:target_chars] + "..."
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
logger.trace(
|
|
325
|
+
"Using conversation context",
|
|
326
|
+
layer=self.name,
|
|
327
|
+
messages_count=len(conversation_messages),
|
|
328
|
+
window_size=self.parameters.context_window_size,
|
|
329
|
+
)
|
|
330
|
+
except Exception as e:
|
|
331
|
+
logger.warning(
|
|
332
|
+
"Failed to build conversation context, using question only",
|
|
333
|
+
layer=self.name,
|
|
334
|
+
error=str(e),
|
|
335
|
+
)
|
|
336
|
+
conversation_context = ""
|
|
337
|
+
|
|
338
|
+
# Generate dual embeddings
|
|
339
|
+
if conversation_context:
|
|
340
|
+
# Embed both question and context
|
|
341
|
+
embeddings: list[list[float]] = self._embeddings.embed_documents(
|
|
342
|
+
[question, conversation_context]
|
|
343
|
+
)
|
|
344
|
+
question_embedding = embeddings[0]
|
|
345
|
+
context_embedding = embeddings[1]
|
|
346
|
+
else:
|
|
347
|
+
# Only embed question, use zero vector for context
|
|
348
|
+
embeddings = self._embeddings.embed_documents([question])
|
|
349
|
+
question_embedding = embeddings[0]
|
|
350
|
+
context_embedding = [0.0] * len(question_embedding) # Zero vector
|
|
351
|
+
|
|
352
|
+
return question_embedding, context_embedding, conversation_context
|
|
353
|
+
|
|
354
|
+
@mlflow.trace(name="semantic_search_in_memory")
|
|
355
|
+
def _find_similar(
|
|
356
|
+
self,
|
|
357
|
+
question: str,
|
|
358
|
+
conversation_context: str,
|
|
359
|
+
question_embedding: list[float],
|
|
360
|
+
context_embedding: list[float],
|
|
361
|
+
conversation_id: str | None = None,
|
|
362
|
+
) -> tuple[SQLCacheEntry, float] | None:
|
|
363
|
+
"""
|
|
364
|
+
Find a semantically similar cached entry using dual embedding matching.
|
|
365
|
+
|
|
366
|
+
This method matches BOTH the question AND the conversation context separately,
|
|
367
|
+
ensuring high precision by requiring both to be semantically similar.
|
|
368
|
+
|
|
369
|
+
Performs linear scan through all cache entries, filtering by space_id and
|
|
370
|
+
calculating L2 distances for similarity matching.
|
|
371
|
+
|
|
372
|
+
Args:
|
|
373
|
+
question: The original question (for logging)
|
|
374
|
+
conversation_context: The conversation context string
|
|
375
|
+
question_embedding: The embedding vector of just the question
|
|
376
|
+
context_embedding: The embedding vector of the conversation context
|
|
377
|
+
conversation_id: Optional conversation ID (for logging)
|
|
378
|
+
|
|
379
|
+
Returns:
|
|
380
|
+
Tuple of (SQLCacheEntry, combined_similarity_score) if found, None otherwise
|
|
381
|
+
"""
|
|
382
|
+
ttl_seconds = self.parameters.time_to_live_seconds
|
|
383
|
+
ttl_disabled = ttl_seconds is None or ttl_seconds < 0
|
|
384
|
+
|
|
385
|
+
question_weight: float = self.parameters.question_weight
|
|
386
|
+
context_weight: float = self.parameters.context_weight
|
|
387
|
+
|
|
388
|
+
best_entry: InMemoryCacheEntry | None = None
|
|
389
|
+
best_question_sim: float = 0.0
|
|
390
|
+
best_context_sim: float = 0.0
|
|
391
|
+
best_combined_sim: float = 0.0
|
|
392
|
+
|
|
393
|
+
# Linear scan through all entries
|
|
394
|
+
with self._lock:
|
|
395
|
+
entries_to_delete: list[int] = []
|
|
396
|
+
|
|
397
|
+
for idx, entry in enumerate(self._cache):
|
|
398
|
+
# Filter by space_id (partition)
|
|
399
|
+
if entry.genie_space_id != self.space_id:
|
|
400
|
+
continue
|
|
401
|
+
|
|
402
|
+
# Check TTL
|
|
403
|
+
is_valid = True
|
|
404
|
+
if not ttl_disabled:
|
|
405
|
+
age = datetime.now() - entry.created_at
|
|
406
|
+
is_valid = age.total_seconds() <= ttl_seconds
|
|
407
|
+
|
|
408
|
+
if not is_valid:
|
|
409
|
+
# Mark for deletion
|
|
410
|
+
entries_to_delete.append(idx)
|
|
411
|
+
continue
|
|
412
|
+
|
|
413
|
+
# Calculate L2 distances and convert to similarities
|
|
414
|
+
question_distance = l2_distance(
|
|
415
|
+
question_embedding, entry.question_embedding
|
|
416
|
+
)
|
|
417
|
+
context_distance = l2_distance(
|
|
418
|
+
context_embedding, entry.context_embedding
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
question_sim = distance_to_similarity(question_distance)
|
|
422
|
+
context_sim = distance_to_similarity(context_distance)
|
|
423
|
+
|
|
424
|
+
# Calculate weighted combined similarity
|
|
425
|
+
combined_sim = (question_weight * question_sim) + (
|
|
426
|
+
context_weight * context_sim
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
# Track best match
|
|
430
|
+
if combined_sim > best_combined_sim:
|
|
431
|
+
best_entry = entry
|
|
432
|
+
best_question_sim = question_sim
|
|
433
|
+
best_context_sim = context_sim
|
|
434
|
+
best_combined_sim = combined_sim
|
|
435
|
+
|
|
436
|
+
# Delete expired entries
|
|
437
|
+
for idx in reversed(entries_to_delete):
|
|
438
|
+
del self._cache[idx]
|
|
439
|
+
logger.trace(
|
|
440
|
+
"Deleted expired entry",
|
|
441
|
+
layer=self.name,
|
|
442
|
+
index=idx,
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
# No entries found
|
|
446
|
+
if best_entry is None:
|
|
447
|
+
logger.info(
|
|
448
|
+
"Cache MISS (no entries)",
|
|
449
|
+
layer=self.name,
|
|
450
|
+
question=question[:50],
|
|
451
|
+
space=self.space_id,
|
|
452
|
+
delegating_to=type(self.impl).__name__,
|
|
453
|
+
)
|
|
454
|
+
return None
|
|
455
|
+
|
|
456
|
+
# Log best match info
|
|
457
|
+
logger.debug(
|
|
458
|
+
"Best match found",
|
|
459
|
+
layer=self.name,
|
|
460
|
+
question_sim=f"{best_question_sim:.4f}",
|
|
461
|
+
context_sim=f"{best_context_sim:.4f}",
|
|
462
|
+
combined_sim=f"{best_combined_sim:.4f}",
|
|
463
|
+
cached_question=best_entry.question[:50],
|
|
464
|
+
cached_context=best_entry.conversation_context[:80],
|
|
465
|
+
)
|
|
466
|
+
|
|
467
|
+
# Check BOTH similarity thresholds (dual embedding precision check)
|
|
468
|
+
if best_question_sim < self.parameters.similarity_threshold:
|
|
469
|
+
logger.info(
|
|
470
|
+
"Cache MISS (question similarity too low)",
|
|
471
|
+
layer=self.name,
|
|
472
|
+
question_sim=f"{best_question_sim:.4f}",
|
|
473
|
+
threshold=self.parameters.similarity_threshold,
|
|
474
|
+
delegating_to=type(self.impl).__name__,
|
|
475
|
+
)
|
|
476
|
+
return None
|
|
477
|
+
|
|
478
|
+
if best_context_sim < self.parameters.context_similarity_threshold:
|
|
479
|
+
logger.info(
|
|
480
|
+
"Cache MISS (context similarity too low)",
|
|
481
|
+
layer=self.name,
|
|
482
|
+
context_sim=f"{best_context_sim:.4f}",
|
|
483
|
+
threshold=self.parameters.context_similarity_threshold,
|
|
484
|
+
delegating_to=type(self.impl).__name__,
|
|
485
|
+
)
|
|
486
|
+
return None
|
|
487
|
+
|
|
488
|
+
# Cache HIT!
|
|
489
|
+
# Update last accessed time for LRU eviction
|
|
490
|
+
with self._lock:
|
|
491
|
+
best_entry.last_accessed_at = datetime.now()
|
|
492
|
+
|
|
493
|
+
cache_age_seconds = (datetime.now() - best_entry.created_at).total_seconds()
|
|
494
|
+
logger.info(
|
|
495
|
+
"Cache HIT",
|
|
496
|
+
layer=self.name,
|
|
497
|
+
question=question[:80],
|
|
498
|
+
conversation_id=conversation_id,
|
|
499
|
+
matched_question=best_entry.question[:80],
|
|
500
|
+
cache_age_seconds=round(cache_age_seconds, 1),
|
|
501
|
+
question_similarity=f"{best_question_sim:.4f}",
|
|
502
|
+
context_similarity=f"{best_context_sim:.4f}",
|
|
503
|
+
combined_similarity=f"{best_combined_sim:.4f}",
|
|
504
|
+
cached_sql=best_entry.sql_query[:80] if best_entry.sql_query else None,
|
|
505
|
+
ttl_seconds=self.parameters.time_to_live_seconds,
|
|
506
|
+
)
|
|
507
|
+
|
|
508
|
+
cache_entry = SQLCacheEntry(
|
|
509
|
+
query=best_entry.sql_query,
|
|
510
|
+
description=best_entry.description,
|
|
511
|
+
conversation_id=best_entry.conversation_id,
|
|
512
|
+
created_at=best_entry.created_at,
|
|
513
|
+
)
|
|
514
|
+
return cache_entry, best_combined_sim
|
|
515
|
+
|
|
516
|
+
def _store_entry(
|
|
517
|
+
self,
|
|
518
|
+
question: str,
|
|
519
|
+
conversation_context: str,
|
|
520
|
+
question_embedding: list[float],
|
|
521
|
+
context_embedding: list[float],
|
|
522
|
+
response: GenieResponse,
|
|
523
|
+
) -> None:
|
|
524
|
+
"""
|
|
525
|
+
Store a new cache entry with dual embeddings for this Genie space.
|
|
526
|
+
|
|
527
|
+
If capacity is set and reached, evicts least recently used entry (LRU).
|
|
528
|
+
"""
|
|
529
|
+
now = datetime.now()
|
|
530
|
+
new_entry = InMemoryCacheEntry(
|
|
531
|
+
genie_space_id=self.space_id,
|
|
532
|
+
question=question,
|
|
533
|
+
conversation_context=conversation_context,
|
|
534
|
+
question_embedding=question_embedding,
|
|
535
|
+
context_embedding=context_embedding,
|
|
536
|
+
sql_query=response.query,
|
|
537
|
+
description=response.description,
|
|
538
|
+
conversation_id=response.conversation_id,
|
|
539
|
+
created_at=now,
|
|
540
|
+
last_accessed_at=now, # Initialize to now; updated on cache hits (traditional LRU)
|
|
541
|
+
)
|
|
542
|
+
|
|
543
|
+
with self._lock:
|
|
544
|
+
# Enforce capacity limit (LRU eviction)
|
|
545
|
+
if self.parameters.capacity is not None:
|
|
546
|
+
# Count entries for this space_id
|
|
547
|
+
space_entries = [
|
|
548
|
+
e for e in self._cache if e.genie_space_id == self.space_id
|
|
549
|
+
]
|
|
550
|
+
|
|
551
|
+
while len(space_entries) >= self.parameters.capacity:
|
|
552
|
+
# Find and remove least recently used entry for this space
|
|
553
|
+
lru_idx = None
|
|
554
|
+
lru_time = None
|
|
555
|
+
|
|
556
|
+
for idx, entry in enumerate(self._cache):
|
|
557
|
+
if entry.genie_space_id == self.space_id:
|
|
558
|
+
if lru_time is None or entry.last_accessed_at < lru_time:
|
|
559
|
+
lru_time = entry.last_accessed_at
|
|
560
|
+
lru_idx = idx
|
|
561
|
+
|
|
562
|
+
if lru_idx is not None:
|
|
563
|
+
evicted = self._cache.pop(lru_idx)
|
|
564
|
+
logger.trace(
|
|
565
|
+
"Evicted LRU cache entry",
|
|
566
|
+
layer=self.name,
|
|
567
|
+
question=evicted.question[:50],
|
|
568
|
+
capacity=self.parameters.capacity,
|
|
569
|
+
)
|
|
570
|
+
space_entries = [
|
|
571
|
+
e for e in self._cache if e.genie_space_id == self.space_id
|
|
572
|
+
]
|
|
573
|
+
else:
|
|
574
|
+
break
|
|
575
|
+
|
|
576
|
+
self._cache.append(new_entry)
|
|
577
|
+
logger.debug(
|
|
578
|
+
"Stored cache entry",
|
|
579
|
+
layer=self.name,
|
|
580
|
+
question=question[:50],
|
|
581
|
+
context=conversation_context[:80],
|
|
582
|
+
sql=response.query[:50] if response.query else None,
|
|
583
|
+
space=self.space_id,
|
|
584
|
+
cache_size=len(
|
|
585
|
+
[e for e in self._cache if e.genie_space_id == self.space_id]
|
|
586
|
+
),
|
|
587
|
+
capacity=self.parameters.capacity,
|
|
588
|
+
)
|
|
589
|
+
|
|
590
|
+
@mlflow.trace(name="execute_cached_sql_in_memory_semantic")
|
|
591
|
+
def _execute_sql(self, sql: str) -> pd.DataFrame | str:
|
|
592
|
+
"""Execute SQL using the warehouse and return results."""
|
|
593
|
+
client: WorkspaceClient = self.warehouse.workspace_client
|
|
594
|
+
warehouse_id: str = str(self.warehouse.warehouse_id)
|
|
595
|
+
|
|
596
|
+
statement_response: StatementResponse = (
|
|
597
|
+
client.statement_execution.execute_statement(
|
|
598
|
+
warehouse_id=warehouse_id,
|
|
599
|
+
statement=sql,
|
|
600
|
+
wait_timeout="30s",
|
|
601
|
+
)
|
|
602
|
+
)
|
|
603
|
+
|
|
604
|
+
if (
|
|
605
|
+
statement_response.status is not None
|
|
606
|
+
and statement_response.status.state != StatementState.SUCCEEDED
|
|
607
|
+
):
|
|
608
|
+
error_msg: str = (
|
|
609
|
+
f"SQL execution failed: {statement_response.status.error.message}"
|
|
610
|
+
if statement_response.status.error is not None
|
|
611
|
+
else f"SQL execution failed with state: {statement_response.status.state}"
|
|
612
|
+
)
|
|
613
|
+
logger.error("SQL execution failed", layer=self.name, error=error_msg)
|
|
614
|
+
return error_msg
|
|
615
|
+
|
|
616
|
+
if statement_response.result and statement_response.result.data_array:
|
|
617
|
+
columns: list[str] = []
|
|
618
|
+
if (
|
|
619
|
+
statement_response.manifest
|
|
620
|
+
and statement_response.manifest.schema
|
|
621
|
+
and statement_response.manifest.schema.columns
|
|
622
|
+
):
|
|
623
|
+
columns = [
|
|
624
|
+
col.name
|
|
625
|
+
for col in statement_response.manifest.schema.columns
|
|
626
|
+
if col.name is not None
|
|
627
|
+
]
|
|
628
|
+
|
|
629
|
+
data: list[list[Any]] = statement_response.result.data_array
|
|
630
|
+
if columns:
|
|
631
|
+
return pd.DataFrame(data, columns=columns)
|
|
632
|
+
else:
|
|
633
|
+
return pd.DataFrame(data)
|
|
634
|
+
|
|
635
|
+
return pd.DataFrame()
|
|
636
|
+
|
|
637
|
+
def ask_question(
|
|
638
|
+
self, question: str, conversation_id: str | None = None
|
|
639
|
+
) -> CacheResult:
|
|
640
|
+
"""
|
|
641
|
+
Ask a question, using semantic cache if a similar query exists.
|
|
642
|
+
|
|
643
|
+
On cache hit, re-executes the cached SQL to get fresh data.
|
|
644
|
+
Returns CacheResult with cache metadata.
|
|
645
|
+
"""
|
|
646
|
+
return self.ask_question_with_cache_info(question, conversation_id)
|
|
647
|
+
|
|
648
|
+
@mlflow.trace(name="genie_in_memory_semantic_cache_lookup")
|
|
649
|
+
def ask_question_with_cache_info(
|
|
650
|
+
self,
|
|
651
|
+
question: str,
|
|
652
|
+
conversation_id: str | None = None,
|
|
653
|
+
) -> CacheResult:
|
|
654
|
+
"""
|
|
655
|
+
Ask a question with detailed cache hit information.
|
|
656
|
+
|
|
657
|
+
On cache hit, the cached SQL is re-executed to return fresh data, but the
|
|
658
|
+
conversation_id returned is the current conversation_id (not the cached one).
|
|
659
|
+
|
|
660
|
+
Args:
|
|
661
|
+
question: The question to ask
|
|
662
|
+
conversation_id: Optional conversation ID for context and continuation
|
|
663
|
+
|
|
664
|
+
Returns:
|
|
665
|
+
CacheResult with fresh response and cache metadata
|
|
666
|
+
"""
|
|
667
|
+
# Ensure initialization (lazy init if initialize() wasn't called)
|
|
668
|
+
self._setup()
|
|
669
|
+
|
|
670
|
+
# Generate dual embeddings for the question and conversation context
|
|
671
|
+
question_embedding: list[float]
|
|
672
|
+
context_embedding: list[float]
|
|
673
|
+
conversation_context: str
|
|
674
|
+
question_embedding, context_embedding, conversation_context = (
|
|
675
|
+
self._embed_question(question, conversation_id)
|
|
676
|
+
)
|
|
677
|
+
|
|
678
|
+
# Check cache using dual embedding similarity
|
|
679
|
+
cache_result: tuple[SQLCacheEntry, float] | None = self._find_similar(
|
|
680
|
+
question,
|
|
681
|
+
conversation_context,
|
|
682
|
+
question_embedding,
|
|
683
|
+
context_embedding,
|
|
684
|
+
conversation_id,
|
|
685
|
+
)
|
|
686
|
+
|
|
687
|
+
if cache_result is not None:
|
|
688
|
+
cached, combined_similarity = cache_result
|
|
689
|
+
logger.debug(
|
|
690
|
+
"In-memory semantic cache hit",
|
|
691
|
+
layer=self.name,
|
|
692
|
+
combined_similarity=f"{combined_similarity:.3f}",
|
|
693
|
+
question=question[:50],
|
|
694
|
+
conversation_id=conversation_id,
|
|
695
|
+
)
|
|
696
|
+
|
|
697
|
+
# Re-execute the cached SQL to get fresh data
|
|
698
|
+
result: pd.DataFrame | str = self._execute_sql(cached.query)
|
|
699
|
+
|
|
700
|
+
# IMPORTANT: Use the current conversation_id (from the request), not the cached one
|
|
701
|
+
# This ensures the conversation continues properly
|
|
702
|
+
response: GenieResponse = GenieResponse(
|
|
703
|
+
result=result,
|
|
704
|
+
query=cached.query,
|
|
705
|
+
description=cached.description,
|
|
706
|
+
conversation_id=conversation_id
|
|
707
|
+
if conversation_id
|
|
708
|
+
else cached.conversation_id,
|
|
709
|
+
)
|
|
710
|
+
|
|
711
|
+
return CacheResult(response=response, cache_hit=True, served_by=self.name)
|
|
712
|
+
|
|
713
|
+
# Cache miss - delegate to wrapped service
|
|
714
|
+
logger.info(
|
|
715
|
+
"Cache MISS",
|
|
716
|
+
layer=self.name,
|
|
717
|
+
question=question[:80],
|
|
718
|
+
conversation_id=conversation_id,
|
|
719
|
+
space_id=self.space_id,
|
|
720
|
+
similarity_threshold=self.similarity_threshold,
|
|
721
|
+
delegating_to=type(self.impl).__name__,
|
|
722
|
+
)
|
|
723
|
+
|
|
724
|
+
result: CacheResult = self.impl.ask_question(question, conversation_id)
|
|
725
|
+
|
|
726
|
+
# Store in cache if we got a SQL query
|
|
727
|
+
if result.response.query:
|
|
728
|
+
logger.debug(
|
|
729
|
+
"Storing new cache entry",
|
|
730
|
+
layer=self.name,
|
|
731
|
+
question=question[:50],
|
|
732
|
+
conversation_id=conversation_id,
|
|
733
|
+
space=self.space_id,
|
|
734
|
+
)
|
|
735
|
+
self._store_entry(
|
|
736
|
+
question,
|
|
737
|
+
conversation_context,
|
|
738
|
+
question_embedding,
|
|
739
|
+
context_embedding,
|
|
740
|
+
result.response,
|
|
741
|
+
)
|
|
742
|
+
elif not result.response.query:
|
|
743
|
+
logger.warning(
|
|
744
|
+
"Not caching: response has no SQL query",
|
|
745
|
+
layer=self.name,
|
|
746
|
+
question=question[:50],
|
|
747
|
+
)
|
|
748
|
+
|
|
749
|
+
return CacheResult(response=result.response, cache_hit=False, served_by=None)
|
|
750
|
+
|
|
751
|
+
@property
|
|
752
|
+
def space_id(self) -> str:
|
|
753
|
+
return self.impl.space_id
|
|
754
|
+
|
|
755
|
+
def invalidate_expired(self) -> int:
|
|
756
|
+
"""
|
|
757
|
+
Remove expired entries from the cache for this Genie space.
|
|
758
|
+
|
|
759
|
+
Returns 0 if TTL is disabled (entries never expire).
|
|
760
|
+
"""
|
|
761
|
+
self._setup()
|
|
762
|
+
ttl_seconds = self.parameters.time_to_live_seconds
|
|
763
|
+
|
|
764
|
+
# If TTL is disabled, nothing can expire
|
|
765
|
+
if ttl_seconds is None or ttl_seconds < 0:
|
|
766
|
+
logger.trace(
|
|
767
|
+
"TTL disabled, no entries to expire",
|
|
768
|
+
layer=self.name,
|
|
769
|
+
space=self.space_id,
|
|
770
|
+
)
|
|
771
|
+
return 0
|
|
772
|
+
|
|
773
|
+
deleted = 0
|
|
774
|
+
with self._lock:
|
|
775
|
+
indices_to_delete: list[int] = []
|
|
776
|
+
now = datetime.now()
|
|
777
|
+
|
|
778
|
+
for idx, entry in enumerate(self._cache):
|
|
779
|
+
if entry.genie_space_id != self.space_id:
|
|
780
|
+
continue
|
|
781
|
+
|
|
782
|
+
age = now - entry.created_at
|
|
783
|
+
if age.total_seconds() > ttl_seconds:
|
|
784
|
+
indices_to_delete.append(idx)
|
|
785
|
+
|
|
786
|
+
# Delete in reverse order to preserve indices
|
|
787
|
+
for idx in reversed(indices_to_delete):
|
|
788
|
+
del self._cache[idx]
|
|
789
|
+
deleted += 1
|
|
790
|
+
|
|
791
|
+
logger.trace(
|
|
792
|
+
"Deleted expired entries",
|
|
793
|
+
layer=self.name,
|
|
794
|
+
deleted_count=deleted,
|
|
795
|
+
space=self.space_id,
|
|
796
|
+
)
|
|
797
|
+
|
|
798
|
+
return deleted
|
|
799
|
+
|
|
800
|
+
def clear(self) -> int:
|
|
801
|
+
"""Clear all entries from the cache for this Genie space."""
|
|
802
|
+
self._setup()
|
|
803
|
+
deleted = 0
|
|
804
|
+
|
|
805
|
+
with self._lock:
|
|
806
|
+
# Find indices for this space
|
|
807
|
+
indices_to_delete: list[int] = []
|
|
808
|
+
for idx, entry in enumerate(self._cache):
|
|
809
|
+
if entry.genie_space_id == self.space_id:
|
|
810
|
+
indices_to_delete.append(idx)
|
|
811
|
+
|
|
812
|
+
# Delete in reverse order
|
|
813
|
+
for idx in reversed(indices_to_delete):
|
|
814
|
+
del self._cache[idx]
|
|
815
|
+
deleted += 1
|
|
816
|
+
|
|
817
|
+
logger.debug(
|
|
818
|
+
"Cleared cache entries",
|
|
819
|
+
layer=self.name,
|
|
820
|
+
deleted_count=deleted,
|
|
821
|
+
space=self.space_id,
|
|
822
|
+
)
|
|
823
|
+
|
|
824
|
+
return deleted
|
|
825
|
+
|
|
826
|
+
@property
|
|
827
|
+
def size(self) -> int:
|
|
828
|
+
"""Current number of entries in the cache for this Genie space."""
|
|
829
|
+
self._setup()
|
|
830
|
+
with self._lock:
|
|
831
|
+
return len([e for e in self._cache if e.genie_space_id == self.space_id])
|
|
832
|
+
|
|
833
|
+
def stats(self) -> dict[str, int | float | None]:
|
|
834
|
+
"""Return cache statistics for this Genie space."""
|
|
835
|
+
self._setup()
|
|
836
|
+
ttl_seconds = self.parameters.time_to_live_seconds
|
|
837
|
+
ttl = self.time_to_live
|
|
838
|
+
|
|
839
|
+
with self._lock:
|
|
840
|
+
space_entries = [
|
|
841
|
+
e for e in self._cache if e.genie_space_id == self.space_id
|
|
842
|
+
]
|
|
843
|
+
total = len(space_entries)
|
|
844
|
+
|
|
845
|
+
# If TTL is disabled, all entries are valid
|
|
846
|
+
if ttl_seconds is None or ttl_seconds < 0:
|
|
847
|
+
return {
|
|
848
|
+
"size": total,
|
|
849
|
+
"capacity": self.parameters.capacity,
|
|
850
|
+
"ttl_seconds": None,
|
|
851
|
+
"similarity_threshold": self.similarity_threshold,
|
|
852
|
+
"expired_entries": 0,
|
|
853
|
+
"valid_entries": total,
|
|
854
|
+
}
|
|
855
|
+
|
|
856
|
+
# Count expired entries
|
|
857
|
+
now = datetime.now()
|
|
858
|
+
expired = 0
|
|
859
|
+
for entry in space_entries:
|
|
860
|
+
age = now - entry.created_at
|
|
861
|
+
if age.total_seconds() > ttl_seconds:
|
|
862
|
+
expired += 1
|
|
863
|
+
|
|
864
|
+
return {
|
|
865
|
+
"size": total,
|
|
866
|
+
"capacity": self.parameters.capacity,
|
|
867
|
+
"ttl_seconds": ttl.total_seconds() if ttl else None,
|
|
868
|
+
"similarity_threshold": self.similarity_threshold,
|
|
869
|
+
"expired_entries": expired,
|
|
870
|
+
"valid_entries": total - expired,
|
|
871
|
+
}
|