dao-ai 0.0.35__py3-none-any.whl → 0.0.36__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/config.py +29 -0
- dao_ai/genie/__init__.py +59 -0
- dao_ai/genie/cache/__init__.py +44 -0
- dao_ai/genie/cache/base.py +122 -0
- dao_ai/genie/cache/lru.py +306 -0
- dao_ai/genie/cache/semantic.py +638 -0
- dao_ai/tools/__init__.py +3 -0
- dao_ai/tools/genie/__init__.py +236 -0
- dao_ai/tools/genie.py +65 -15
- dao_ai-0.0.36.dist-info/METADATA +951 -0
- {dao_ai-0.0.35.dist-info → dao_ai-0.0.36.dist-info}/RECORD +14 -8
- dao_ai-0.0.35.dist-info/METADATA +0 -1169
- {dao_ai-0.0.35.dist-info → dao_ai-0.0.36.dist-info}/WHEEL +0 -0
- {dao_ai-0.0.35.dist-info → dao_ai-0.0.36.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.35.dist-info → dao_ai-0.0.36.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,638 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Semantic cache implementation for Genie SQL queries using PostgreSQL pg_vector.
|
|
3
|
+
|
|
4
|
+
This module provides a semantic cache that uses embeddings and similarity search
|
|
5
|
+
to find cached queries that match the intent of new questions. Cache entries are
|
|
6
|
+
partitioned by genie_space_id to ensure proper isolation between Genie spaces.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from datetime import timedelta
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
import mlflow
|
|
13
|
+
import pandas as pd
|
|
14
|
+
from databricks.sdk import WorkspaceClient
|
|
15
|
+
from databricks.sdk.service.sql import StatementResponse, StatementState
|
|
16
|
+
from databricks_ai_bridge.genie import GenieResponse
|
|
17
|
+
from loguru import logger
|
|
18
|
+
from mlflow.entities import SpanType
|
|
19
|
+
|
|
20
|
+
from dao_ai.config import (
|
|
21
|
+
DatabaseModel,
|
|
22
|
+
GenieSemanticCacheParametersModel,
|
|
23
|
+
WarehouseModel,
|
|
24
|
+
)
|
|
25
|
+
from dao_ai.genie.cache.base import (
|
|
26
|
+
CacheResult,
|
|
27
|
+
GenieServiceBase,
|
|
28
|
+
SQLCacheEntry,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
# Type alias for database row (dict due to row_factory=dict_row)
|
|
32
|
+
DbRow = dict[str, Any]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class SemanticCacheService(GenieServiceBase):
|
|
36
|
+
"""
|
|
37
|
+
Semantic caching decorator that uses PostgreSQL pg_vector for similarity lookup.
|
|
38
|
+
|
|
39
|
+
This service caches the SQL query generated by Genie along with an embedding
|
|
40
|
+
of the original question. On subsequent queries, it performs a semantic similarity
|
|
41
|
+
search to find cached queries that match the intent of the new question.
|
|
42
|
+
|
|
43
|
+
Cache entries are partitioned by genie_space_id to ensure queries from different
|
|
44
|
+
Genie spaces don't return incorrect cache hits.
|
|
45
|
+
|
|
46
|
+
On cache hit, it re-executes the cached SQL using the provided warehouse
|
|
47
|
+
to return fresh data while avoiding the Genie NL-to-SQL translation cost.
|
|
48
|
+
|
|
49
|
+
Example:
|
|
50
|
+
from dao_ai.config import GenieSemanticCacheParametersModel, DatabaseModel
|
|
51
|
+
from dao_ai.genie.cache import SemanticCacheService
|
|
52
|
+
|
|
53
|
+
cache_params = GenieSemanticCacheParametersModel(
|
|
54
|
+
database=database_model,
|
|
55
|
+
warehouse=warehouse_model,
|
|
56
|
+
embedding_model="databricks-gte-large-en",
|
|
57
|
+
time_to_live_seconds=86400, # 24 hours
|
|
58
|
+
similarity_threshold=0.85
|
|
59
|
+
)
|
|
60
|
+
genie = SemanticCacheService(
|
|
61
|
+
impl=GenieService(Genie(space_id="my-space")),
|
|
62
|
+
parameters=cache_params,
|
|
63
|
+
genie_space_id="my-space"
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
Thread-safe: Uses connection pooling from psycopg_pool.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
impl: GenieServiceBase
|
|
70
|
+
parameters: GenieSemanticCacheParametersModel
|
|
71
|
+
genie_space_id: str
|
|
72
|
+
name: str
|
|
73
|
+
_embeddings: Any # DatabricksEmbeddings
|
|
74
|
+
_pool: Any # ConnectionPool
|
|
75
|
+
_embedding_dims: int | None
|
|
76
|
+
_setup_complete: bool
|
|
77
|
+
|
|
78
|
+
def __init__(
|
|
79
|
+
self,
|
|
80
|
+
impl: GenieServiceBase,
|
|
81
|
+
parameters: GenieSemanticCacheParametersModel,
|
|
82
|
+
genie_space_id: str,
|
|
83
|
+
name: str | None = None,
|
|
84
|
+
) -> None:
|
|
85
|
+
"""
|
|
86
|
+
Initialize the semantic cache service.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
impl: The underlying GenieServiceBase to delegate to on cache miss
|
|
90
|
+
parameters: Cache configuration including database, warehouse, embedding model
|
|
91
|
+
genie_space_id: The Genie space ID for partitioning cache entries
|
|
92
|
+
name: Name for this cache layer (for logging). Defaults to class name.
|
|
93
|
+
"""
|
|
94
|
+
self.impl = impl
|
|
95
|
+
self.parameters = parameters
|
|
96
|
+
self.genie_space_id = genie_space_id
|
|
97
|
+
self.name = name if name is not None else self.__class__.__name__
|
|
98
|
+
self._embeddings = None
|
|
99
|
+
self._pool = None
|
|
100
|
+
self._embedding_dims = None
|
|
101
|
+
self._setup_complete = False
|
|
102
|
+
|
|
103
|
+
def initialize(self) -> "SemanticCacheService":
|
|
104
|
+
"""
|
|
105
|
+
Eagerly initialize the cache service.
|
|
106
|
+
|
|
107
|
+
Call this during tool creation to:
|
|
108
|
+
- Validate configuration early (fail fast)
|
|
109
|
+
- Create the database table before any requests
|
|
110
|
+
- Avoid first-request latency from lazy initialization
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
self for method chaining
|
|
114
|
+
"""
|
|
115
|
+
self._setup()
|
|
116
|
+
return self
|
|
117
|
+
|
|
118
|
+
def _setup(self) -> None:
|
|
119
|
+
"""Initialize embeddings and database connection pool lazily."""
|
|
120
|
+
if self._setup_complete:
|
|
121
|
+
return
|
|
122
|
+
|
|
123
|
+
from databricks_langchain import DatabricksEmbeddings
|
|
124
|
+
|
|
125
|
+
from dao_ai.memory.postgres import PostgresPoolManager
|
|
126
|
+
|
|
127
|
+
# Initialize embeddings
|
|
128
|
+
embedding_model: str = (
|
|
129
|
+
self.parameters.embedding_model
|
|
130
|
+
if isinstance(self.parameters.embedding_model, str)
|
|
131
|
+
else self.parameters.embedding_model.name
|
|
132
|
+
)
|
|
133
|
+
self._embeddings = DatabricksEmbeddings(endpoint=embedding_model)
|
|
134
|
+
|
|
135
|
+
# Auto-detect embedding dimensions if not provided
|
|
136
|
+
if self.parameters.embedding_dims is None:
|
|
137
|
+
sample_embedding: list[float] = self._embeddings.embed_query("test")
|
|
138
|
+
self._embedding_dims = len(sample_embedding)
|
|
139
|
+
logger.debug(
|
|
140
|
+
f"[{self.name}] Auto-detected embedding dimensions: {self._embedding_dims}"
|
|
141
|
+
)
|
|
142
|
+
else:
|
|
143
|
+
self._embedding_dims = self.parameters.embedding_dims
|
|
144
|
+
|
|
145
|
+
# Get connection pool
|
|
146
|
+
self._pool = PostgresPoolManager.get_pool(self.parameters.database)
|
|
147
|
+
|
|
148
|
+
# Ensure table exists
|
|
149
|
+
self._create_table_if_not_exists()
|
|
150
|
+
|
|
151
|
+
self._setup_complete = True
|
|
152
|
+
logger.debug(
|
|
153
|
+
f"[{self.name}] Semantic cache initialized for space '{self.genie_space_id}' "
|
|
154
|
+
f"with table '{self.table_name}' (dims={self._embedding_dims})"
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
@property
|
|
158
|
+
def database(self) -> DatabaseModel:
|
|
159
|
+
"""The database used for storing cache entries."""
|
|
160
|
+
return self.parameters.database
|
|
161
|
+
|
|
162
|
+
@property
|
|
163
|
+
def warehouse(self) -> WarehouseModel:
|
|
164
|
+
"""The warehouse used for executing cached SQL queries."""
|
|
165
|
+
return self.parameters.warehouse
|
|
166
|
+
|
|
167
|
+
@property
|
|
168
|
+
def time_to_live(self) -> timedelta | None:
|
|
169
|
+
"""Time-to-live for cache entries. None means never expires."""
|
|
170
|
+
ttl = self.parameters.time_to_live_seconds
|
|
171
|
+
if ttl is None or ttl < 0:
|
|
172
|
+
return None
|
|
173
|
+
return timedelta(seconds=ttl)
|
|
174
|
+
|
|
175
|
+
@property
|
|
176
|
+
def similarity_threshold(self) -> float:
|
|
177
|
+
"""Minimum similarity for cache hit (using L2 distance converted to similarity)."""
|
|
178
|
+
return self.parameters.similarity_threshold
|
|
179
|
+
|
|
180
|
+
@property
|
|
181
|
+
def embedding_dims(self) -> int:
|
|
182
|
+
"""Dimension size for embeddings (auto-detected if not configured)."""
|
|
183
|
+
if self._embedding_dims is None:
|
|
184
|
+
raise RuntimeError(
|
|
185
|
+
"Embedding dimensions not yet initialized. Call _setup() first."
|
|
186
|
+
)
|
|
187
|
+
return self._embedding_dims
|
|
188
|
+
|
|
189
|
+
@property
|
|
190
|
+
def table_name(self) -> str:
|
|
191
|
+
"""Name of the cache table."""
|
|
192
|
+
return self.parameters.table_name
|
|
193
|
+
|
|
194
|
+
def _create_table_if_not_exists(self) -> None:
|
|
195
|
+
"""Create the cache table with pg_vector extension if it doesn't exist.
|
|
196
|
+
|
|
197
|
+
If the table exists but has a different embedding dimension, it will be
|
|
198
|
+
dropped and recreated with the new dimension size.
|
|
199
|
+
"""
|
|
200
|
+
create_extension_sql: str = "CREATE EXTENSION IF NOT EXISTS vector"
|
|
201
|
+
|
|
202
|
+
# Check if table exists and get current embedding dimensions
|
|
203
|
+
check_dims_sql: str = """
|
|
204
|
+
SELECT atttypmod
|
|
205
|
+
FROM pg_attribute
|
|
206
|
+
WHERE attrelid = %s::regclass
|
|
207
|
+
AND attname = 'question_embedding'
|
|
208
|
+
"""
|
|
209
|
+
|
|
210
|
+
create_table_sql: str = f"""
|
|
211
|
+
CREATE TABLE IF NOT EXISTS {self.table_name} (
|
|
212
|
+
id SERIAL PRIMARY KEY,
|
|
213
|
+
genie_space_id TEXT NOT NULL,
|
|
214
|
+
question TEXT NOT NULL,
|
|
215
|
+
question_embedding vector({self.embedding_dims}),
|
|
216
|
+
sql_query TEXT NOT NULL,
|
|
217
|
+
description TEXT,
|
|
218
|
+
conversation_id TEXT,
|
|
219
|
+
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
|
|
220
|
+
)
|
|
221
|
+
"""
|
|
222
|
+
# Index for efficient similarity search partitioned by genie_space_id
|
|
223
|
+
# Use L2 (Euclidean) distance - optimal for Databricks GTE embeddings
|
|
224
|
+
create_embedding_index_sql: str = f"""
|
|
225
|
+
CREATE INDEX IF NOT EXISTS {self.table_name}_embedding_idx
|
|
226
|
+
ON {self.table_name}
|
|
227
|
+
USING ivfflat (question_embedding vector_l2_ops)
|
|
228
|
+
WITH (lists = 100)
|
|
229
|
+
"""
|
|
230
|
+
# Index for filtering by genie_space_id
|
|
231
|
+
create_space_index_sql: str = f"""
|
|
232
|
+
CREATE INDEX IF NOT EXISTS {self.table_name}_space_idx
|
|
233
|
+
ON {self.table_name} (genie_space_id)
|
|
234
|
+
"""
|
|
235
|
+
|
|
236
|
+
with self._pool.connection() as conn:
|
|
237
|
+
with conn.cursor() as cur:
|
|
238
|
+
cur.execute(create_extension_sql)
|
|
239
|
+
|
|
240
|
+
# Check if table exists and verify embedding dimensions
|
|
241
|
+
try:
|
|
242
|
+
cur.execute(check_dims_sql, (self.table_name,))
|
|
243
|
+
row: DbRow | None = cur.fetchone()
|
|
244
|
+
if row is not None:
|
|
245
|
+
# atttypmod for vector type contains the dimension
|
|
246
|
+
current_dims = row.get("atttypmod", 0)
|
|
247
|
+
if current_dims != self.embedding_dims:
|
|
248
|
+
logger.warning(
|
|
249
|
+
f"[{self.name}] Embedding dimension mismatch: "
|
|
250
|
+
f"table has {current_dims}, expected {self.embedding_dims}. "
|
|
251
|
+
f"Dropping and recreating table '{self.table_name}'."
|
|
252
|
+
)
|
|
253
|
+
cur.execute(f"DROP TABLE {self.table_name}")
|
|
254
|
+
except Exception:
|
|
255
|
+
# Table doesn't exist, which is fine
|
|
256
|
+
pass
|
|
257
|
+
|
|
258
|
+
cur.execute(create_table_sql)
|
|
259
|
+
cur.execute(create_space_index_sql)
|
|
260
|
+
cur.execute(create_embedding_index_sql)
|
|
261
|
+
|
|
262
|
+
def _embed_question(self, question: str) -> list[float]:
|
|
263
|
+
"""Generate embedding for a question."""
|
|
264
|
+
embeddings: list[list[float]] = self._embeddings.embed_documents([question])
|
|
265
|
+
return embeddings[0]
|
|
266
|
+
|
|
267
|
+
@mlflow.trace(name="semantic_search")
|
|
268
|
+
def _find_similar(
|
|
269
|
+
self, question: str, embedding: list[float]
|
|
270
|
+
) -> tuple[SQLCacheEntry, float] | None:
|
|
271
|
+
"""
|
|
272
|
+
Find a semantically similar cached entry for this Genie space.
|
|
273
|
+
|
|
274
|
+
Args:
|
|
275
|
+
question: The question to search for
|
|
276
|
+
embedding: The embedding vector of the question
|
|
277
|
+
|
|
278
|
+
Returns:
|
|
279
|
+
Tuple of (SQLCacheEntry, similarity_score) if found, None otherwise
|
|
280
|
+
"""
|
|
281
|
+
# Use L2 (Euclidean) distance - optimal for Databricks GTE embeddings
|
|
282
|
+
# pg_vector's <-> operator returns L2 distance (0 = identical)
|
|
283
|
+
# Convert to similarity: 1 / (1 + distance) gives range [0, 1]
|
|
284
|
+
#
|
|
285
|
+
# Refresh-on-hit strategy:
|
|
286
|
+
# 1. Search without TTL filter to find best semantic match
|
|
287
|
+
# 2. If match is within TTL (or TTL disabled) → cache hit
|
|
288
|
+
# 3. If match is expired → delete it, return miss (triggers refresh)
|
|
289
|
+
ttl_seconds = self.parameters.time_to_live_seconds
|
|
290
|
+
ttl_disabled = ttl_seconds is None or ttl_seconds < 0
|
|
291
|
+
|
|
292
|
+
# When TTL is disabled, all entries are always valid
|
|
293
|
+
if ttl_disabled:
|
|
294
|
+
is_valid_expr = "TRUE"
|
|
295
|
+
else:
|
|
296
|
+
is_valid_expr = f"created_at > NOW() - INTERVAL '{ttl_seconds} seconds'"
|
|
297
|
+
|
|
298
|
+
search_sql: str = f"""
|
|
299
|
+
SELECT
|
|
300
|
+
id,
|
|
301
|
+
question,
|
|
302
|
+
sql_query,
|
|
303
|
+
description,
|
|
304
|
+
conversation_id,
|
|
305
|
+
created_at,
|
|
306
|
+
1.0 / (1.0 + (question_embedding <-> %s::vector)) as similarity,
|
|
307
|
+
{is_valid_expr} as is_valid
|
|
308
|
+
FROM {self.table_name}
|
|
309
|
+
WHERE genie_space_id = %s
|
|
310
|
+
ORDER BY question_embedding <-> %s::vector
|
|
311
|
+
LIMIT 1
|
|
312
|
+
"""
|
|
313
|
+
|
|
314
|
+
embedding_str: str = f"[{','.join(str(x) for x in embedding)}]"
|
|
315
|
+
|
|
316
|
+
with self._pool.connection() as conn:
|
|
317
|
+
with conn.cursor() as cur:
|
|
318
|
+
cur.execute(
|
|
319
|
+
search_sql,
|
|
320
|
+
(embedding_str, self.genie_space_id, embedding_str),
|
|
321
|
+
)
|
|
322
|
+
row: DbRow | None = cur.fetchone()
|
|
323
|
+
|
|
324
|
+
if row is None:
|
|
325
|
+
logger.info(
|
|
326
|
+
f"[{self.name}] MISS (no entries): "
|
|
327
|
+
f"question='{question[:50]}...' space='{self.genie_space_id}'"
|
|
328
|
+
)
|
|
329
|
+
return None
|
|
330
|
+
|
|
331
|
+
# Extract values from dict row
|
|
332
|
+
entry_id = row.get("id")
|
|
333
|
+
cached_question = row.get("question", "")
|
|
334
|
+
sql_query = row["sql_query"]
|
|
335
|
+
description = row.get("description", "")
|
|
336
|
+
conversation_id = row.get("conversation_id", "")
|
|
337
|
+
created_at = row["created_at"]
|
|
338
|
+
similarity = row["similarity"]
|
|
339
|
+
is_valid = row.get("is_valid", False)
|
|
340
|
+
|
|
341
|
+
# Log best match info (L2 distance can be computed from similarity: d = 1/s - 1)
|
|
342
|
+
l2_distance = (
|
|
343
|
+
(1.0 / similarity) - 1.0 if similarity > 0 else float("inf")
|
|
344
|
+
)
|
|
345
|
+
logger.info(
|
|
346
|
+
f"[{self.name}] Best match: l2_distance={l2_distance:.4f}, similarity={similarity:.4f}, "
|
|
347
|
+
f"is_valid={is_valid}, question='{cached_question[:50]}...'"
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
# Check similarity threshold
|
|
351
|
+
if similarity < self.similarity_threshold:
|
|
352
|
+
logger.info(
|
|
353
|
+
f"[{self.name}] MISS (below threshold): similarity={similarity:.4f} < threshold={self.similarity_threshold} "
|
|
354
|
+
f"(cached_question='{cached_question[:50]}...')"
|
|
355
|
+
)
|
|
356
|
+
return None
|
|
357
|
+
|
|
358
|
+
# Check TTL - refresh on hit strategy
|
|
359
|
+
if not is_valid:
|
|
360
|
+
# Entry is expired - delete it and return miss to trigger refresh
|
|
361
|
+
delete_sql = f"DELETE FROM {self.table_name} WHERE id = %s"
|
|
362
|
+
cur.execute(delete_sql, (entry_id,))
|
|
363
|
+
logger.info(
|
|
364
|
+
f"[{self.name}] MISS (expired, deleted for refresh): similarity={similarity:.4f}, "
|
|
365
|
+
f"ttl={ttl_seconds}s, question='{cached_question[:50]}...'"
|
|
366
|
+
)
|
|
367
|
+
return None
|
|
368
|
+
|
|
369
|
+
logger.info(
|
|
370
|
+
f"[{self.name}] HIT: similarity={similarity:.4f} >= threshold={self.similarity_threshold} "
|
|
371
|
+
f"(cached_question='{cached_question[:50]}...')"
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
entry = SQLCacheEntry(
|
|
375
|
+
query=sql_query,
|
|
376
|
+
description=description,
|
|
377
|
+
conversation_id=conversation_id,
|
|
378
|
+
created_at=created_at,
|
|
379
|
+
)
|
|
380
|
+
return entry, similarity
|
|
381
|
+
|
|
382
|
+
def _store_entry(
|
|
383
|
+
self, question: str, embedding: list[float], response: GenieResponse
|
|
384
|
+
) -> None:
|
|
385
|
+
"""Store a new cache entry for this Genie space."""
|
|
386
|
+
insert_sql: str = f"""
|
|
387
|
+
INSERT INTO {self.table_name}
|
|
388
|
+
(genie_space_id, question, question_embedding, sql_query, description, conversation_id)
|
|
389
|
+
VALUES (%s, %s, %s::vector, %s, %s, %s)
|
|
390
|
+
"""
|
|
391
|
+
embedding_str: str = f"[{','.join(str(x) for x in embedding)}]"
|
|
392
|
+
|
|
393
|
+
with self._pool.connection() as conn:
|
|
394
|
+
with conn.cursor() as cur:
|
|
395
|
+
cur.execute(
|
|
396
|
+
insert_sql,
|
|
397
|
+
(
|
|
398
|
+
self.genie_space_id,
|
|
399
|
+
question,
|
|
400
|
+
embedding_str,
|
|
401
|
+
response.query,
|
|
402
|
+
response.description,
|
|
403
|
+
response.conversation_id,
|
|
404
|
+
),
|
|
405
|
+
)
|
|
406
|
+
logger.info(
|
|
407
|
+
f"[{self.name}] Stored cache entry: question='{question[:50]}...' "
|
|
408
|
+
f"sql='{response.query[:50]}...' (space={self.genie_space_id}, table={self.table_name})"
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
@mlflow.trace(name="execute_cached_sql_semantic")
|
|
412
|
+
def _execute_sql(self, sql: str) -> pd.DataFrame | str:
|
|
413
|
+
"""Execute SQL using the warehouse and return results."""
|
|
414
|
+
client: WorkspaceClient = self.warehouse.workspace_client
|
|
415
|
+
warehouse_id: str = self.warehouse.warehouse_id
|
|
416
|
+
|
|
417
|
+
statement_response: StatementResponse = (
|
|
418
|
+
client.statement_execution.execute_statement(
|
|
419
|
+
warehouse_id=warehouse_id,
|
|
420
|
+
statement=sql,
|
|
421
|
+
wait_timeout="30s",
|
|
422
|
+
)
|
|
423
|
+
)
|
|
424
|
+
|
|
425
|
+
if statement_response.status.state != StatementState.SUCCEEDED:
|
|
426
|
+
error_msg: str = (
|
|
427
|
+
f"SQL execution failed: {statement_response.status.error.message}"
|
|
428
|
+
if statement_response.status.error
|
|
429
|
+
else f"SQL execution failed with state: {statement_response.status.state}"
|
|
430
|
+
)
|
|
431
|
+
logger.error(f"[{self.name}] {error_msg}")
|
|
432
|
+
return error_msg
|
|
433
|
+
|
|
434
|
+
if statement_response.result and statement_response.result.data_array:
|
|
435
|
+
columns: list[str] = []
|
|
436
|
+
if (
|
|
437
|
+
statement_response.manifest
|
|
438
|
+
and statement_response.manifest.schema
|
|
439
|
+
and statement_response.manifest.schema.columns
|
|
440
|
+
):
|
|
441
|
+
columns = [
|
|
442
|
+
col.name for col in statement_response.manifest.schema.columns
|
|
443
|
+
]
|
|
444
|
+
elif hasattr(statement_response.result, "schema"):
|
|
445
|
+
columns = [col.name for col in statement_response.result.schema.columns]
|
|
446
|
+
|
|
447
|
+
data: list[list[Any]] = statement_response.result.data_array
|
|
448
|
+
if columns:
|
|
449
|
+
return pd.DataFrame(data, columns=columns)
|
|
450
|
+
else:
|
|
451
|
+
return pd.DataFrame(data)
|
|
452
|
+
|
|
453
|
+
return pd.DataFrame()
|
|
454
|
+
|
|
455
|
+
def ask_question(
|
|
456
|
+
self, question: str, conversation_id: str | None = None
|
|
457
|
+
) -> GenieResponse:
|
|
458
|
+
"""
|
|
459
|
+
Ask a question, using semantic cache if a similar query exists.
|
|
460
|
+
|
|
461
|
+
On cache hit, re-executes the cached SQL to get fresh data.
|
|
462
|
+
Implements GenieServiceBase for seamless chaining.
|
|
463
|
+
"""
|
|
464
|
+
result: CacheResult = self.ask_question_with_cache_info(
|
|
465
|
+
question, conversation_id
|
|
466
|
+
)
|
|
467
|
+
return result.response
|
|
468
|
+
|
|
469
|
+
@mlflow.trace(name="genie_semantic_cache_lookup", span_type=SpanType.TOOL)
|
|
470
|
+
def ask_question_with_cache_info(
|
|
471
|
+
self,
|
|
472
|
+
question: str,
|
|
473
|
+
conversation_id: str | None = None,
|
|
474
|
+
) -> CacheResult:
|
|
475
|
+
"""
|
|
476
|
+
Ask a question with detailed cache hit information.
|
|
477
|
+
|
|
478
|
+
On cache hit, the cached SQL is re-executed to return fresh data.
|
|
479
|
+
|
|
480
|
+
Args:
|
|
481
|
+
question: The question to ask
|
|
482
|
+
conversation_id: Optional conversation ID
|
|
483
|
+
|
|
484
|
+
Returns:
|
|
485
|
+
CacheResult with fresh response and cache metadata
|
|
486
|
+
"""
|
|
487
|
+
# Ensure initialization (lazy init if initialize() wasn't called)
|
|
488
|
+
self._setup()
|
|
489
|
+
|
|
490
|
+
# Generate embedding for the question
|
|
491
|
+
embedding: list[float] = self._embed_question(question)
|
|
492
|
+
|
|
493
|
+
# Check cache
|
|
494
|
+
cache_result: tuple[SQLCacheEntry, float] | None = self._find_similar(
|
|
495
|
+
question, embedding
|
|
496
|
+
)
|
|
497
|
+
|
|
498
|
+
if cache_result is not None:
|
|
499
|
+
cached, similarity = cache_result
|
|
500
|
+
logger.debug(
|
|
501
|
+
f"[{self.name}] Semantic cache hit (similarity={similarity:.3f}): {question[:50]}..."
|
|
502
|
+
)
|
|
503
|
+
|
|
504
|
+
# Re-execute the cached SQL to get fresh data
|
|
505
|
+
result: pd.DataFrame | str = self._execute_sql(cached.query)
|
|
506
|
+
|
|
507
|
+
response: GenieResponse = GenieResponse(
|
|
508
|
+
result=result,
|
|
509
|
+
query=cached.query,
|
|
510
|
+
description=cached.description,
|
|
511
|
+
conversation_id=cached.conversation_id,
|
|
512
|
+
)
|
|
513
|
+
|
|
514
|
+
return CacheResult(response=response, cache_hit=True, served_by=self.name)
|
|
515
|
+
|
|
516
|
+
# Cache miss - delegate to wrapped service
|
|
517
|
+
logger.debug(f"[{self.name}] Miss: {question[:50]}...")
|
|
518
|
+
|
|
519
|
+
response = self.impl.ask_question(question, conversation_id)
|
|
520
|
+
|
|
521
|
+
# Store in cache if we got a SQL query
|
|
522
|
+
if response.query:
|
|
523
|
+
logger.info(
|
|
524
|
+
f"[{self.name}] Storing new cache entry for question: '{question[:50]}...' "
|
|
525
|
+
f"(space={self.genie_space_id})"
|
|
526
|
+
)
|
|
527
|
+
self._store_entry(question, embedding, response)
|
|
528
|
+
elif not response.query:
|
|
529
|
+
logger.warning(
|
|
530
|
+
f"[{self.name}] Not caching: response has no SQL query "
|
|
531
|
+
f"(question='{question[:50]}...')"
|
|
532
|
+
)
|
|
533
|
+
|
|
534
|
+
return CacheResult(response=response, cache_hit=False, served_by=None)
|
|
535
|
+
|
|
536
|
+
def invalidate_expired(self) -> int:
|
|
537
|
+
"""Remove expired entries from the cache for this Genie space.
|
|
538
|
+
|
|
539
|
+
Returns 0 if TTL is disabled (entries never expire).
|
|
540
|
+
"""
|
|
541
|
+
self._setup()
|
|
542
|
+
ttl_seconds = self.parameters.time_to_live_seconds
|
|
543
|
+
|
|
544
|
+
# If TTL is disabled, nothing can expire
|
|
545
|
+
if ttl_seconds is None or ttl_seconds < 0:
|
|
546
|
+
logger.debug(
|
|
547
|
+
f"[{self.name}] TTL disabled, no entries to expire for space {self.genie_space_id}"
|
|
548
|
+
)
|
|
549
|
+
return 0
|
|
550
|
+
|
|
551
|
+
delete_sql: str = f"""
|
|
552
|
+
DELETE FROM {self.table_name}
|
|
553
|
+
WHERE genie_space_id = %s
|
|
554
|
+
AND created_at < NOW() - INTERVAL '%s seconds'
|
|
555
|
+
"""
|
|
556
|
+
|
|
557
|
+
with self._pool.connection() as conn:
|
|
558
|
+
with conn.cursor() as cur:
|
|
559
|
+
cur.execute(delete_sql, (self.genie_space_id, ttl_seconds))
|
|
560
|
+
deleted: int = cur.rowcount
|
|
561
|
+
logger.debug(
|
|
562
|
+
f"[{self.name}] Deleted {deleted} expired entries for space {self.genie_space_id}"
|
|
563
|
+
)
|
|
564
|
+
return deleted
|
|
565
|
+
|
|
566
|
+
def clear(self) -> int:
|
|
567
|
+
"""Clear all entries from the cache for this Genie space."""
|
|
568
|
+
self._setup()
|
|
569
|
+
delete_sql: str = f"DELETE FROM {self.table_name} WHERE genie_space_id = %s"
|
|
570
|
+
|
|
571
|
+
with self._pool.connection() as conn:
|
|
572
|
+
with conn.cursor() as cur:
|
|
573
|
+
cur.execute(delete_sql, (self.genie_space_id,))
|
|
574
|
+
deleted: int = cur.rowcount
|
|
575
|
+
logger.debug(
|
|
576
|
+
f"[{self.name}] Cleared {deleted} entries for space {self.genie_space_id}"
|
|
577
|
+
)
|
|
578
|
+
return deleted
|
|
579
|
+
|
|
580
|
+
@property
|
|
581
|
+
def size(self) -> int:
|
|
582
|
+
"""Current number of entries in the cache for this Genie space."""
|
|
583
|
+
self._setup()
|
|
584
|
+
count_sql: str = (
|
|
585
|
+
f"SELECT COUNT(*) as count FROM {self.table_name} WHERE genie_space_id = %s"
|
|
586
|
+
)
|
|
587
|
+
|
|
588
|
+
with self._pool.connection() as conn:
|
|
589
|
+
with conn.cursor() as cur:
|
|
590
|
+
cur.execute(count_sql, (self.genie_space_id,))
|
|
591
|
+
row: DbRow | None = cur.fetchone()
|
|
592
|
+
return row.get("count", 0) if row else 0
|
|
593
|
+
|
|
594
|
+
def stats(self) -> dict[str, int | float | None]:
|
|
595
|
+
"""Return cache statistics for this Genie space."""
|
|
596
|
+
self._setup()
|
|
597
|
+
ttl_seconds = self.parameters.time_to_live_seconds
|
|
598
|
+
ttl = self.time_to_live
|
|
599
|
+
|
|
600
|
+
# If TTL is disabled, all entries are valid
|
|
601
|
+
if ttl_seconds is None or ttl_seconds < 0:
|
|
602
|
+
count_sql: str = f"""
|
|
603
|
+
SELECT COUNT(*) as total FROM {self.table_name}
|
|
604
|
+
WHERE genie_space_id = %s
|
|
605
|
+
"""
|
|
606
|
+
with self._pool.connection() as conn:
|
|
607
|
+
with conn.cursor() as cur:
|
|
608
|
+
cur.execute(count_sql, (self.genie_space_id,))
|
|
609
|
+
row: DbRow | None = cur.fetchone()
|
|
610
|
+
total = row.get("total", 0) if row else 0
|
|
611
|
+
return {
|
|
612
|
+
"size": total,
|
|
613
|
+
"ttl_seconds": None,
|
|
614
|
+
"similarity_threshold": self.similarity_threshold,
|
|
615
|
+
"expired_entries": 0,
|
|
616
|
+
"valid_entries": total,
|
|
617
|
+
}
|
|
618
|
+
|
|
619
|
+
stats_sql: str = f"""
|
|
620
|
+
SELECT
|
|
621
|
+
COUNT(*) as total,
|
|
622
|
+
COUNT(*) FILTER (WHERE created_at > NOW() - INTERVAL '%s seconds') as valid,
|
|
623
|
+
COUNT(*) FILTER (WHERE created_at <= NOW() - INTERVAL '%s seconds') as expired
|
|
624
|
+
FROM {self.table_name}
|
|
625
|
+
WHERE genie_space_id = %s
|
|
626
|
+
"""
|
|
627
|
+
|
|
628
|
+
with self._pool.connection() as conn:
|
|
629
|
+
with conn.cursor() as cur:
|
|
630
|
+
cur.execute(stats_sql, (ttl_seconds, ttl_seconds, self.genie_space_id))
|
|
631
|
+
row: DbRow | None = cur.fetchone()
|
|
632
|
+
return {
|
|
633
|
+
"size": row.get("total", 0) if row else 0,
|
|
634
|
+
"ttl_seconds": ttl.total_seconds() if ttl else None,
|
|
635
|
+
"similarity_threshold": self.similarity_threshold,
|
|
636
|
+
"expired_entries": row.get("expired", 0) if row else 0,
|
|
637
|
+
"valid_entries": row.get("valid", 0) if row else 0,
|
|
638
|
+
}
|
dao_ai/tools/__init__.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from dao_ai.genie.cache import LRUCacheService, SemanticCacheService
|
|
1
2
|
from dao_ai.hooks.core import create_hooks
|
|
2
3
|
from dao_ai.tools.agent import create_agent_endpoint_tool
|
|
3
4
|
from dao_ai.tools.core import (
|
|
@@ -35,7 +36,9 @@ __all__ = [
|
|
|
35
36
|
"current_time_tool",
|
|
36
37
|
"format_time_tool",
|
|
37
38
|
"is_business_hours_tool",
|
|
39
|
+
"LRUCacheService",
|
|
38
40
|
"search_tool",
|
|
41
|
+
"SemanticCacheService",
|
|
39
42
|
"time_difference_tool",
|
|
40
43
|
"time_in_timezone_tool",
|
|
41
44
|
"time_until_tool",
|