dao-ai 0.0.35__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/__init__.py +29 -0
- dao_ai/cli.py +195 -30
- dao_ai/config.py +797 -242
- dao_ai/genie/__init__.py +38 -0
- dao_ai/genie/cache/__init__.py +43 -0
- dao_ai/genie/cache/base.py +72 -0
- dao_ai/genie/cache/core.py +75 -0
- dao_ai/genie/cache/lru.py +329 -0
- dao_ai/genie/cache/semantic.py +919 -0
- dao_ai/genie/core.py +35 -0
- dao_ai/graph.py +27 -253
- dao_ai/hooks/__init__.py +9 -6
- dao_ai/hooks/core.py +22 -190
- dao_ai/memory/__init__.py +10 -0
- dao_ai/memory/core.py +23 -5
- dao_ai/memory/databricks.py +389 -0
- dao_ai/memory/postgres.py +2 -2
- dao_ai/messages.py +6 -4
- dao_ai/middleware/__init__.py +125 -0
- dao_ai/middleware/assertions.py +778 -0
- dao_ai/middleware/base.py +50 -0
- dao_ai/middleware/core.py +61 -0
- dao_ai/middleware/guardrails.py +415 -0
- dao_ai/middleware/human_in_the_loop.py +228 -0
- dao_ai/middleware/message_validation.py +554 -0
- dao_ai/middleware/summarization.py +192 -0
- dao_ai/models.py +1177 -108
- dao_ai/nodes.py +118 -161
- dao_ai/optimization.py +664 -0
- dao_ai/orchestration/__init__.py +52 -0
- dao_ai/orchestration/core.py +287 -0
- dao_ai/orchestration/supervisor.py +264 -0
- dao_ai/orchestration/swarm.py +226 -0
- dao_ai/prompts.py +126 -29
- dao_ai/providers/databricks.py +126 -381
- dao_ai/state.py +139 -21
- dao_ai/tools/__init__.py +11 -5
- dao_ai/tools/core.py +57 -4
- dao_ai/tools/email.py +280 -0
- dao_ai/tools/genie.py +108 -35
- dao_ai/tools/mcp.py +4 -3
- dao_ai/tools/memory.py +50 -0
- dao_ai/tools/python.py +4 -12
- dao_ai/tools/search.py +14 -0
- dao_ai/tools/slack.py +1 -1
- dao_ai/tools/unity_catalog.py +8 -6
- dao_ai/tools/vector_search.py +16 -9
- dao_ai/utils.py +72 -8
- dao_ai-0.1.0.dist-info/METADATA +1878 -0
- dao_ai-0.1.0.dist-info/RECORD +62 -0
- dao_ai/chat_models.py +0 -204
- dao_ai/guardrails.py +0 -112
- dao_ai/tools/human_in_the_loop.py +0 -100
- dao_ai-0.0.35.dist-info/METADATA +0 -1169
- dao_ai-0.0.35.dist-info/RECORD +0 -41
- {dao_ai-0.0.35.dist-info → dao_ai-0.1.0.dist-info}/WHEEL +0 -0
- {dao_ai-0.0.35.dist-info → dao_ai-0.1.0.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.35.dist-info → dao_ai-0.1.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,919 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Semantic cache implementation for Genie SQL queries using PostgreSQL pg_vector.
|
|
3
|
+
|
|
4
|
+
This module provides a semantic cache that uses embeddings and similarity search
|
|
5
|
+
to find cached queries that match the intent of new questions. Cache entries are
|
|
6
|
+
partitioned by genie_space_id to ensure proper isolation between Genie spaces.
|
|
7
|
+
|
|
8
|
+
The cache supports conversation-aware embedding using a rolling window approach
|
|
9
|
+
to capture context from recent conversation turns, improving accuracy for
|
|
10
|
+
multi-turn conversations with anaphoric references.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from datetime import timedelta
|
|
14
|
+
from typing import Any
|
|
15
|
+
|
|
16
|
+
import mlflow
|
|
17
|
+
import pandas as pd
|
|
18
|
+
from databricks.sdk import WorkspaceClient
|
|
19
|
+
from databricks.sdk.service.dashboards import (
|
|
20
|
+
GenieListConversationMessagesResponse,
|
|
21
|
+
GenieMessage,
|
|
22
|
+
)
|
|
23
|
+
from databricks.sdk.service.sql import StatementResponse, StatementState
|
|
24
|
+
from databricks_ai_bridge.genie import GenieResponse
|
|
25
|
+
from loguru import logger
|
|
26
|
+
|
|
27
|
+
from dao_ai.config import (
|
|
28
|
+
DatabaseModel,
|
|
29
|
+
GenieSemanticCacheParametersModel,
|
|
30
|
+
LLMModel,
|
|
31
|
+
WarehouseModel,
|
|
32
|
+
)
|
|
33
|
+
from dao_ai.genie.cache.base import (
|
|
34
|
+
CacheResult,
|
|
35
|
+
GenieServiceBase,
|
|
36
|
+
SQLCacheEntry,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
# Type alias for database row (dict due to row_factory=dict_row)
|
|
40
|
+
DbRow = dict[str, Any]
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def get_conversation_history(
|
|
44
|
+
workspace_client: WorkspaceClient,
|
|
45
|
+
space_id: str,
|
|
46
|
+
conversation_id: str,
|
|
47
|
+
max_messages: int = 10,
|
|
48
|
+
) -> list[GenieMessage]:
|
|
49
|
+
"""
|
|
50
|
+
Retrieve conversation history from Genie.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
workspace_client: The Databricks workspace client
|
|
54
|
+
space_id: The Genie space ID
|
|
55
|
+
conversation_id: The conversation ID to retrieve
|
|
56
|
+
max_messages: Maximum number of messages to retrieve
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
List of GenieMessage objects representing the conversation history
|
|
60
|
+
"""
|
|
61
|
+
try:
|
|
62
|
+
# Use the Genie API to retrieve conversation messages
|
|
63
|
+
response: GenieListConversationMessagesResponse = (
|
|
64
|
+
workspace_client.genie.list_conversation_messages(
|
|
65
|
+
space_id=space_id,
|
|
66
|
+
conversation_id=conversation_id,
|
|
67
|
+
)
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
# Return the most recent messages up to max_messages
|
|
71
|
+
if response.messages is not None:
|
|
72
|
+
all_messages: list[GenieMessage] = list(response.messages)
|
|
73
|
+
return (
|
|
74
|
+
all_messages[-max_messages:]
|
|
75
|
+
if len(all_messages) > max_messages
|
|
76
|
+
else all_messages
|
|
77
|
+
)
|
|
78
|
+
return []
|
|
79
|
+
except Exception as e:
|
|
80
|
+
logger.warning(
|
|
81
|
+
f"Failed to retrieve conversation history for conversation_id={conversation_id}: {e}"
|
|
82
|
+
)
|
|
83
|
+
return []
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def build_context_string(
|
|
87
|
+
question: str,
|
|
88
|
+
conversation_messages: list[GenieMessage],
|
|
89
|
+
window_size: int,
|
|
90
|
+
max_tokens: int = 2000,
|
|
91
|
+
) -> str:
|
|
92
|
+
"""
|
|
93
|
+
Build a context-aware question string using rolling window.
|
|
94
|
+
|
|
95
|
+
This function creates a concatenated string that includes recent conversation
|
|
96
|
+
turns to provide context for semantic similarity matching.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
question: The current question
|
|
100
|
+
conversation_messages: List of previous conversation messages
|
|
101
|
+
window_size: Number of previous turns to include
|
|
102
|
+
max_tokens: Maximum estimated tokens (rough approximation: 4 chars = 1 token)
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
Context-aware question string formatted for embedding
|
|
106
|
+
"""
|
|
107
|
+
if window_size <= 0 or not conversation_messages:
|
|
108
|
+
return question
|
|
109
|
+
|
|
110
|
+
# Take the last window_size messages (most recent)
|
|
111
|
+
recent_messages = (
|
|
112
|
+
conversation_messages[-window_size:]
|
|
113
|
+
if len(conversation_messages) > window_size
|
|
114
|
+
else conversation_messages
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
# Build context parts
|
|
118
|
+
context_parts: list[str] = []
|
|
119
|
+
|
|
120
|
+
for msg in recent_messages:
|
|
121
|
+
# Only include messages with content from the history
|
|
122
|
+
if msg.content:
|
|
123
|
+
# Limit message length to prevent token overflow
|
|
124
|
+
content: str = msg.content
|
|
125
|
+
if len(content) > 500: # Truncate very long messages
|
|
126
|
+
content = content[:500] + "..."
|
|
127
|
+
context_parts.append(f"Previous: {content}")
|
|
128
|
+
|
|
129
|
+
# Add current question
|
|
130
|
+
context_parts.append(f"Current: {question}")
|
|
131
|
+
|
|
132
|
+
# Join with newlines
|
|
133
|
+
context_string = "\n".join(context_parts)
|
|
134
|
+
|
|
135
|
+
# Rough token limit check (4 chars ≈ 1 token)
|
|
136
|
+
estimated_tokens = len(context_string) / 4
|
|
137
|
+
if estimated_tokens > max_tokens:
|
|
138
|
+
# Truncate to fit max_tokens
|
|
139
|
+
target_chars = max_tokens * 4
|
|
140
|
+
context_string = context_string[:target_chars] + "..."
|
|
141
|
+
logger.debug(
|
|
142
|
+
f"Truncated context string from {len(context_string)} to {target_chars} chars "
|
|
143
|
+
f"(estimated {max_tokens} tokens)"
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
return context_string
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
class SemanticCacheService(GenieServiceBase):
|
|
150
|
+
"""
|
|
151
|
+
Semantic caching decorator that uses PostgreSQL pg_vector for similarity lookup.
|
|
152
|
+
|
|
153
|
+
This service caches the SQL query generated by Genie along with an embedding
|
|
154
|
+
of the original question. On subsequent queries, it performs a semantic similarity
|
|
155
|
+
search to find cached queries that match the intent of the new question.
|
|
156
|
+
|
|
157
|
+
Cache entries are partitioned by genie_space_id to ensure queries from different
|
|
158
|
+
Genie spaces don't return incorrect cache hits.
|
|
159
|
+
|
|
160
|
+
On cache hit, it re-executes the cached SQL using the provided warehouse
|
|
161
|
+
to return fresh data while avoiding the Genie NL-to-SQL translation cost.
|
|
162
|
+
|
|
163
|
+
Example:
|
|
164
|
+
from dao_ai.config import GenieSemanticCacheParametersModel, DatabaseModel
|
|
165
|
+
from dao_ai.genie.cache import SemanticCacheService
|
|
166
|
+
|
|
167
|
+
cache_params = GenieSemanticCacheParametersModel(
|
|
168
|
+
database=database_model,
|
|
169
|
+
warehouse=warehouse_model,
|
|
170
|
+
embedding_model="databricks-gte-large-en",
|
|
171
|
+
time_to_live_seconds=86400, # 24 hours
|
|
172
|
+
similarity_threshold=0.85
|
|
173
|
+
)
|
|
174
|
+
genie = SemanticCacheService(
|
|
175
|
+
impl=GenieService(Genie(space_id="my-space")),
|
|
176
|
+
parameters=cache_params
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
Thread-safe: Uses connection pooling from psycopg_pool.
|
|
180
|
+
"""
|
|
181
|
+
|
|
182
|
+
impl: GenieServiceBase
|
|
183
|
+
parameters: GenieSemanticCacheParametersModel
|
|
184
|
+
workspace_client: WorkspaceClient | None
|
|
185
|
+
name: str
|
|
186
|
+
_embeddings: Any # DatabricksEmbeddings
|
|
187
|
+
_pool: Any # ConnectionPool
|
|
188
|
+
_embedding_dims: int | None
|
|
189
|
+
_setup_complete: bool
|
|
190
|
+
|
|
191
|
+
def __init__(
|
|
192
|
+
self,
|
|
193
|
+
impl: GenieServiceBase,
|
|
194
|
+
parameters: GenieSemanticCacheParametersModel,
|
|
195
|
+
workspace_client: WorkspaceClient | None = None,
|
|
196
|
+
name: str | None = None,
|
|
197
|
+
) -> None:
|
|
198
|
+
"""
|
|
199
|
+
Initialize the semantic cache service.
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
impl: The underlying GenieServiceBase to delegate to on cache miss.
|
|
203
|
+
The space_id will be obtained from impl.space_id.
|
|
204
|
+
parameters: Cache configuration including database, warehouse, embedding model
|
|
205
|
+
workspace_client: Optional WorkspaceClient for retrieving conversation history.
|
|
206
|
+
If None, conversation context will not be used.
|
|
207
|
+
name: Name for this cache layer (for logging). Defaults to class name.
|
|
208
|
+
"""
|
|
209
|
+
self.impl = impl
|
|
210
|
+
self.parameters = parameters
|
|
211
|
+
self.workspace_client = workspace_client
|
|
212
|
+
self.name = name if name is not None else self.__class__.__name__
|
|
213
|
+
self._embeddings = None
|
|
214
|
+
self._pool = None
|
|
215
|
+
self._embedding_dims = None
|
|
216
|
+
self._setup_complete = False
|
|
217
|
+
|
|
218
|
+
def initialize(self) -> "SemanticCacheService":
|
|
219
|
+
"""
|
|
220
|
+
Eagerly initialize the cache service.
|
|
221
|
+
|
|
222
|
+
Call this during tool creation to:
|
|
223
|
+
- Validate configuration early (fail fast)
|
|
224
|
+
- Create the database table before any requests
|
|
225
|
+
- Avoid first-request latency from lazy initialization
|
|
226
|
+
|
|
227
|
+
Returns:
|
|
228
|
+
self for method chaining
|
|
229
|
+
"""
|
|
230
|
+
self._setup()
|
|
231
|
+
return self
|
|
232
|
+
|
|
233
|
+
def _setup(self) -> None:
|
|
234
|
+
"""Initialize embeddings and database connection pool lazily."""
|
|
235
|
+
if self._setup_complete:
|
|
236
|
+
return
|
|
237
|
+
|
|
238
|
+
from dao_ai.memory.postgres import PostgresPoolManager
|
|
239
|
+
|
|
240
|
+
# Initialize embeddings
|
|
241
|
+
# Convert embedding_model to LLMModel if it's a string
|
|
242
|
+
embedding_model: LLMModel = (
|
|
243
|
+
LLMModel(name=self.parameters.embedding_model)
|
|
244
|
+
if isinstance(self.parameters.embedding_model, str)
|
|
245
|
+
else self.parameters.embedding_model
|
|
246
|
+
)
|
|
247
|
+
self._embeddings = embedding_model.as_embeddings_model()
|
|
248
|
+
|
|
249
|
+
# Auto-detect embedding dimensions if not provided
|
|
250
|
+
if self.parameters.embedding_dims is None:
|
|
251
|
+
sample_embedding: list[float] = self._embeddings.embed_query("test")
|
|
252
|
+
self._embedding_dims = len(sample_embedding)
|
|
253
|
+
logger.debug(
|
|
254
|
+
f"[{self.name}] Auto-detected embedding dimensions: {self._embedding_dims}"
|
|
255
|
+
)
|
|
256
|
+
else:
|
|
257
|
+
self._embedding_dims = self.parameters.embedding_dims
|
|
258
|
+
|
|
259
|
+
# Get connection pool
|
|
260
|
+
self._pool = PostgresPoolManager.get_pool(self.parameters.database)
|
|
261
|
+
|
|
262
|
+
# Ensure table exists
|
|
263
|
+
self._create_table_if_not_exists()
|
|
264
|
+
|
|
265
|
+
self._setup_complete = True
|
|
266
|
+
logger.debug(
|
|
267
|
+
f"[{self.name}] Semantic cache initialized for space '{self.space_id}' "
|
|
268
|
+
f"with table '{self.table_name}' (dims={self._embedding_dims})"
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
@property
|
|
272
|
+
def database(self) -> DatabaseModel:
|
|
273
|
+
"""The database used for storing cache entries."""
|
|
274
|
+
return self.parameters.database
|
|
275
|
+
|
|
276
|
+
@property
|
|
277
|
+
def warehouse(self) -> WarehouseModel:
|
|
278
|
+
"""The warehouse used for executing cached SQL queries."""
|
|
279
|
+
return self.parameters.warehouse
|
|
280
|
+
|
|
281
|
+
@property
|
|
282
|
+
def time_to_live(self) -> timedelta | None:
|
|
283
|
+
"""Time-to-live for cache entries. None means never expires."""
|
|
284
|
+
ttl = self.parameters.time_to_live_seconds
|
|
285
|
+
if ttl is None or ttl < 0:
|
|
286
|
+
return None
|
|
287
|
+
return timedelta(seconds=ttl)
|
|
288
|
+
|
|
289
|
+
@property
|
|
290
|
+
def similarity_threshold(self) -> float:
|
|
291
|
+
"""Minimum similarity for cache hit (using L2 distance converted to similarity)."""
|
|
292
|
+
return self.parameters.similarity_threshold
|
|
293
|
+
|
|
294
|
+
@property
|
|
295
|
+
def embedding_dims(self) -> int:
|
|
296
|
+
"""Dimension size for embeddings (auto-detected if not configured)."""
|
|
297
|
+
if self._embedding_dims is None:
|
|
298
|
+
raise RuntimeError(
|
|
299
|
+
"Embedding dimensions not yet initialized. Call _setup() first."
|
|
300
|
+
)
|
|
301
|
+
return self._embedding_dims
|
|
302
|
+
|
|
303
|
+
@property
|
|
304
|
+
def table_name(self) -> str:
|
|
305
|
+
"""Name of the cache table."""
|
|
306
|
+
return self.parameters.table_name
|
|
307
|
+
|
|
308
|
+
def _create_table_if_not_exists(self) -> None:
|
|
309
|
+
"""Create the cache table with pg_vector extension if it doesn't exist.
|
|
310
|
+
|
|
311
|
+
If the table exists but has a different embedding dimension, it will be
|
|
312
|
+
dropped and recreated with the new dimension size.
|
|
313
|
+
"""
|
|
314
|
+
create_extension_sql: str = "CREATE EXTENSION IF NOT EXISTS vector"
|
|
315
|
+
|
|
316
|
+
# Check if table exists and get current embedding dimensions
|
|
317
|
+
check_dims_sql: str = """
|
|
318
|
+
SELECT atttypmod
|
|
319
|
+
FROM pg_attribute
|
|
320
|
+
WHERE attrelid = %s::regclass
|
|
321
|
+
AND attname = 'question_embedding'
|
|
322
|
+
"""
|
|
323
|
+
|
|
324
|
+
create_table_sql: str = f"""
|
|
325
|
+
CREATE TABLE IF NOT EXISTS {self.table_name} (
|
|
326
|
+
id SERIAL PRIMARY KEY,
|
|
327
|
+
genie_space_id TEXT NOT NULL,
|
|
328
|
+
question TEXT NOT NULL,
|
|
329
|
+
conversation_context TEXT,
|
|
330
|
+
context_string TEXT,
|
|
331
|
+
question_embedding vector({self.embedding_dims}),
|
|
332
|
+
context_embedding vector({self.embedding_dims}),
|
|
333
|
+
sql_query TEXT NOT NULL,
|
|
334
|
+
description TEXT,
|
|
335
|
+
conversation_id TEXT,
|
|
336
|
+
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
|
|
337
|
+
)
|
|
338
|
+
"""
|
|
339
|
+
# Index for efficient similarity search partitioned by genie_space_id
|
|
340
|
+
# Use L2 (Euclidean) distance - optimal for Databricks GTE embeddings
|
|
341
|
+
create_question_embedding_index_sql: str = f"""
|
|
342
|
+
CREATE INDEX IF NOT EXISTS {self.table_name}_question_embedding_idx
|
|
343
|
+
ON {self.table_name}
|
|
344
|
+
USING ivfflat (question_embedding vector_l2_ops)
|
|
345
|
+
WITH (lists = 100)
|
|
346
|
+
"""
|
|
347
|
+
create_context_embedding_index_sql: str = f"""
|
|
348
|
+
CREATE INDEX IF NOT EXISTS {self.table_name}_context_embedding_idx
|
|
349
|
+
ON {self.table_name}
|
|
350
|
+
USING ivfflat (context_embedding vector_l2_ops)
|
|
351
|
+
WITH (lists = 100)
|
|
352
|
+
"""
|
|
353
|
+
# Index for filtering by genie_space_id
|
|
354
|
+
create_space_index_sql: str = f"""
|
|
355
|
+
CREATE INDEX IF NOT EXISTS {self.table_name}_space_idx
|
|
356
|
+
ON {self.table_name} (genie_space_id)
|
|
357
|
+
"""
|
|
358
|
+
|
|
359
|
+
with self._pool.connection() as conn:
|
|
360
|
+
with conn.cursor() as cur:
|
|
361
|
+
cur.execute(create_extension_sql)
|
|
362
|
+
|
|
363
|
+
# Check if table exists and verify embedding dimensions
|
|
364
|
+
try:
|
|
365
|
+
cur.execute(check_dims_sql, (self.table_name,))
|
|
366
|
+
row: DbRow | None = cur.fetchone()
|
|
367
|
+
if row is not None:
|
|
368
|
+
# atttypmod for vector type contains the dimension
|
|
369
|
+
current_dims = row.get("atttypmod", 0)
|
|
370
|
+
if current_dims != self.embedding_dims:
|
|
371
|
+
logger.warning(
|
|
372
|
+
f"[{self.name}] Embedding dimension mismatch: "
|
|
373
|
+
f"table has {current_dims}, expected {self.embedding_dims}. "
|
|
374
|
+
f"Dropping and recreating table '{self.table_name}'."
|
|
375
|
+
)
|
|
376
|
+
cur.execute(f"DROP TABLE {self.table_name}")
|
|
377
|
+
except Exception:
|
|
378
|
+
# Table doesn't exist, which is fine
|
|
379
|
+
pass
|
|
380
|
+
|
|
381
|
+
cur.execute(create_table_sql)
|
|
382
|
+
cur.execute(create_space_index_sql)
|
|
383
|
+
cur.execute(create_question_embedding_index_sql)
|
|
384
|
+
cur.execute(create_context_embedding_index_sql)
|
|
385
|
+
|
|
386
|
+
def _embed_question(
|
|
387
|
+
self, question: str, conversation_id: str | None = None
|
|
388
|
+
) -> tuple[list[float], list[float], str]:
|
|
389
|
+
"""
|
|
390
|
+
Generate dual embeddings: one for the question, one for the conversation context.
|
|
391
|
+
|
|
392
|
+
This enables separate matching of question similarity vs context similarity,
|
|
393
|
+
improving precision by ensuring both the question AND the conversation context
|
|
394
|
+
are semantically similar before returning a cached result.
|
|
395
|
+
|
|
396
|
+
Args:
|
|
397
|
+
question: The question to embed
|
|
398
|
+
conversation_id: Optional conversation ID for retrieving context
|
|
399
|
+
|
|
400
|
+
Returns:
|
|
401
|
+
Tuple of (question_embedding, context_embedding, conversation_context_string)
|
|
402
|
+
- question_embedding: Vector embedding of just the question
|
|
403
|
+
- context_embedding: Vector embedding of the conversation context (or zero vector if no context)
|
|
404
|
+
- conversation_context_string: The conversation context string (empty if no context)
|
|
405
|
+
"""
|
|
406
|
+
conversation_context = ""
|
|
407
|
+
|
|
408
|
+
# If conversation context is enabled and available
|
|
409
|
+
if (
|
|
410
|
+
self.workspace_client is not None
|
|
411
|
+
and conversation_id is not None
|
|
412
|
+
and self.parameters.context_window_size > 0
|
|
413
|
+
):
|
|
414
|
+
try:
|
|
415
|
+
# Retrieve conversation history
|
|
416
|
+
conversation_messages = get_conversation_history(
|
|
417
|
+
workspace_client=self.workspace_client,
|
|
418
|
+
space_id=self.space_id,
|
|
419
|
+
conversation_id=conversation_id,
|
|
420
|
+
max_messages=self.parameters.context_window_size
|
|
421
|
+
* 2, # Get extra for safety
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
# Build context string (just the "Previous:" messages, not the current question)
|
|
425
|
+
if conversation_messages:
|
|
426
|
+
recent_messages = (
|
|
427
|
+
conversation_messages[-self.parameters.context_window_size :]
|
|
428
|
+
if len(conversation_messages)
|
|
429
|
+
> self.parameters.context_window_size
|
|
430
|
+
else conversation_messages
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
context_parts: list[str] = []
|
|
434
|
+
for msg in recent_messages:
|
|
435
|
+
if msg.content:
|
|
436
|
+
content: str = msg.content
|
|
437
|
+
if len(content) > 500:
|
|
438
|
+
content = content[:500] + "..."
|
|
439
|
+
context_parts.append(f"Previous: {content}")
|
|
440
|
+
|
|
441
|
+
conversation_context = "\n".join(context_parts)
|
|
442
|
+
|
|
443
|
+
# Truncate if too long
|
|
444
|
+
estimated_tokens = len(conversation_context) / 4
|
|
445
|
+
if estimated_tokens > self.parameters.max_context_tokens:
|
|
446
|
+
target_chars = self.parameters.max_context_tokens * 4
|
|
447
|
+
conversation_context = (
|
|
448
|
+
conversation_context[:target_chars] + "..."
|
|
449
|
+
)
|
|
450
|
+
|
|
451
|
+
logger.debug(
|
|
452
|
+
f"[{self.name}] Using conversation context: {len(conversation_messages)} messages "
|
|
453
|
+
f"(window_size={self.parameters.context_window_size})"
|
|
454
|
+
)
|
|
455
|
+
except Exception as e:
|
|
456
|
+
logger.warning(
|
|
457
|
+
f"[{self.name}] Failed to build conversation context, using question only: {e}"
|
|
458
|
+
)
|
|
459
|
+
conversation_context = ""
|
|
460
|
+
|
|
461
|
+
# Generate dual embeddings
|
|
462
|
+
if conversation_context:
|
|
463
|
+
# Embed both question and context
|
|
464
|
+
embeddings: list[list[float]] = self._embeddings.embed_documents(
|
|
465
|
+
[question, conversation_context]
|
|
466
|
+
)
|
|
467
|
+
question_embedding = embeddings[0]
|
|
468
|
+
context_embedding = embeddings[1]
|
|
469
|
+
else:
|
|
470
|
+
# Only embed question, use zero vector for context
|
|
471
|
+
embeddings = self._embeddings.embed_documents([question])
|
|
472
|
+
question_embedding = embeddings[0]
|
|
473
|
+
context_embedding = [0.0] * len(question_embedding) # Zero vector
|
|
474
|
+
|
|
475
|
+
return question_embedding, context_embedding, conversation_context
|
|
476
|
+
|
|
477
|
+
@mlflow.trace(name="semantic_search")
|
|
478
|
+
def _find_similar(
|
|
479
|
+
self,
|
|
480
|
+
question: str,
|
|
481
|
+
conversation_context: str,
|
|
482
|
+
question_embedding: list[float],
|
|
483
|
+
context_embedding: list[float],
|
|
484
|
+
) -> tuple[SQLCacheEntry, float] | None:
|
|
485
|
+
"""
|
|
486
|
+
Find a semantically similar cached entry using dual embedding matching.
|
|
487
|
+
|
|
488
|
+
This method matches BOTH the question AND the conversation context separately,
|
|
489
|
+
ensuring high precision by requiring both to be semantically similar.
|
|
490
|
+
|
|
491
|
+
Args:
|
|
492
|
+
question: The original question (for logging)
|
|
493
|
+
conversation_context: The conversation context string
|
|
494
|
+
question_embedding: The embedding vector of just the question
|
|
495
|
+
context_embedding: The embedding vector of the conversation context
|
|
496
|
+
|
|
497
|
+
Returns:
|
|
498
|
+
Tuple of (SQLCacheEntry, combined_similarity_score) if found, None otherwise
|
|
499
|
+
"""
|
|
500
|
+
# Use L2 (Euclidean) distance - optimal for Databricks GTE embeddings
|
|
501
|
+
# pg_vector's <-> operator returns L2 distance (0 = identical)
|
|
502
|
+
# Convert to similarity: 1 / (1 + distance) gives range [0, 1]
|
|
503
|
+
#
|
|
504
|
+
# Dual embedding strategy:
|
|
505
|
+
# 1. Calculate separate similarities for question and context
|
|
506
|
+
# 2. BOTH must exceed their respective thresholds
|
|
507
|
+
# 3. Combined score is weighted average
|
|
508
|
+
# 4. Refresh-on-hit: check TTL after similarity check
|
|
509
|
+
ttl_seconds = self.parameters.time_to_live_seconds
|
|
510
|
+
ttl_disabled = ttl_seconds is None or ttl_seconds < 0
|
|
511
|
+
|
|
512
|
+
# When TTL is disabled, all entries are always valid
|
|
513
|
+
if ttl_disabled:
|
|
514
|
+
is_valid_expr = "TRUE"
|
|
515
|
+
else:
|
|
516
|
+
is_valid_expr = f"created_at > NOW() - INTERVAL '{ttl_seconds} seconds'"
|
|
517
|
+
|
|
518
|
+
# Weighted combined similarity for ordering
|
|
519
|
+
question_weight: float = self.parameters.question_weight
|
|
520
|
+
context_weight: float = self.parameters.context_weight
|
|
521
|
+
|
|
522
|
+
search_sql: str = f"""
|
|
523
|
+
SELECT
|
|
524
|
+
id,
|
|
525
|
+
question,
|
|
526
|
+
conversation_context,
|
|
527
|
+
sql_query,
|
|
528
|
+
description,
|
|
529
|
+
conversation_id,
|
|
530
|
+
created_at,
|
|
531
|
+
1.0 / (1.0 + (question_embedding <-> %s::vector)) as question_similarity,
|
|
532
|
+
1.0 / (1.0 + (context_embedding <-> %s::vector)) as context_similarity,
|
|
533
|
+
({question_weight} * (1.0 / (1.0 + (question_embedding <-> %s::vector)))) +
|
|
534
|
+
({context_weight} * (1.0 / (1.0 + (context_embedding <-> %s::vector)))) as combined_similarity,
|
|
535
|
+
{is_valid_expr} as is_valid
|
|
536
|
+
FROM {self.table_name}
|
|
537
|
+
WHERE genie_space_id = %s
|
|
538
|
+
ORDER BY combined_similarity DESC
|
|
539
|
+
LIMIT 1
|
|
540
|
+
"""
|
|
541
|
+
|
|
542
|
+
question_emb_str: str = f"[{','.join(str(x) for x in question_embedding)}]"
|
|
543
|
+
context_emb_str: str = f"[{','.join(str(x) for x in context_embedding)}]"
|
|
544
|
+
|
|
545
|
+
with self._pool.connection() as conn:
|
|
546
|
+
with conn.cursor() as cur:
|
|
547
|
+
cur.execute(
|
|
548
|
+
search_sql,
|
|
549
|
+
(
|
|
550
|
+
question_emb_str,
|
|
551
|
+
context_emb_str,
|
|
552
|
+
question_emb_str,
|
|
553
|
+
context_emb_str,
|
|
554
|
+
self.space_id,
|
|
555
|
+
),
|
|
556
|
+
)
|
|
557
|
+
row: DbRow | None = cur.fetchone()
|
|
558
|
+
|
|
559
|
+
if row is None:
|
|
560
|
+
logger.info(
|
|
561
|
+
f"[{self.name}] MISS (no entries): "
|
|
562
|
+
f"question='{question[:50]}...' space='{self.space_id}'"
|
|
563
|
+
)
|
|
564
|
+
return None
|
|
565
|
+
|
|
566
|
+
# Extract values from dict row
|
|
567
|
+
entry_id: Any = row.get("id")
|
|
568
|
+
cached_question: str = row.get("question", "")
|
|
569
|
+
cached_context: str = row.get("conversation_context", "")
|
|
570
|
+
sql_query: str = row["sql_query"]
|
|
571
|
+
description: str = row.get("description", "")
|
|
572
|
+
conversation_id_cached: str = row.get("conversation_id", "")
|
|
573
|
+
created_at: Any = row["created_at"]
|
|
574
|
+
question_similarity: float = row["question_similarity"]
|
|
575
|
+
context_similarity: float = row["context_similarity"]
|
|
576
|
+
combined_similarity: float = row["combined_similarity"]
|
|
577
|
+
is_valid: bool = row.get("is_valid", False)
|
|
578
|
+
|
|
579
|
+
# Log best match info
|
|
580
|
+
logger.info(
|
|
581
|
+
f"[{self.name}] Best match: "
|
|
582
|
+
f"question_sim={question_similarity:.4f}, context_sim={context_similarity:.4f}, "
|
|
583
|
+
f"combined_sim={combined_similarity:.4f}, is_valid={is_valid}, "
|
|
584
|
+
f"question='{cached_question[:50]}...', context='{cached_context[:80]}...'"
|
|
585
|
+
)
|
|
586
|
+
|
|
587
|
+
# Check BOTH similarity thresholds (dual embedding precision check)
|
|
588
|
+
if question_similarity < self.parameters.similarity_threshold:
|
|
589
|
+
logger.info(
|
|
590
|
+
f"[{self.name}] MISS (question similarity too low): "
|
|
591
|
+
f"question_sim={question_similarity:.4f} < threshold={self.parameters.similarity_threshold}"
|
|
592
|
+
)
|
|
593
|
+
return None
|
|
594
|
+
|
|
595
|
+
if context_similarity < self.parameters.context_similarity_threshold:
|
|
596
|
+
logger.info(
|
|
597
|
+
f"[{self.name}] MISS (context similarity too low): "
|
|
598
|
+
f"context_sim={context_similarity:.4f} < threshold={self.parameters.context_similarity_threshold}"
|
|
599
|
+
)
|
|
600
|
+
return None
|
|
601
|
+
|
|
602
|
+
# Check TTL - refresh on hit strategy
|
|
603
|
+
if not is_valid:
|
|
604
|
+
# Entry is expired - delete it and return miss to trigger refresh
|
|
605
|
+
delete_sql = f"DELETE FROM {self.table_name} WHERE id = %s"
|
|
606
|
+
cur.execute(delete_sql, (entry_id,))
|
|
607
|
+
logger.info(
|
|
608
|
+
f"[{self.name}] MISS (expired, deleted for refresh): "
|
|
609
|
+
f"combined_sim={combined_similarity:.4f}, ttl={ttl_seconds}s, question='{cached_question[:50]}...'"
|
|
610
|
+
)
|
|
611
|
+
return None
|
|
612
|
+
|
|
613
|
+
logger.info(
|
|
614
|
+
f"[{self.name}] HIT: question_sim={question_similarity:.4f}, context_sim={context_similarity:.4f}, "
|
|
615
|
+
f"combined_sim={combined_similarity:.4f} (cached_question='{cached_question[:50]}...')"
|
|
616
|
+
)
|
|
617
|
+
|
|
618
|
+
entry = SQLCacheEntry(
|
|
619
|
+
query=sql_query,
|
|
620
|
+
description=description,
|
|
621
|
+
conversation_id=conversation_id_cached,
|
|
622
|
+
created_at=created_at,
|
|
623
|
+
)
|
|
624
|
+
return entry, combined_similarity
|
|
625
|
+
|
|
626
|
+
def _store_entry(
|
|
627
|
+
self,
|
|
628
|
+
question: str,
|
|
629
|
+
conversation_context: str,
|
|
630
|
+
question_embedding: list[float],
|
|
631
|
+
context_embedding: list[float],
|
|
632
|
+
response: GenieResponse,
|
|
633
|
+
) -> None:
|
|
634
|
+
"""Store a new cache entry with dual embeddings for this Genie space."""
|
|
635
|
+
insert_sql: str = f"""
|
|
636
|
+
INSERT INTO {self.table_name}
|
|
637
|
+
(genie_space_id, question, conversation_context, context_string,
|
|
638
|
+
question_embedding, context_embedding, sql_query, description, conversation_id)
|
|
639
|
+
VALUES (%s, %s, %s, %s, %s::vector, %s::vector, %s, %s, %s)
|
|
640
|
+
"""
|
|
641
|
+
question_emb_str: str = f"[{','.join(str(x) for x in question_embedding)}]"
|
|
642
|
+
context_emb_str: str = f"[{','.join(str(x) for x in context_embedding)}]"
|
|
643
|
+
|
|
644
|
+
# Build full context string for backward compatibility (used in logging)
|
|
645
|
+
if conversation_context:
|
|
646
|
+
full_context_string = f"{conversation_context}\nCurrent: {question}"
|
|
647
|
+
else:
|
|
648
|
+
full_context_string = question
|
|
649
|
+
|
|
650
|
+
with self._pool.connection() as conn:
|
|
651
|
+
with conn.cursor() as cur:
|
|
652
|
+
cur.execute(
|
|
653
|
+
insert_sql,
|
|
654
|
+
(
|
|
655
|
+
self.space_id,
|
|
656
|
+
question,
|
|
657
|
+
conversation_context,
|
|
658
|
+
full_context_string,
|
|
659
|
+
question_emb_str,
|
|
660
|
+
context_emb_str,
|
|
661
|
+
response.query,
|
|
662
|
+
response.description,
|
|
663
|
+
response.conversation_id,
|
|
664
|
+
),
|
|
665
|
+
)
|
|
666
|
+
logger.info(
|
|
667
|
+
f"[{self.name}] Stored cache entry: question='{question[:50]}...' "
|
|
668
|
+
f"context='{conversation_context[:80]}...' "
|
|
669
|
+
f"sql='{response.query[:50]}...' (space={self.space_id}, table={self.table_name})"
|
|
670
|
+
)
|
|
671
|
+
|
|
672
|
+
@mlflow.trace(name="execute_cached_sql_semantic")
|
|
673
|
+
def _execute_sql(self, sql: str) -> pd.DataFrame | str:
|
|
674
|
+
"""Execute SQL using the warehouse and return results."""
|
|
675
|
+
client: WorkspaceClient = self.warehouse.workspace_client
|
|
676
|
+
warehouse_id: str = str(self.warehouse.warehouse_id)
|
|
677
|
+
|
|
678
|
+
statement_response: StatementResponse = (
|
|
679
|
+
client.statement_execution.execute_statement(
|
|
680
|
+
warehouse_id=warehouse_id,
|
|
681
|
+
statement=sql,
|
|
682
|
+
wait_timeout="30s",
|
|
683
|
+
)
|
|
684
|
+
)
|
|
685
|
+
|
|
686
|
+
if (
|
|
687
|
+
statement_response.status is not None
|
|
688
|
+
and statement_response.status.state != StatementState.SUCCEEDED
|
|
689
|
+
):
|
|
690
|
+
error_msg: str = (
|
|
691
|
+
f"SQL execution failed: {statement_response.status.error.message}"
|
|
692
|
+
if statement_response.status.error is not None
|
|
693
|
+
else f"SQL execution failed with state: {statement_response.status.state}"
|
|
694
|
+
)
|
|
695
|
+
logger.error(f"[{self.name}] {error_msg}")
|
|
696
|
+
return error_msg
|
|
697
|
+
|
|
698
|
+
if statement_response.result and statement_response.result.data_array:
|
|
699
|
+
columns: list[str] = []
|
|
700
|
+
if (
|
|
701
|
+
statement_response.manifest
|
|
702
|
+
and statement_response.manifest.schema
|
|
703
|
+
and statement_response.manifest.schema.columns
|
|
704
|
+
):
|
|
705
|
+
columns = [
|
|
706
|
+
col.name
|
|
707
|
+
for col in statement_response.manifest.schema.columns
|
|
708
|
+
if col.name is not None
|
|
709
|
+
]
|
|
710
|
+
|
|
711
|
+
data: list[list[Any]] = statement_response.result.data_array
|
|
712
|
+
if columns:
|
|
713
|
+
return pd.DataFrame(data, columns=columns)
|
|
714
|
+
else:
|
|
715
|
+
return pd.DataFrame(data)
|
|
716
|
+
|
|
717
|
+
return pd.DataFrame()
|
|
718
|
+
|
|
719
|
+
def ask_question(
|
|
720
|
+
self, question: str, conversation_id: str | None = None
|
|
721
|
+
) -> CacheResult:
|
|
722
|
+
"""
|
|
723
|
+
Ask a question, using semantic cache if a similar query exists.
|
|
724
|
+
|
|
725
|
+
On cache hit, re-executes the cached SQL to get fresh data.
|
|
726
|
+
Returns CacheResult with cache metadata.
|
|
727
|
+
"""
|
|
728
|
+
return self.ask_question_with_cache_info(question, conversation_id)
|
|
729
|
+
|
|
730
|
+
@mlflow.trace(name="genie_semantic_cache_lookup")
|
|
731
|
+
def ask_question_with_cache_info(
|
|
732
|
+
self,
|
|
733
|
+
question: str,
|
|
734
|
+
conversation_id: str | None = None,
|
|
735
|
+
) -> CacheResult:
|
|
736
|
+
"""
|
|
737
|
+
Ask a question with detailed cache hit information.
|
|
738
|
+
|
|
739
|
+
On cache hit, the cached SQL is re-executed to return fresh data, but the
|
|
740
|
+
conversation_id returned is the current conversation_id (not the cached one).
|
|
741
|
+
|
|
742
|
+
Args:
|
|
743
|
+
question: The question to ask
|
|
744
|
+
conversation_id: Optional conversation ID for context and continuation
|
|
745
|
+
|
|
746
|
+
Returns:
|
|
747
|
+
CacheResult with fresh response and cache metadata
|
|
748
|
+
"""
|
|
749
|
+
# Ensure initialization (lazy init if initialize() wasn't called)
|
|
750
|
+
self._setup()
|
|
751
|
+
|
|
752
|
+
# Generate dual embeddings for the question and conversation context
|
|
753
|
+
question_embedding: list[float]
|
|
754
|
+
context_embedding: list[float]
|
|
755
|
+
conversation_context: str
|
|
756
|
+
question_embedding, context_embedding, conversation_context = (
|
|
757
|
+
self._embed_question(question, conversation_id)
|
|
758
|
+
)
|
|
759
|
+
|
|
760
|
+
# Check cache using dual embedding similarity
|
|
761
|
+
cache_result: tuple[SQLCacheEntry, float] | None = self._find_similar(
|
|
762
|
+
question, conversation_context, question_embedding, context_embedding
|
|
763
|
+
)
|
|
764
|
+
|
|
765
|
+
if cache_result is not None:
|
|
766
|
+
cached, combined_similarity = cache_result
|
|
767
|
+
logger.debug(
|
|
768
|
+
f"[{self.name}] Semantic cache hit (combined_similarity={combined_similarity:.3f}): {question[:50]}..."
|
|
769
|
+
)
|
|
770
|
+
|
|
771
|
+
# Re-execute the cached SQL to get fresh data
|
|
772
|
+
result: pd.DataFrame | str = self._execute_sql(cached.query)
|
|
773
|
+
|
|
774
|
+
# IMPORTANT: Use the current conversation_id (from the request), not the cached one
|
|
775
|
+
# This ensures the conversation continues properly
|
|
776
|
+
response: GenieResponse = GenieResponse(
|
|
777
|
+
result=result,
|
|
778
|
+
query=cached.query,
|
|
779
|
+
description=cached.description,
|
|
780
|
+
conversation_id=conversation_id
|
|
781
|
+
if conversation_id
|
|
782
|
+
else cached.conversation_id,
|
|
783
|
+
)
|
|
784
|
+
|
|
785
|
+
return CacheResult(response=response, cache_hit=True, served_by=self.name)
|
|
786
|
+
|
|
787
|
+
# Cache miss - delegate to wrapped service
|
|
788
|
+
logger.debug(f"[{self.name}] Miss: {question[:50]}...")
|
|
789
|
+
|
|
790
|
+
result: CacheResult = self.impl.ask_question(question, conversation_id)
|
|
791
|
+
|
|
792
|
+
# Store in cache if we got a SQL query
|
|
793
|
+
if result.response.query:
|
|
794
|
+
logger.info(
|
|
795
|
+
f"[{self.name}] Storing new cache entry for question: '{question[:50]}...' "
|
|
796
|
+
f"(space={self.space_id})"
|
|
797
|
+
)
|
|
798
|
+
self._store_entry(
|
|
799
|
+
question,
|
|
800
|
+
conversation_context,
|
|
801
|
+
question_embedding,
|
|
802
|
+
context_embedding,
|
|
803
|
+
result.response,
|
|
804
|
+
)
|
|
805
|
+
elif not result.response.query:
|
|
806
|
+
logger.warning(
|
|
807
|
+
f"[{self.name}] Not caching: response has no SQL query "
|
|
808
|
+
f"(question='{question[:50]}...')"
|
|
809
|
+
)
|
|
810
|
+
|
|
811
|
+
return CacheResult(response=result.response, cache_hit=False, served_by=None)
|
|
812
|
+
|
|
813
|
+
@property
|
|
814
|
+
def space_id(self) -> str:
|
|
815
|
+
return self.impl.space_id
|
|
816
|
+
|
|
817
|
+
def invalidate_expired(self) -> int:
|
|
818
|
+
"""Remove expired entries from the cache for this Genie space.
|
|
819
|
+
|
|
820
|
+
Returns 0 if TTL is disabled (entries never expire).
|
|
821
|
+
"""
|
|
822
|
+
self._setup()
|
|
823
|
+
ttl_seconds = self.parameters.time_to_live_seconds
|
|
824
|
+
|
|
825
|
+
# If TTL is disabled, nothing can expire
|
|
826
|
+
if ttl_seconds is None or ttl_seconds < 0:
|
|
827
|
+
logger.debug(
|
|
828
|
+
f"[{self.name}] TTL disabled, no entries to expire for space {self.space_id}"
|
|
829
|
+
)
|
|
830
|
+
return 0
|
|
831
|
+
|
|
832
|
+
delete_sql: str = f"""
|
|
833
|
+
DELETE FROM {self.table_name}
|
|
834
|
+
WHERE genie_space_id = %s
|
|
835
|
+
AND created_at < NOW() - INTERVAL '%s seconds'
|
|
836
|
+
"""
|
|
837
|
+
|
|
838
|
+
with self._pool.connection() as conn:
|
|
839
|
+
with conn.cursor() as cur:
|
|
840
|
+
cur.execute(delete_sql, (self.space_id, ttl_seconds))
|
|
841
|
+
deleted: int = cur.rowcount
|
|
842
|
+
logger.debug(
|
|
843
|
+
f"[{self.name}] Deleted {deleted} expired entries for space {self.space_id}"
|
|
844
|
+
)
|
|
845
|
+
return deleted
|
|
846
|
+
|
|
847
|
+
def clear(self) -> int:
|
|
848
|
+
"""Clear all entries from the cache for this Genie space."""
|
|
849
|
+
self._setup()
|
|
850
|
+
delete_sql: str = f"DELETE FROM {self.table_name} WHERE genie_space_id = %s"
|
|
851
|
+
|
|
852
|
+
with self._pool.connection() as conn:
|
|
853
|
+
with conn.cursor() as cur:
|
|
854
|
+
cur.execute(delete_sql, (self.space_id,))
|
|
855
|
+
deleted: int = cur.rowcount
|
|
856
|
+
logger.debug(
|
|
857
|
+
f"[{self.name}] Cleared {deleted} entries for space {self.space_id}"
|
|
858
|
+
)
|
|
859
|
+
return deleted
|
|
860
|
+
|
|
861
|
+
@property
|
|
862
|
+
def size(self) -> int:
|
|
863
|
+
"""Current number of entries in the cache for this Genie space."""
|
|
864
|
+
self._setup()
|
|
865
|
+
count_sql: str = (
|
|
866
|
+
f"SELECT COUNT(*) as count FROM {self.table_name} WHERE genie_space_id = %s"
|
|
867
|
+
)
|
|
868
|
+
|
|
869
|
+
with self._pool.connection() as conn:
|
|
870
|
+
with conn.cursor() as cur:
|
|
871
|
+
cur.execute(count_sql, (self.space_id,))
|
|
872
|
+
row: DbRow | None = cur.fetchone()
|
|
873
|
+
return row.get("count", 0) if row else 0
|
|
874
|
+
|
|
875
|
+
def stats(self) -> dict[str, int | float | None]:
|
|
876
|
+
"""Return cache statistics for this Genie space."""
|
|
877
|
+
self._setup()
|
|
878
|
+
ttl_seconds = self.parameters.time_to_live_seconds
|
|
879
|
+
ttl = self.time_to_live
|
|
880
|
+
|
|
881
|
+
# If TTL is disabled, all entries are valid
|
|
882
|
+
if ttl_seconds is None or ttl_seconds < 0:
|
|
883
|
+
count_sql: str = f"""
|
|
884
|
+
SELECT COUNT(*) as total FROM {self.table_name}
|
|
885
|
+
WHERE genie_space_id = %s
|
|
886
|
+
"""
|
|
887
|
+
with self._pool.connection() as conn:
|
|
888
|
+
with conn.cursor() as cur:
|
|
889
|
+
cur.execute(count_sql, (self.space_id,))
|
|
890
|
+
row: DbRow | None = cur.fetchone()
|
|
891
|
+
total = row.get("total", 0) if row else 0
|
|
892
|
+
return {
|
|
893
|
+
"size": total,
|
|
894
|
+
"ttl_seconds": None,
|
|
895
|
+
"similarity_threshold": self.similarity_threshold,
|
|
896
|
+
"expired_entries": 0,
|
|
897
|
+
"valid_entries": total,
|
|
898
|
+
}
|
|
899
|
+
|
|
900
|
+
stats_sql: str = f"""
|
|
901
|
+
SELECT
|
|
902
|
+
COUNT(*) as total,
|
|
903
|
+
COUNT(*) FILTER (WHERE created_at > NOW() - INTERVAL '%s seconds') as valid,
|
|
904
|
+
COUNT(*) FILTER (WHERE created_at <= NOW() - INTERVAL '%s seconds') as expired
|
|
905
|
+
FROM {self.table_name}
|
|
906
|
+
WHERE genie_space_id = %s
|
|
907
|
+
"""
|
|
908
|
+
|
|
909
|
+
with self._pool.connection() as conn:
|
|
910
|
+
with conn.cursor() as cur:
|
|
911
|
+
cur.execute(stats_sql, (ttl_seconds, ttl_seconds, self.space_id))
|
|
912
|
+
stats_row: DbRow | None = cur.fetchone()
|
|
913
|
+
return {
|
|
914
|
+
"size": stats_row.get("total", 0) if stats_row else 0,
|
|
915
|
+
"ttl_seconds": ttl.total_seconds() if ttl else None,
|
|
916
|
+
"similarity_threshold": self.similarity_threshold,
|
|
917
|
+
"expired_entries": stats_row.get("expired", 0) if stats_row else 0,
|
|
918
|
+
"valid_entries": stats_row.get("valid", 0) if stats_row else 0,
|
|
919
|
+
}
|