dao-ai 0.0.36__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/__init__.py +29 -0
- dao_ai/cli.py +195 -30
- dao_ai/config.py +770 -244
- dao_ai/genie/__init__.py +1 -22
- dao_ai/genie/cache/__init__.py +1 -2
- dao_ai/genie/cache/base.py +20 -70
- dao_ai/genie/cache/core.py +75 -0
- dao_ai/genie/cache/lru.py +44 -21
- dao_ai/genie/cache/semantic.py +390 -109
- dao_ai/genie/core.py +35 -0
- dao_ai/graph.py +27 -253
- dao_ai/hooks/__init__.py +9 -6
- dao_ai/hooks/core.py +22 -190
- dao_ai/memory/__init__.py +10 -0
- dao_ai/memory/core.py +23 -5
- dao_ai/memory/databricks.py +389 -0
- dao_ai/memory/postgres.py +2 -2
- dao_ai/messages.py +6 -4
- dao_ai/middleware/__init__.py +125 -0
- dao_ai/middleware/assertions.py +778 -0
- dao_ai/middleware/base.py +50 -0
- dao_ai/middleware/core.py +61 -0
- dao_ai/middleware/guardrails.py +415 -0
- dao_ai/middleware/human_in_the_loop.py +228 -0
- dao_ai/middleware/message_validation.py +554 -0
- dao_ai/middleware/summarization.py +192 -0
- dao_ai/models.py +1177 -108
- dao_ai/nodes.py +118 -161
- dao_ai/optimization.py +664 -0
- dao_ai/orchestration/__init__.py +52 -0
- dao_ai/orchestration/core.py +287 -0
- dao_ai/orchestration/supervisor.py +264 -0
- dao_ai/orchestration/swarm.py +226 -0
- dao_ai/prompts.py +126 -29
- dao_ai/providers/databricks.py +126 -381
- dao_ai/state.py +139 -21
- dao_ai/tools/__init__.py +8 -5
- dao_ai/tools/core.py +57 -4
- dao_ai/tools/email.py +280 -0
- dao_ai/tools/genie.py +47 -24
- dao_ai/tools/mcp.py +4 -3
- dao_ai/tools/memory.py +50 -0
- dao_ai/tools/python.py +4 -12
- dao_ai/tools/search.py +14 -0
- dao_ai/tools/slack.py +1 -1
- dao_ai/tools/unity_catalog.py +8 -6
- dao_ai/tools/vector_search.py +16 -9
- dao_ai/utils.py +72 -8
- dao_ai-0.1.1.dist-info/METADATA +1878 -0
- dao_ai-0.1.1.dist-info/RECORD +62 -0
- dao_ai/chat_models.py +0 -204
- dao_ai/guardrails.py +0 -112
- dao_ai/tools/genie/__init__.py +0 -236
- dao_ai/tools/human_in_the_loop.py +0 -100
- dao_ai-0.0.36.dist-info/METADATA +0 -951
- dao_ai-0.0.36.dist-info/RECORD +0 -47
- {dao_ai-0.0.36.dist-info → dao_ai-0.1.1.dist-info}/WHEEL +0 -0
- {dao_ai-0.0.36.dist-info → dao_ai-0.1.1.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.36.dist-info → dao_ai-0.1.1.dist-info}/licenses/LICENSE +0 -0
dao_ai/genie/cache/semantic.py
CHANGED
|
@@ -4,6 +4,10 @@ Semantic cache implementation for Genie SQL queries using PostgreSQL pg_vector.
|
|
|
4
4
|
This module provides a semantic cache that uses embeddings and similarity search
|
|
5
5
|
to find cached queries that match the intent of new questions. Cache entries are
|
|
6
6
|
partitioned by genie_space_id to ensure proper isolation between Genie spaces.
|
|
7
|
+
|
|
8
|
+
The cache supports conversation-aware embedding using a rolling window approach
|
|
9
|
+
to capture context from recent conversation turns, improving accuracy for
|
|
10
|
+
multi-turn conversations with anaphoric references.
|
|
7
11
|
"""
|
|
8
12
|
|
|
9
13
|
from datetime import timedelta
|
|
@@ -12,14 +16,18 @@ from typing import Any
|
|
|
12
16
|
import mlflow
|
|
13
17
|
import pandas as pd
|
|
14
18
|
from databricks.sdk import WorkspaceClient
|
|
19
|
+
from databricks.sdk.service.dashboards import (
|
|
20
|
+
GenieListConversationMessagesResponse,
|
|
21
|
+
GenieMessage,
|
|
22
|
+
)
|
|
15
23
|
from databricks.sdk.service.sql import StatementResponse, StatementState
|
|
16
24
|
from databricks_ai_bridge.genie import GenieResponse
|
|
17
25
|
from loguru import logger
|
|
18
|
-
from mlflow.entities import SpanType
|
|
19
26
|
|
|
20
27
|
from dao_ai.config import (
|
|
21
28
|
DatabaseModel,
|
|
22
29
|
GenieSemanticCacheParametersModel,
|
|
30
|
+
LLMModel,
|
|
23
31
|
WarehouseModel,
|
|
24
32
|
)
|
|
25
33
|
from dao_ai.genie.cache.base import (
|
|
@@ -32,6 +40,112 @@ from dao_ai.genie.cache.base import (
|
|
|
32
40
|
DbRow = dict[str, Any]
|
|
33
41
|
|
|
34
42
|
|
|
43
|
+
def get_conversation_history(
|
|
44
|
+
workspace_client: WorkspaceClient,
|
|
45
|
+
space_id: str,
|
|
46
|
+
conversation_id: str,
|
|
47
|
+
max_messages: int = 10,
|
|
48
|
+
) -> list[GenieMessage]:
|
|
49
|
+
"""
|
|
50
|
+
Retrieve conversation history from Genie.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
workspace_client: The Databricks workspace client
|
|
54
|
+
space_id: The Genie space ID
|
|
55
|
+
conversation_id: The conversation ID to retrieve
|
|
56
|
+
max_messages: Maximum number of messages to retrieve
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
List of GenieMessage objects representing the conversation history
|
|
60
|
+
"""
|
|
61
|
+
try:
|
|
62
|
+
# Use the Genie API to retrieve conversation messages
|
|
63
|
+
response: GenieListConversationMessagesResponse = (
|
|
64
|
+
workspace_client.genie.list_conversation_messages(
|
|
65
|
+
space_id=space_id,
|
|
66
|
+
conversation_id=conversation_id,
|
|
67
|
+
)
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
# Return the most recent messages up to max_messages
|
|
71
|
+
if response.messages is not None:
|
|
72
|
+
all_messages: list[GenieMessage] = list(response.messages)
|
|
73
|
+
return (
|
|
74
|
+
all_messages[-max_messages:]
|
|
75
|
+
if len(all_messages) > max_messages
|
|
76
|
+
else all_messages
|
|
77
|
+
)
|
|
78
|
+
return []
|
|
79
|
+
except Exception as e:
|
|
80
|
+
logger.warning(
|
|
81
|
+
f"Failed to retrieve conversation history for conversation_id={conversation_id}: {e}"
|
|
82
|
+
)
|
|
83
|
+
return []
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def build_context_string(
|
|
87
|
+
question: str,
|
|
88
|
+
conversation_messages: list[GenieMessage],
|
|
89
|
+
window_size: int,
|
|
90
|
+
max_tokens: int = 2000,
|
|
91
|
+
) -> str:
|
|
92
|
+
"""
|
|
93
|
+
Build a context-aware question string using rolling window.
|
|
94
|
+
|
|
95
|
+
This function creates a concatenated string that includes recent conversation
|
|
96
|
+
turns to provide context for semantic similarity matching.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
question: The current question
|
|
100
|
+
conversation_messages: List of previous conversation messages
|
|
101
|
+
window_size: Number of previous turns to include
|
|
102
|
+
max_tokens: Maximum estimated tokens (rough approximation: 4 chars = 1 token)
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
Context-aware question string formatted for embedding
|
|
106
|
+
"""
|
|
107
|
+
if window_size <= 0 or not conversation_messages:
|
|
108
|
+
return question
|
|
109
|
+
|
|
110
|
+
# Take the last window_size messages (most recent)
|
|
111
|
+
recent_messages = (
|
|
112
|
+
conversation_messages[-window_size:]
|
|
113
|
+
if len(conversation_messages) > window_size
|
|
114
|
+
else conversation_messages
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
# Build context parts
|
|
118
|
+
context_parts: list[str] = []
|
|
119
|
+
|
|
120
|
+
for msg in recent_messages:
|
|
121
|
+
# Only include messages with content from the history
|
|
122
|
+
if msg.content:
|
|
123
|
+
# Limit message length to prevent token overflow
|
|
124
|
+
content: str = msg.content
|
|
125
|
+
if len(content) > 500: # Truncate very long messages
|
|
126
|
+
content = content[:500] + "..."
|
|
127
|
+
context_parts.append(f"Previous: {content}")
|
|
128
|
+
|
|
129
|
+
# Add current question
|
|
130
|
+
context_parts.append(f"Current: {question}")
|
|
131
|
+
|
|
132
|
+
# Join with newlines
|
|
133
|
+
context_string = "\n".join(context_parts)
|
|
134
|
+
|
|
135
|
+
# Rough token limit check (4 chars ≈ 1 token)
|
|
136
|
+
estimated_tokens = len(context_string) / 4
|
|
137
|
+
if estimated_tokens > max_tokens:
|
|
138
|
+
# Truncate to fit max_tokens
|
|
139
|
+
target_chars = max_tokens * 4
|
|
140
|
+
context_string = context_string[:target_chars] + "..."
|
|
141
|
+
logger.debug(
|
|
142
|
+
f"Truncated context string from {len(context_string)} to {target_chars} chars "
|
|
143
|
+
f"(estimated {max_tokens} tokens)"
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
return context_string
|
|
147
|
+
|
|
148
|
+
|
|
35
149
|
class SemanticCacheService(GenieServiceBase):
|
|
36
150
|
"""
|
|
37
151
|
Semantic caching decorator that uses PostgreSQL pg_vector for similarity lookup.
|
|
@@ -59,8 +173,7 @@ class SemanticCacheService(GenieServiceBase):
|
|
|
59
173
|
)
|
|
60
174
|
genie = SemanticCacheService(
|
|
61
175
|
impl=GenieService(Genie(space_id="my-space")),
|
|
62
|
-
parameters=cache_params
|
|
63
|
-
genie_space_id="my-space"
|
|
176
|
+
parameters=cache_params
|
|
64
177
|
)
|
|
65
178
|
|
|
66
179
|
Thread-safe: Uses connection pooling from psycopg_pool.
|
|
@@ -68,7 +181,7 @@ class SemanticCacheService(GenieServiceBase):
|
|
|
68
181
|
|
|
69
182
|
impl: GenieServiceBase
|
|
70
183
|
parameters: GenieSemanticCacheParametersModel
|
|
71
|
-
|
|
184
|
+
workspace_client: WorkspaceClient | None
|
|
72
185
|
name: str
|
|
73
186
|
_embeddings: Any # DatabricksEmbeddings
|
|
74
187
|
_pool: Any # ConnectionPool
|
|
@@ -79,21 +192,23 @@ class SemanticCacheService(GenieServiceBase):
|
|
|
79
192
|
self,
|
|
80
193
|
impl: GenieServiceBase,
|
|
81
194
|
parameters: GenieSemanticCacheParametersModel,
|
|
82
|
-
|
|
195
|
+
workspace_client: WorkspaceClient | None = None,
|
|
83
196
|
name: str | None = None,
|
|
84
197
|
) -> None:
|
|
85
198
|
"""
|
|
86
199
|
Initialize the semantic cache service.
|
|
87
200
|
|
|
88
201
|
Args:
|
|
89
|
-
impl: The underlying GenieServiceBase to delegate to on cache miss
|
|
202
|
+
impl: The underlying GenieServiceBase to delegate to on cache miss.
|
|
203
|
+
The space_id will be obtained from impl.space_id.
|
|
90
204
|
parameters: Cache configuration including database, warehouse, embedding model
|
|
91
|
-
|
|
205
|
+
workspace_client: Optional WorkspaceClient for retrieving conversation history.
|
|
206
|
+
If None, conversation context will not be used.
|
|
92
207
|
name: Name for this cache layer (for logging). Defaults to class name.
|
|
93
208
|
"""
|
|
94
209
|
self.impl = impl
|
|
95
210
|
self.parameters = parameters
|
|
96
|
-
self.
|
|
211
|
+
self.workspace_client = workspace_client
|
|
97
212
|
self.name = name if name is not None else self.__class__.__name__
|
|
98
213
|
self._embeddings = None
|
|
99
214
|
self._pool = None
|
|
@@ -120,17 +235,16 @@ class SemanticCacheService(GenieServiceBase):
|
|
|
120
235
|
if self._setup_complete:
|
|
121
236
|
return
|
|
122
237
|
|
|
123
|
-
from databricks_langchain import DatabricksEmbeddings
|
|
124
|
-
|
|
125
238
|
from dao_ai.memory.postgres import PostgresPoolManager
|
|
126
239
|
|
|
127
240
|
# Initialize embeddings
|
|
128
|
-
embedding_model
|
|
129
|
-
|
|
241
|
+
# Convert embedding_model to LLMModel if it's a string
|
|
242
|
+
embedding_model: LLMModel = (
|
|
243
|
+
LLMModel(name=self.parameters.embedding_model)
|
|
130
244
|
if isinstance(self.parameters.embedding_model, str)
|
|
131
|
-
else self.parameters.embedding_model
|
|
245
|
+
else self.parameters.embedding_model
|
|
132
246
|
)
|
|
133
|
-
self._embeddings =
|
|
247
|
+
self._embeddings = embedding_model.as_embeddings_model()
|
|
134
248
|
|
|
135
249
|
# Auto-detect embedding dimensions if not provided
|
|
136
250
|
if self.parameters.embedding_dims is None:
|
|
@@ -150,7 +264,7 @@ class SemanticCacheService(GenieServiceBase):
|
|
|
150
264
|
|
|
151
265
|
self._setup_complete = True
|
|
152
266
|
logger.debug(
|
|
153
|
-
f"[{self.name}] Semantic cache initialized for space '{self.
|
|
267
|
+
f"[{self.name}] Semantic cache initialized for space '{self.space_id}' "
|
|
154
268
|
f"with table '{self.table_name}' (dims={self._embedding_dims})"
|
|
155
269
|
)
|
|
156
270
|
|
|
@@ -212,7 +326,10 @@ class SemanticCacheService(GenieServiceBase):
|
|
|
212
326
|
id SERIAL PRIMARY KEY,
|
|
213
327
|
genie_space_id TEXT NOT NULL,
|
|
214
328
|
question TEXT NOT NULL,
|
|
329
|
+
conversation_context TEXT,
|
|
330
|
+
context_string TEXT,
|
|
215
331
|
question_embedding vector({self.embedding_dims}),
|
|
332
|
+
context_embedding vector({self.embedding_dims}),
|
|
216
333
|
sql_query TEXT NOT NULL,
|
|
217
334
|
description TEXT,
|
|
218
335
|
conversation_id TEXT,
|
|
@@ -221,12 +338,18 @@ class SemanticCacheService(GenieServiceBase):
|
|
|
221
338
|
"""
|
|
222
339
|
# Index for efficient similarity search partitioned by genie_space_id
|
|
223
340
|
# Use L2 (Euclidean) distance - optimal for Databricks GTE embeddings
|
|
224
|
-
|
|
225
|
-
CREATE INDEX IF NOT EXISTS {self.table_name}
|
|
341
|
+
create_question_embedding_index_sql: str = f"""
|
|
342
|
+
CREATE INDEX IF NOT EXISTS {self.table_name}_question_embedding_idx
|
|
226
343
|
ON {self.table_name}
|
|
227
344
|
USING ivfflat (question_embedding vector_l2_ops)
|
|
228
345
|
WITH (lists = 100)
|
|
229
346
|
"""
|
|
347
|
+
create_context_embedding_index_sql: str = f"""
|
|
348
|
+
CREATE INDEX IF NOT EXISTS {self.table_name}_context_embedding_idx
|
|
349
|
+
ON {self.table_name}
|
|
350
|
+
USING ivfflat (context_embedding vector_l2_ops)
|
|
351
|
+
WITH (lists = 100)
|
|
352
|
+
"""
|
|
230
353
|
# Index for filtering by genie_space_id
|
|
231
354
|
create_space_index_sql: str = f"""
|
|
232
355
|
CREATE INDEX IF NOT EXISTS {self.table_name}_space_idx
|
|
@@ -257,35 +380,132 @@ class SemanticCacheService(GenieServiceBase):
|
|
|
257
380
|
|
|
258
381
|
cur.execute(create_table_sql)
|
|
259
382
|
cur.execute(create_space_index_sql)
|
|
260
|
-
cur.execute(
|
|
383
|
+
cur.execute(create_question_embedding_index_sql)
|
|
384
|
+
cur.execute(create_context_embedding_index_sql)
|
|
385
|
+
|
|
386
|
+
def _embed_question(
|
|
387
|
+
self, question: str, conversation_id: str | None = None
|
|
388
|
+
) -> tuple[list[float], list[float], str]:
|
|
389
|
+
"""
|
|
390
|
+
Generate dual embeddings: one for the question, one for the conversation context.
|
|
391
|
+
|
|
392
|
+
This enables separate matching of question similarity vs context similarity,
|
|
393
|
+
improving precision by ensuring both the question AND the conversation context
|
|
394
|
+
are semantically similar before returning a cached result.
|
|
395
|
+
|
|
396
|
+
Args:
|
|
397
|
+
question: The question to embed
|
|
398
|
+
conversation_id: Optional conversation ID for retrieving context
|
|
399
|
+
|
|
400
|
+
Returns:
|
|
401
|
+
Tuple of (question_embedding, context_embedding, conversation_context_string)
|
|
402
|
+
- question_embedding: Vector embedding of just the question
|
|
403
|
+
- context_embedding: Vector embedding of the conversation context (or zero vector if no context)
|
|
404
|
+
- conversation_context_string: The conversation context string (empty if no context)
|
|
405
|
+
"""
|
|
406
|
+
conversation_context = ""
|
|
407
|
+
|
|
408
|
+
# If conversation context is enabled and available
|
|
409
|
+
if (
|
|
410
|
+
self.workspace_client is not None
|
|
411
|
+
and conversation_id is not None
|
|
412
|
+
and self.parameters.context_window_size > 0
|
|
413
|
+
):
|
|
414
|
+
try:
|
|
415
|
+
# Retrieve conversation history
|
|
416
|
+
conversation_messages = get_conversation_history(
|
|
417
|
+
workspace_client=self.workspace_client,
|
|
418
|
+
space_id=self.space_id,
|
|
419
|
+
conversation_id=conversation_id,
|
|
420
|
+
max_messages=self.parameters.context_window_size
|
|
421
|
+
* 2, # Get extra for safety
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
# Build context string (just the "Previous:" messages, not the current question)
|
|
425
|
+
if conversation_messages:
|
|
426
|
+
recent_messages = (
|
|
427
|
+
conversation_messages[-self.parameters.context_window_size :]
|
|
428
|
+
if len(conversation_messages)
|
|
429
|
+
> self.parameters.context_window_size
|
|
430
|
+
else conversation_messages
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
context_parts: list[str] = []
|
|
434
|
+
for msg in recent_messages:
|
|
435
|
+
if msg.content:
|
|
436
|
+
content: str = msg.content
|
|
437
|
+
if len(content) > 500:
|
|
438
|
+
content = content[:500] + "..."
|
|
439
|
+
context_parts.append(f"Previous: {content}")
|
|
440
|
+
|
|
441
|
+
conversation_context = "\n".join(context_parts)
|
|
261
442
|
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
443
|
+
# Truncate if too long
|
|
444
|
+
estimated_tokens = len(conversation_context) / 4
|
|
445
|
+
if estimated_tokens > self.parameters.max_context_tokens:
|
|
446
|
+
target_chars = self.parameters.max_context_tokens * 4
|
|
447
|
+
conversation_context = (
|
|
448
|
+
conversation_context[:target_chars] + "..."
|
|
449
|
+
)
|
|
450
|
+
|
|
451
|
+
logger.debug(
|
|
452
|
+
f"[{self.name}] Using conversation context: {len(conversation_messages)} messages "
|
|
453
|
+
f"(window_size={self.parameters.context_window_size})"
|
|
454
|
+
)
|
|
455
|
+
except Exception as e:
|
|
456
|
+
logger.warning(
|
|
457
|
+
f"[{self.name}] Failed to build conversation context, using question only: {e}"
|
|
458
|
+
)
|
|
459
|
+
conversation_context = ""
|
|
460
|
+
|
|
461
|
+
# Generate dual embeddings
|
|
462
|
+
if conversation_context:
|
|
463
|
+
# Embed both question and context
|
|
464
|
+
embeddings: list[list[float]] = self._embeddings.embed_documents(
|
|
465
|
+
[question, conversation_context]
|
|
466
|
+
)
|
|
467
|
+
question_embedding = embeddings[0]
|
|
468
|
+
context_embedding = embeddings[1]
|
|
469
|
+
else:
|
|
470
|
+
# Only embed question, use zero vector for context
|
|
471
|
+
embeddings = self._embeddings.embed_documents([question])
|
|
472
|
+
question_embedding = embeddings[0]
|
|
473
|
+
context_embedding = [0.0] * len(question_embedding) # Zero vector
|
|
474
|
+
|
|
475
|
+
return question_embedding, context_embedding, conversation_context
|
|
266
476
|
|
|
267
477
|
@mlflow.trace(name="semantic_search")
|
|
268
478
|
def _find_similar(
|
|
269
|
-
self,
|
|
479
|
+
self,
|
|
480
|
+
question: str,
|
|
481
|
+
conversation_context: str,
|
|
482
|
+
question_embedding: list[float],
|
|
483
|
+
context_embedding: list[float],
|
|
270
484
|
) -> tuple[SQLCacheEntry, float] | None:
|
|
271
485
|
"""
|
|
272
|
-
Find a semantically similar cached entry
|
|
486
|
+
Find a semantically similar cached entry using dual embedding matching.
|
|
487
|
+
|
|
488
|
+
This method matches BOTH the question AND the conversation context separately,
|
|
489
|
+
ensuring high precision by requiring both to be semantically similar.
|
|
273
490
|
|
|
274
491
|
Args:
|
|
275
|
-
question: The question
|
|
276
|
-
|
|
492
|
+
question: The original question (for logging)
|
|
493
|
+
conversation_context: The conversation context string
|
|
494
|
+
question_embedding: The embedding vector of just the question
|
|
495
|
+
context_embedding: The embedding vector of the conversation context
|
|
277
496
|
|
|
278
497
|
Returns:
|
|
279
|
-
Tuple of (SQLCacheEntry,
|
|
498
|
+
Tuple of (SQLCacheEntry, combined_similarity_score) if found, None otherwise
|
|
280
499
|
"""
|
|
281
500
|
# Use L2 (Euclidean) distance - optimal for Databricks GTE embeddings
|
|
282
501
|
# pg_vector's <-> operator returns L2 distance (0 = identical)
|
|
283
502
|
# Convert to similarity: 1 / (1 + distance) gives range [0, 1]
|
|
284
503
|
#
|
|
285
|
-
#
|
|
286
|
-
# 1.
|
|
287
|
-
# 2.
|
|
288
|
-
# 3.
|
|
504
|
+
# Dual embedding strategy:
|
|
505
|
+
# 1. Calculate separate similarities for question and context
|
|
506
|
+
# 2. BOTH must exceed their respective thresholds
|
|
507
|
+
# 3. Combined score is weighted average
|
|
508
|
+
# 4. Refresh-on-hit: check TTL after similarity check
|
|
289
509
|
ttl_seconds = self.parameters.time_to_live_seconds
|
|
290
510
|
ttl_disabled = ttl_seconds is None or ttl_seconds < 0
|
|
291
511
|
|
|
@@ -295,63 +515,87 @@ class SemanticCacheService(GenieServiceBase):
|
|
|
295
515
|
else:
|
|
296
516
|
is_valid_expr = f"created_at > NOW() - INTERVAL '{ttl_seconds} seconds'"
|
|
297
517
|
|
|
518
|
+
# Weighted combined similarity for ordering
|
|
519
|
+
question_weight: float = self.parameters.question_weight
|
|
520
|
+
context_weight: float = self.parameters.context_weight
|
|
521
|
+
|
|
298
522
|
search_sql: str = f"""
|
|
299
523
|
SELECT
|
|
300
524
|
id,
|
|
301
525
|
question,
|
|
526
|
+
conversation_context,
|
|
302
527
|
sql_query,
|
|
303
528
|
description,
|
|
304
529
|
conversation_id,
|
|
305
530
|
created_at,
|
|
306
|
-
1.0 / (1.0 + (question_embedding <-> %s::vector)) as
|
|
531
|
+
1.0 / (1.0 + (question_embedding <-> %s::vector)) as question_similarity,
|
|
532
|
+
1.0 / (1.0 + (context_embedding <-> %s::vector)) as context_similarity,
|
|
533
|
+
({question_weight} * (1.0 / (1.0 + (question_embedding <-> %s::vector)))) +
|
|
534
|
+
({context_weight} * (1.0 / (1.0 + (context_embedding <-> %s::vector)))) as combined_similarity,
|
|
307
535
|
{is_valid_expr} as is_valid
|
|
308
536
|
FROM {self.table_name}
|
|
309
537
|
WHERE genie_space_id = %s
|
|
310
|
-
ORDER BY
|
|
538
|
+
ORDER BY combined_similarity DESC
|
|
311
539
|
LIMIT 1
|
|
312
540
|
"""
|
|
313
541
|
|
|
314
|
-
|
|
542
|
+
question_emb_str: str = f"[{','.join(str(x) for x in question_embedding)}]"
|
|
543
|
+
context_emb_str: str = f"[{','.join(str(x) for x in context_embedding)}]"
|
|
315
544
|
|
|
316
545
|
with self._pool.connection() as conn:
|
|
317
546
|
with conn.cursor() as cur:
|
|
318
547
|
cur.execute(
|
|
319
548
|
search_sql,
|
|
320
|
-
(
|
|
549
|
+
(
|
|
550
|
+
question_emb_str,
|
|
551
|
+
context_emb_str,
|
|
552
|
+
question_emb_str,
|
|
553
|
+
context_emb_str,
|
|
554
|
+
self.space_id,
|
|
555
|
+
),
|
|
321
556
|
)
|
|
322
557
|
row: DbRow | None = cur.fetchone()
|
|
323
558
|
|
|
324
559
|
if row is None:
|
|
325
560
|
logger.info(
|
|
326
561
|
f"[{self.name}] MISS (no entries): "
|
|
327
|
-
f"question='{question[:50]}...' space='{self.
|
|
562
|
+
f"question='{question[:50]}...' space='{self.space_id}'"
|
|
328
563
|
)
|
|
329
564
|
return None
|
|
330
565
|
|
|
331
566
|
# Extract values from dict row
|
|
332
|
-
entry_id = row.get("id")
|
|
333
|
-
cached_question = row.get("question", "")
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
567
|
+
entry_id: Any = row.get("id")
|
|
568
|
+
cached_question: str = row.get("question", "")
|
|
569
|
+
cached_context: str = row.get("conversation_context", "")
|
|
570
|
+
sql_query: str = row["sql_query"]
|
|
571
|
+
description: str = row.get("description", "")
|
|
572
|
+
conversation_id_cached: str = row.get("conversation_id", "")
|
|
573
|
+
created_at: Any = row["created_at"]
|
|
574
|
+
question_similarity: float = row["question_similarity"]
|
|
575
|
+
context_similarity: float = row["context_similarity"]
|
|
576
|
+
combined_similarity: float = row["combined_similarity"]
|
|
577
|
+
is_valid: bool = row.get("is_valid", False)
|
|
578
|
+
|
|
579
|
+
# Log best match info
|
|
345
580
|
logger.info(
|
|
346
|
-
f"[{self.name}] Best match:
|
|
347
|
-
f"
|
|
581
|
+
f"[{self.name}] Best match: "
|
|
582
|
+
f"question_sim={question_similarity:.4f}, context_sim={context_similarity:.4f}, "
|
|
583
|
+
f"combined_sim={combined_similarity:.4f}, is_valid={is_valid}, "
|
|
584
|
+
f"question='{cached_question[:50]}...', context='{cached_context[:80]}...'"
|
|
348
585
|
)
|
|
349
586
|
|
|
350
|
-
# Check similarity
|
|
351
|
-
if
|
|
587
|
+
# Check BOTH similarity thresholds (dual embedding precision check)
|
|
588
|
+
if question_similarity < self.parameters.similarity_threshold:
|
|
589
|
+
logger.info(
|
|
590
|
+
f"[{self.name}] MISS (question similarity too low): "
|
|
591
|
+
f"question_sim={question_similarity:.4f} < threshold={self.parameters.similarity_threshold}"
|
|
592
|
+
)
|
|
593
|
+
return None
|
|
594
|
+
|
|
595
|
+
if context_similarity < self.parameters.context_similarity_threshold:
|
|
352
596
|
logger.info(
|
|
353
|
-
f"[{self.name}] MISS (
|
|
354
|
-
f"
|
|
597
|
+
f"[{self.name}] MISS (context similarity too low): "
|
|
598
|
+
f"context_sim={context_similarity:.4f} < threshold={self.parameters.context_similarity_threshold}"
|
|
355
599
|
)
|
|
356
600
|
return None
|
|
357
601
|
|
|
@@ -361,43 +605,59 @@ class SemanticCacheService(GenieServiceBase):
|
|
|
361
605
|
delete_sql = f"DELETE FROM {self.table_name} WHERE id = %s"
|
|
362
606
|
cur.execute(delete_sql, (entry_id,))
|
|
363
607
|
logger.info(
|
|
364
|
-
f"[{self.name}] MISS (expired, deleted for refresh):
|
|
365
|
-
f"ttl={ttl_seconds}s, question='{cached_question[:50]}...'"
|
|
608
|
+
f"[{self.name}] MISS (expired, deleted for refresh): "
|
|
609
|
+
f"combined_sim={combined_similarity:.4f}, ttl={ttl_seconds}s, question='{cached_question[:50]}...'"
|
|
366
610
|
)
|
|
367
611
|
return None
|
|
368
612
|
|
|
369
613
|
logger.info(
|
|
370
|
-
f"[{self.name}] HIT:
|
|
371
|
-
f"(cached_question='{cached_question[:50]}...')"
|
|
614
|
+
f"[{self.name}] HIT: question_sim={question_similarity:.4f}, context_sim={context_similarity:.4f}, "
|
|
615
|
+
f"combined_sim={combined_similarity:.4f} (cached_question='{cached_question[:50]}...')"
|
|
372
616
|
)
|
|
373
617
|
|
|
374
618
|
entry = SQLCacheEntry(
|
|
375
619
|
query=sql_query,
|
|
376
620
|
description=description,
|
|
377
|
-
conversation_id=
|
|
621
|
+
conversation_id=conversation_id_cached,
|
|
378
622
|
created_at=created_at,
|
|
379
623
|
)
|
|
380
|
-
return entry,
|
|
624
|
+
return entry, combined_similarity
|
|
381
625
|
|
|
382
626
|
def _store_entry(
|
|
383
|
-
self,
|
|
627
|
+
self,
|
|
628
|
+
question: str,
|
|
629
|
+
conversation_context: str,
|
|
630
|
+
question_embedding: list[float],
|
|
631
|
+
context_embedding: list[float],
|
|
632
|
+
response: GenieResponse,
|
|
384
633
|
) -> None:
|
|
385
|
-
"""Store a new cache entry for this Genie space."""
|
|
634
|
+
"""Store a new cache entry with dual embeddings for this Genie space."""
|
|
386
635
|
insert_sql: str = f"""
|
|
387
636
|
INSERT INTO {self.table_name}
|
|
388
|
-
(genie_space_id, question,
|
|
389
|
-
|
|
637
|
+
(genie_space_id, question, conversation_context, context_string,
|
|
638
|
+
question_embedding, context_embedding, sql_query, description, conversation_id)
|
|
639
|
+
VALUES (%s, %s, %s, %s, %s::vector, %s::vector, %s, %s, %s)
|
|
390
640
|
"""
|
|
391
|
-
|
|
641
|
+
question_emb_str: str = f"[{','.join(str(x) for x in question_embedding)}]"
|
|
642
|
+
context_emb_str: str = f"[{','.join(str(x) for x in context_embedding)}]"
|
|
643
|
+
|
|
644
|
+
# Build full context string for backward compatibility (used in logging)
|
|
645
|
+
if conversation_context:
|
|
646
|
+
full_context_string = f"{conversation_context}\nCurrent: {question}"
|
|
647
|
+
else:
|
|
648
|
+
full_context_string = question
|
|
392
649
|
|
|
393
650
|
with self._pool.connection() as conn:
|
|
394
651
|
with conn.cursor() as cur:
|
|
395
652
|
cur.execute(
|
|
396
653
|
insert_sql,
|
|
397
654
|
(
|
|
398
|
-
self.
|
|
655
|
+
self.space_id,
|
|
399
656
|
question,
|
|
400
|
-
|
|
657
|
+
conversation_context,
|
|
658
|
+
full_context_string,
|
|
659
|
+
question_emb_str,
|
|
660
|
+
context_emb_str,
|
|
401
661
|
response.query,
|
|
402
662
|
response.description,
|
|
403
663
|
response.conversation_id,
|
|
@@ -405,14 +665,15 @@ class SemanticCacheService(GenieServiceBase):
|
|
|
405
665
|
)
|
|
406
666
|
logger.info(
|
|
407
667
|
f"[{self.name}] Stored cache entry: question='{question[:50]}...' "
|
|
408
|
-
f"
|
|
668
|
+
f"context='{conversation_context[:80]}...' "
|
|
669
|
+
f"sql='{response.query[:50]}...' (space={self.space_id}, table={self.table_name})"
|
|
409
670
|
)
|
|
410
671
|
|
|
411
672
|
@mlflow.trace(name="execute_cached_sql_semantic")
|
|
412
673
|
def _execute_sql(self, sql: str) -> pd.DataFrame | str:
|
|
413
674
|
"""Execute SQL using the warehouse and return results."""
|
|
414
675
|
client: WorkspaceClient = self.warehouse.workspace_client
|
|
415
|
-
warehouse_id: str = self.warehouse.warehouse_id
|
|
676
|
+
warehouse_id: str = str(self.warehouse.warehouse_id)
|
|
416
677
|
|
|
417
678
|
statement_response: StatementResponse = (
|
|
418
679
|
client.statement_execution.execute_statement(
|
|
@@ -422,10 +683,13 @@ class SemanticCacheService(GenieServiceBase):
|
|
|
422
683
|
)
|
|
423
684
|
)
|
|
424
685
|
|
|
425
|
-
if
|
|
686
|
+
if (
|
|
687
|
+
statement_response.status is not None
|
|
688
|
+
and statement_response.status.state != StatementState.SUCCEEDED
|
|
689
|
+
):
|
|
426
690
|
error_msg: str = (
|
|
427
691
|
f"SQL execution failed: {statement_response.status.error.message}"
|
|
428
|
-
if statement_response.status.error
|
|
692
|
+
if statement_response.status.error is not None
|
|
429
693
|
else f"SQL execution failed with state: {statement_response.status.state}"
|
|
430
694
|
)
|
|
431
695
|
logger.error(f"[{self.name}] {error_msg}")
|
|
@@ -439,10 +703,10 @@ class SemanticCacheService(GenieServiceBase):
|
|
|
439
703
|
and statement_response.manifest.schema.columns
|
|
440
704
|
):
|
|
441
705
|
columns = [
|
|
442
|
-
col.name
|
|
706
|
+
col.name
|
|
707
|
+
for col in statement_response.manifest.schema.columns
|
|
708
|
+
if col.name is not None
|
|
443
709
|
]
|
|
444
|
-
elif hasattr(statement_response.result, "schema"):
|
|
445
|
-
columns = [col.name for col in statement_response.result.schema.columns]
|
|
446
710
|
|
|
447
711
|
data: list[list[Any]] = statement_response.result.data_array
|
|
448
712
|
if columns:
|
|
@@ -454,19 +718,16 @@ class SemanticCacheService(GenieServiceBase):
|
|
|
454
718
|
|
|
455
719
|
def ask_question(
|
|
456
720
|
self, question: str, conversation_id: str | None = None
|
|
457
|
-
) ->
|
|
721
|
+
) -> CacheResult:
|
|
458
722
|
"""
|
|
459
723
|
Ask a question, using semantic cache if a similar query exists.
|
|
460
724
|
|
|
461
725
|
On cache hit, re-executes the cached SQL to get fresh data.
|
|
462
|
-
|
|
726
|
+
Returns CacheResult with cache metadata.
|
|
463
727
|
"""
|
|
464
|
-
|
|
465
|
-
question, conversation_id
|
|
466
|
-
)
|
|
467
|
-
return result.response
|
|
728
|
+
return self.ask_question_with_cache_info(question, conversation_id)
|
|
468
729
|
|
|
469
|
-
@mlflow.trace(name="genie_semantic_cache_lookup"
|
|
730
|
+
@mlflow.trace(name="genie_semantic_cache_lookup")
|
|
470
731
|
def ask_question_with_cache_info(
|
|
471
732
|
self,
|
|
472
733
|
question: str,
|
|
@@ -475,11 +736,12 @@ class SemanticCacheService(GenieServiceBase):
|
|
|
475
736
|
"""
|
|
476
737
|
Ask a question with detailed cache hit information.
|
|
477
738
|
|
|
478
|
-
On cache hit, the cached SQL is re-executed to return fresh data
|
|
739
|
+
On cache hit, the cached SQL is re-executed to return fresh data, but the
|
|
740
|
+
conversation_id returned is the current conversation_id (not the cached one).
|
|
479
741
|
|
|
480
742
|
Args:
|
|
481
743
|
question: The question to ask
|
|
482
|
-
conversation_id: Optional conversation ID
|
|
744
|
+
conversation_id: Optional conversation ID for context and continuation
|
|
483
745
|
|
|
484
746
|
Returns:
|
|
485
747
|
CacheResult with fresh response and cache metadata
|
|
@@ -487,28 +749,37 @@ class SemanticCacheService(GenieServiceBase):
|
|
|
487
749
|
# Ensure initialization (lazy init if initialize() wasn't called)
|
|
488
750
|
self._setup()
|
|
489
751
|
|
|
490
|
-
# Generate
|
|
491
|
-
|
|
752
|
+
# Generate dual embeddings for the question and conversation context
|
|
753
|
+
question_embedding: list[float]
|
|
754
|
+
context_embedding: list[float]
|
|
755
|
+
conversation_context: str
|
|
756
|
+
question_embedding, context_embedding, conversation_context = (
|
|
757
|
+
self._embed_question(question, conversation_id)
|
|
758
|
+
)
|
|
492
759
|
|
|
493
|
-
# Check cache
|
|
760
|
+
# Check cache using dual embedding similarity
|
|
494
761
|
cache_result: tuple[SQLCacheEntry, float] | None = self._find_similar(
|
|
495
|
-
question,
|
|
762
|
+
question, conversation_context, question_embedding, context_embedding
|
|
496
763
|
)
|
|
497
764
|
|
|
498
765
|
if cache_result is not None:
|
|
499
|
-
cached,
|
|
766
|
+
cached, combined_similarity = cache_result
|
|
500
767
|
logger.debug(
|
|
501
|
-
f"[{self.name}] Semantic cache hit (
|
|
768
|
+
f"[{self.name}] Semantic cache hit (combined_similarity={combined_similarity:.3f}): {question[:50]}..."
|
|
502
769
|
)
|
|
503
770
|
|
|
504
771
|
# Re-execute the cached SQL to get fresh data
|
|
505
772
|
result: pd.DataFrame | str = self._execute_sql(cached.query)
|
|
506
773
|
|
|
774
|
+
# IMPORTANT: Use the current conversation_id (from the request), not the cached one
|
|
775
|
+
# This ensures the conversation continues properly
|
|
507
776
|
response: GenieResponse = GenieResponse(
|
|
508
777
|
result=result,
|
|
509
778
|
query=cached.query,
|
|
510
779
|
description=cached.description,
|
|
511
|
-
conversation_id=
|
|
780
|
+
conversation_id=conversation_id
|
|
781
|
+
if conversation_id
|
|
782
|
+
else cached.conversation_id,
|
|
512
783
|
)
|
|
513
784
|
|
|
514
785
|
return CacheResult(response=response, cache_hit=True, served_by=self.name)
|
|
@@ -516,22 +787,32 @@ class SemanticCacheService(GenieServiceBase):
|
|
|
516
787
|
# Cache miss - delegate to wrapped service
|
|
517
788
|
logger.debug(f"[{self.name}] Miss: {question[:50]}...")
|
|
518
789
|
|
|
519
|
-
|
|
790
|
+
result: CacheResult = self.impl.ask_question(question, conversation_id)
|
|
520
791
|
|
|
521
792
|
# Store in cache if we got a SQL query
|
|
522
|
-
if response.query:
|
|
793
|
+
if result.response.query:
|
|
523
794
|
logger.info(
|
|
524
795
|
f"[{self.name}] Storing new cache entry for question: '{question[:50]}...' "
|
|
525
|
-
f"(space={self.
|
|
796
|
+
f"(space={self.space_id})"
|
|
526
797
|
)
|
|
527
|
-
self._store_entry(
|
|
528
|
-
|
|
798
|
+
self._store_entry(
|
|
799
|
+
question,
|
|
800
|
+
conversation_context,
|
|
801
|
+
question_embedding,
|
|
802
|
+
context_embedding,
|
|
803
|
+
result.response,
|
|
804
|
+
)
|
|
805
|
+
elif not result.response.query:
|
|
529
806
|
logger.warning(
|
|
530
807
|
f"[{self.name}] Not caching: response has no SQL query "
|
|
531
808
|
f"(question='{question[:50]}...')"
|
|
532
809
|
)
|
|
533
810
|
|
|
534
|
-
return CacheResult(response=response, cache_hit=False, served_by=None)
|
|
811
|
+
return CacheResult(response=result.response, cache_hit=False, served_by=None)
|
|
812
|
+
|
|
813
|
+
@property
|
|
814
|
+
def space_id(self) -> str:
|
|
815
|
+
return self.impl.space_id
|
|
535
816
|
|
|
536
817
|
def invalidate_expired(self) -> int:
|
|
537
818
|
"""Remove expired entries from the cache for this Genie space.
|
|
@@ -544,7 +825,7 @@ class SemanticCacheService(GenieServiceBase):
|
|
|
544
825
|
# If TTL is disabled, nothing can expire
|
|
545
826
|
if ttl_seconds is None or ttl_seconds < 0:
|
|
546
827
|
logger.debug(
|
|
547
|
-
f"[{self.name}] TTL disabled, no entries to expire for space {self.
|
|
828
|
+
f"[{self.name}] TTL disabled, no entries to expire for space {self.space_id}"
|
|
548
829
|
)
|
|
549
830
|
return 0
|
|
550
831
|
|
|
@@ -556,10 +837,10 @@ class SemanticCacheService(GenieServiceBase):
|
|
|
556
837
|
|
|
557
838
|
with self._pool.connection() as conn:
|
|
558
839
|
with conn.cursor() as cur:
|
|
559
|
-
cur.execute(delete_sql, (self.
|
|
840
|
+
cur.execute(delete_sql, (self.space_id, ttl_seconds))
|
|
560
841
|
deleted: int = cur.rowcount
|
|
561
842
|
logger.debug(
|
|
562
|
-
f"[{self.name}] Deleted {deleted} expired entries for space {self.
|
|
843
|
+
f"[{self.name}] Deleted {deleted} expired entries for space {self.space_id}"
|
|
563
844
|
)
|
|
564
845
|
return deleted
|
|
565
846
|
|
|
@@ -570,10 +851,10 @@ class SemanticCacheService(GenieServiceBase):
|
|
|
570
851
|
|
|
571
852
|
with self._pool.connection() as conn:
|
|
572
853
|
with conn.cursor() as cur:
|
|
573
|
-
cur.execute(delete_sql, (self.
|
|
854
|
+
cur.execute(delete_sql, (self.space_id,))
|
|
574
855
|
deleted: int = cur.rowcount
|
|
575
856
|
logger.debug(
|
|
576
|
-
f"[{self.name}] Cleared {deleted} entries for space {self.
|
|
857
|
+
f"[{self.name}] Cleared {deleted} entries for space {self.space_id}"
|
|
577
858
|
)
|
|
578
859
|
return deleted
|
|
579
860
|
|
|
@@ -587,7 +868,7 @@ class SemanticCacheService(GenieServiceBase):
|
|
|
587
868
|
|
|
588
869
|
with self._pool.connection() as conn:
|
|
589
870
|
with conn.cursor() as cur:
|
|
590
|
-
cur.execute(count_sql, (self.
|
|
871
|
+
cur.execute(count_sql, (self.space_id,))
|
|
591
872
|
row: DbRow | None = cur.fetchone()
|
|
592
873
|
return row.get("count", 0) if row else 0
|
|
593
874
|
|
|
@@ -605,7 +886,7 @@ class SemanticCacheService(GenieServiceBase):
|
|
|
605
886
|
"""
|
|
606
887
|
with self._pool.connection() as conn:
|
|
607
888
|
with conn.cursor() as cur:
|
|
608
|
-
cur.execute(count_sql, (self.
|
|
889
|
+
cur.execute(count_sql, (self.space_id,))
|
|
609
890
|
row: DbRow | None = cur.fetchone()
|
|
610
891
|
total = row.get("total", 0) if row else 0
|
|
611
892
|
return {
|
|
@@ -627,12 +908,12 @@ class SemanticCacheService(GenieServiceBase):
|
|
|
627
908
|
|
|
628
909
|
with self._pool.connection() as conn:
|
|
629
910
|
with conn.cursor() as cur:
|
|
630
|
-
cur.execute(stats_sql, (ttl_seconds, ttl_seconds, self.
|
|
631
|
-
|
|
911
|
+
cur.execute(stats_sql, (ttl_seconds, ttl_seconds, self.space_id))
|
|
912
|
+
stats_row: DbRow | None = cur.fetchone()
|
|
632
913
|
return {
|
|
633
|
-
"size":
|
|
914
|
+
"size": stats_row.get("total", 0) if stats_row else 0,
|
|
634
915
|
"ttl_seconds": ttl.total_seconds() if ttl else None,
|
|
635
916
|
"similarity_threshold": self.similarity_threshold,
|
|
636
|
-
"expired_entries":
|
|
637
|
-
"valid_entries":
|
|
917
|
+
"expired_entries": stats_row.get("expired", 0) if stats_row else 0,
|
|
918
|
+
"valid_entries": stats_row.get("valid", 0) if stats_row else 0,
|
|
638
919
|
}
|