dao-ai 0.1.2__py3-none-any.whl → 0.1.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/apps/__init__.py +24 -0
- dao_ai/apps/handlers.py +105 -0
- dao_ai/apps/model_serving.py +29 -0
- dao_ai/apps/resources.py +1122 -0
- dao_ai/apps/server.py +39 -0
- dao_ai/cli.py +546 -37
- dao_ai/config.py +1179 -139
- dao_ai/evaluation.py +543 -0
- dao_ai/genie/__init__.py +55 -7
- dao_ai/genie/cache/__init__.py +34 -7
- dao_ai/genie/cache/base.py +143 -2
- dao_ai/genie/cache/context_aware/__init__.py +31 -0
- dao_ai/genie/cache/context_aware/base.py +1151 -0
- dao_ai/genie/cache/context_aware/in_memory.py +609 -0
- dao_ai/genie/cache/context_aware/persistent.py +802 -0
- dao_ai/genie/cache/context_aware/postgres.py +1166 -0
- dao_ai/genie/cache/core.py +1 -1
- dao_ai/genie/cache/lru.py +257 -75
- dao_ai/genie/cache/optimization.py +890 -0
- dao_ai/genie/core.py +235 -11
- dao_ai/memory/postgres.py +175 -39
- dao_ai/middleware/__init__.py +38 -0
- dao_ai/middleware/assertions.py +3 -3
- dao_ai/middleware/context_editing.py +230 -0
- dao_ai/middleware/core.py +4 -4
- dao_ai/middleware/guardrails.py +3 -3
- dao_ai/middleware/human_in_the_loop.py +3 -2
- dao_ai/middleware/message_validation.py +4 -4
- dao_ai/middleware/model_call_limit.py +77 -0
- dao_ai/middleware/model_retry.py +121 -0
- dao_ai/middleware/pii.py +157 -0
- dao_ai/middleware/summarization.py +1 -1
- dao_ai/middleware/tool_call_limit.py +210 -0
- dao_ai/middleware/tool_retry.py +174 -0
- dao_ai/middleware/tool_selector.py +129 -0
- dao_ai/models.py +327 -370
- dao_ai/nodes.py +9 -16
- dao_ai/orchestration/core.py +33 -9
- dao_ai/orchestration/supervisor.py +29 -13
- dao_ai/orchestration/swarm.py +6 -1
- dao_ai/{prompts.py → prompts/__init__.py} +12 -61
- dao_ai/prompts/instructed_retriever_decomposition.yaml +58 -0
- dao_ai/prompts/instruction_reranker.yaml +14 -0
- dao_ai/prompts/router.yaml +37 -0
- dao_ai/prompts/verifier.yaml +46 -0
- dao_ai/providers/base.py +28 -2
- dao_ai/providers/databricks.py +363 -33
- dao_ai/state.py +1 -0
- dao_ai/tools/__init__.py +5 -3
- dao_ai/tools/genie.py +103 -26
- dao_ai/tools/instructed_retriever.py +366 -0
- dao_ai/tools/instruction_reranker.py +202 -0
- dao_ai/tools/mcp.py +539 -97
- dao_ai/tools/router.py +89 -0
- dao_ai/tools/slack.py +13 -2
- dao_ai/tools/sql.py +7 -3
- dao_ai/tools/unity_catalog.py +32 -10
- dao_ai/tools/vector_search.py +493 -160
- dao_ai/tools/verifier.py +159 -0
- dao_ai/utils.py +182 -2
- dao_ai/vector_search.py +46 -1
- {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/METADATA +45 -9
- dao_ai-0.1.20.dist-info/RECORD +89 -0
- dao_ai/agent_as_code.py +0 -22
- dao_ai/genie/cache/semantic.py +0 -970
- dao_ai-0.1.2.dist-info/RECORD +0 -64
- {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/WHEEL +0 -0
- {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,1151 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Abstract base class for context-aware Genie cache implementations.
|
|
3
|
+
|
|
4
|
+
This module provides the foundational abstract base class for all context-aware
|
|
5
|
+
cache implementations. It extracts common code for:
|
|
6
|
+
- Dual embedding generation (question + conversation context)
|
|
7
|
+
- Ask question flow with error handling and graceful fallback
|
|
8
|
+
- SQL execution with retry logic
|
|
9
|
+
- Common properties and initialization patterns
|
|
10
|
+
|
|
11
|
+
Subclasses must implement storage-specific methods:
|
|
12
|
+
- _find_similar(): Find semantically similar cached entry
|
|
13
|
+
- _store_entry(): Store new cache entry
|
|
14
|
+
- _setup(): Initialize resources (embeddings, storage)
|
|
15
|
+
- invalidate_expired(): Remove expired entries
|
|
16
|
+
- clear(): Clear all entries for space
|
|
17
|
+
- stats(): Return cache statistics
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from __future__ import annotations
|
|
21
|
+
|
|
22
|
+
from abc import abstractmethod
|
|
23
|
+
from datetime import timedelta
|
|
24
|
+
from typing import Any, Self, TypeVar
|
|
25
|
+
|
|
26
|
+
import mlflow
|
|
27
|
+
import pandas as pd
|
|
28
|
+
from databricks.sdk import WorkspaceClient
|
|
29
|
+
from databricks.sdk.service.dashboards import (
|
|
30
|
+
GenieFeedbackRating,
|
|
31
|
+
GenieListConversationMessagesResponse,
|
|
32
|
+
GenieMessage,
|
|
33
|
+
)
|
|
34
|
+
from databricks_ai_bridge.genie import GenieResponse
|
|
35
|
+
from loguru import logger
|
|
36
|
+
|
|
37
|
+
from dao_ai.config import LLMModel, WarehouseModel
|
|
38
|
+
from dao_ai.genie.cache.base import (
|
|
39
|
+
CacheResult,
|
|
40
|
+
GenieServiceBase,
|
|
41
|
+
SQLCacheEntry,
|
|
42
|
+
)
|
|
43
|
+
from dao_ai.genie.cache.core import execute_sql_via_warehouse
|
|
44
|
+
|
|
45
|
+
# Type variable for subclass return types
|
|
46
|
+
T = TypeVar("T", bound="ContextAwareGenieService")
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def get_conversation_history(
|
|
50
|
+
workspace_client: WorkspaceClient,
|
|
51
|
+
space_id: str,
|
|
52
|
+
conversation_id: str,
|
|
53
|
+
max_messages: int = 10,
|
|
54
|
+
) -> list[GenieMessage]:
|
|
55
|
+
"""
|
|
56
|
+
Retrieve conversation history from Genie.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
workspace_client: The Databricks workspace client
|
|
60
|
+
space_id: The Genie space ID
|
|
61
|
+
conversation_id: The conversation ID to retrieve
|
|
62
|
+
max_messages: Maximum number of messages to retrieve
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
List of GenieMessage objects representing the conversation history
|
|
66
|
+
"""
|
|
67
|
+
try:
|
|
68
|
+
# Use the Genie API to retrieve conversation messages
|
|
69
|
+
response: GenieListConversationMessagesResponse = (
|
|
70
|
+
workspace_client.genie.list_conversation_messages(
|
|
71
|
+
space_id=space_id,
|
|
72
|
+
conversation_id=conversation_id,
|
|
73
|
+
)
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
# Return the most recent messages up to max_messages
|
|
77
|
+
if response.messages is not None:
|
|
78
|
+
all_messages: list[GenieMessage] = list(response.messages)
|
|
79
|
+
return (
|
|
80
|
+
all_messages[-max_messages:]
|
|
81
|
+
if len(all_messages) > max_messages
|
|
82
|
+
else all_messages
|
|
83
|
+
)
|
|
84
|
+
return []
|
|
85
|
+
except Exception as e:
|
|
86
|
+
logger.warning(
|
|
87
|
+
"Failed to retrieve conversation history",
|
|
88
|
+
conversation_id=conversation_id,
|
|
89
|
+
error=str(e),
|
|
90
|
+
)
|
|
91
|
+
return []
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def build_context_string(
|
|
95
|
+
question: str,
|
|
96
|
+
conversation_messages: list[GenieMessage],
|
|
97
|
+
window_size: int,
|
|
98
|
+
max_tokens: int = 2000,
|
|
99
|
+
) -> str:
|
|
100
|
+
"""
|
|
101
|
+
Build a context-aware question string using rolling window.
|
|
102
|
+
|
|
103
|
+
This function creates a concatenated string that includes recent conversation
|
|
104
|
+
turns to provide context for semantic similarity matching.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
question: The current question
|
|
108
|
+
conversation_messages: List of previous conversation messages
|
|
109
|
+
window_size: Number of previous turns to include
|
|
110
|
+
max_tokens: Maximum estimated tokens (rough approximation: 4 chars = 1 token)
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
Context-aware question string formatted for embedding
|
|
114
|
+
"""
|
|
115
|
+
if window_size <= 0 or not conversation_messages:
|
|
116
|
+
return question
|
|
117
|
+
|
|
118
|
+
# Take the last window_size messages (most recent)
|
|
119
|
+
recent_messages = (
|
|
120
|
+
conversation_messages[-window_size:]
|
|
121
|
+
if len(conversation_messages) > window_size
|
|
122
|
+
else conversation_messages
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
# Build context parts
|
|
126
|
+
context_parts: list[str] = []
|
|
127
|
+
|
|
128
|
+
for msg in recent_messages:
|
|
129
|
+
# Only include messages with content from the history
|
|
130
|
+
if msg.content:
|
|
131
|
+
# Limit message length to prevent token overflow
|
|
132
|
+
content: str = msg.content
|
|
133
|
+
if len(content) > 500: # Truncate very long messages
|
|
134
|
+
content = content[:500] + "..."
|
|
135
|
+
context_parts.append(f"Previous: {content}")
|
|
136
|
+
|
|
137
|
+
# Add current question
|
|
138
|
+
context_parts.append(f"Current: {question}")
|
|
139
|
+
|
|
140
|
+
# Join with newlines
|
|
141
|
+
context_string = "\n".join(context_parts)
|
|
142
|
+
|
|
143
|
+
# Rough token limit check (4 chars ≈ 1 token)
|
|
144
|
+
estimated_tokens = len(context_string) / 4
|
|
145
|
+
if estimated_tokens > max_tokens:
|
|
146
|
+
# Truncate to fit max_tokens
|
|
147
|
+
target_chars = max_tokens * 4
|
|
148
|
+
original_length = len(context_string)
|
|
149
|
+
context_string = context_string[:target_chars] + "..."
|
|
150
|
+
logger.trace(
|
|
151
|
+
"Truncated context string",
|
|
152
|
+
original_chars=original_length,
|
|
153
|
+
target_chars=target_chars,
|
|
154
|
+
max_tokens=max_tokens,
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
return context_string
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
class ContextAwareGenieService(GenieServiceBase):
|
|
161
|
+
"""
|
|
162
|
+
Abstract base class for context-aware Genie cache implementations.
|
|
163
|
+
|
|
164
|
+
This class provides shared implementation for:
|
|
165
|
+
- Dual embedding generation (question + conversation context)
|
|
166
|
+
- Main ask_question flow with error handling
|
|
167
|
+
- SQL execution with warehouse
|
|
168
|
+
- Common properties (time_to_live, similarity_threshold, etc.)
|
|
169
|
+
|
|
170
|
+
Subclasses must implement storage-specific methods for finding similar
|
|
171
|
+
entries, storing new entries, and managing cache lifecycle.
|
|
172
|
+
|
|
173
|
+
Error Handling:
|
|
174
|
+
All cache operations are wrapped in try/except to ensure graceful
|
|
175
|
+
degradation. If any cache operation fails, the request is delegated
|
|
176
|
+
to the underlying service without caching.
|
|
177
|
+
|
|
178
|
+
Thread Safety:
|
|
179
|
+
Subclasses are responsible for thread safety of storage operations.
|
|
180
|
+
This base class does not provide synchronization primitives.
|
|
181
|
+
"""
|
|
182
|
+
|
|
183
|
+
# Common attributes - subclasses should define these
|
|
184
|
+
impl: GenieServiceBase
|
|
185
|
+
_workspace_client: WorkspaceClient | None
|
|
186
|
+
name: str
|
|
187
|
+
_embeddings: Any # DatabricksEmbeddings
|
|
188
|
+
_embedding_dims: int | None
|
|
189
|
+
_setup_complete: bool
|
|
190
|
+
|
|
191
|
+
# Abstract methods that subclasses must implement
|
|
192
|
+
@abstractmethod
|
|
193
|
+
def _setup(self) -> None:
|
|
194
|
+
"""
|
|
195
|
+
Initialize resources required by the cache implementation.
|
|
196
|
+
|
|
197
|
+
This method is called lazily before first use. Implementations should:
|
|
198
|
+
- Initialize embedding model
|
|
199
|
+
- Set up storage (database connection, in-memory structures, etc.)
|
|
200
|
+
- Create necessary tables/indexes if applicable
|
|
201
|
+
|
|
202
|
+
This method should be idempotent (safe to call multiple times).
|
|
203
|
+
"""
|
|
204
|
+
pass
|
|
205
|
+
|
|
206
|
+
@abstractmethod
|
|
207
|
+
def _find_similar(
|
|
208
|
+
self,
|
|
209
|
+
question: str,
|
|
210
|
+
conversation_context: str,
|
|
211
|
+
question_embedding: list[float],
|
|
212
|
+
context_embedding: list[float],
|
|
213
|
+
conversation_id: str | None = None,
|
|
214
|
+
) -> tuple[SQLCacheEntry, float] | None:
|
|
215
|
+
"""
|
|
216
|
+
Find a semantically similar cached entry using dual embedding matching.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
question: The original question (for logging)
|
|
220
|
+
conversation_context: The conversation context string
|
|
221
|
+
question_embedding: The embedding vector of just the question
|
|
222
|
+
context_embedding: The embedding vector of the conversation context
|
|
223
|
+
conversation_id: Optional conversation ID (for logging)
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
Tuple of (SQLCacheEntry, combined_similarity_score) if found, None otherwise
|
|
227
|
+
"""
|
|
228
|
+
pass
|
|
229
|
+
|
|
230
|
+
@abstractmethod
|
|
231
|
+
def _store_entry(
|
|
232
|
+
self,
|
|
233
|
+
question: str,
|
|
234
|
+
conversation_context: str,
|
|
235
|
+
question_embedding: list[float],
|
|
236
|
+
context_embedding: list[float],
|
|
237
|
+
response: GenieResponse,
|
|
238
|
+
message_id: str | None = None,
|
|
239
|
+
) -> None:
|
|
240
|
+
"""
|
|
241
|
+
Store a new cache entry with dual embeddings.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
question: The user's question
|
|
245
|
+
conversation_context: Previous conversation context string
|
|
246
|
+
question_embedding: Embedding of the question
|
|
247
|
+
context_embedding: Embedding of the conversation context
|
|
248
|
+
response: The GenieResponse containing query, description, etc.
|
|
249
|
+
message_id: The Genie message ID from the original API response.
|
|
250
|
+
Stored with the cache entry to enable feedback on cache hits.
|
|
251
|
+
"""
|
|
252
|
+
pass
|
|
253
|
+
|
|
254
|
+
def invalidate_expired(self) -> int | dict[str, int]:
|
|
255
|
+
"""
|
|
256
|
+
Template method for removing expired entries from the cache.
|
|
257
|
+
|
|
258
|
+
This method implements the TTL check and delegates to
|
|
259
|
+
_delete_expired_entries() for the actual deletion.
|
|
260
|
+
|
|
261
|
+
Returns:
|
|
262
|
+
Number of entries deleted, or dict with counts by category
|
|
263
|
+
"""
|
|
264
|
+
self._setup()
|
|
265
|
+
ttl_seconds = self.time_to_live_seconds
|
|
266
|
+
|
|
267
|
+
if ttl_seconds is None or ttl_seconds < 0:
|
|
268
|
+
return self._get_empty_expiration_result()
|
|
269
|
+
|
|
270
|
+
return self._delete_expired_entries(ttl_seconds)
|
|
271
|
+
|
|
272
|
+
@abstractmethod
|
|
273
|
+
def _delete_expired_entries(self, ttl_seconds: int) -> int | dict[str, int]:
|
|
274
|
+
"""
|
|
275
|
+
Delete expired entries from storage.
|
|
276
|
+
|
|
277
|
+
Args:
|
|
278
|
+
ttl_seconds: TTL in seconds for determining expiration
|
|
279
|
+
|
|
280
|
+
Returns:
|
|
281
|
+
Number of entries deleted, or dict with counts by category
|
|
282
|
+
"""
|
|
283
|
+
pass
|
|
284
|
+
|
|
285
|
+
def _get_empty_expiration_result(self) -> int | dict[str, int]:
|
|
286
|
+
"""
|
|
287
|
+
Return the empty result for invalidate_expired when TTL is disabled.
|
|
288
|
+
|
|
289
|
+
Override this in subclasses that return dict to return appropriate empty dict.
|
|
290
|
+
|
|
291
|
+
Returns:
|
|
292
|
+
0 by default, or empty dict for subclasses that return dict
|
|
293
|
+
"""
|
|
294
|
+
return 0
|
|
295
|
+
|
|
296
|
+
def clear(self) -> int:
|
|
297
|
+
"""
|
|
298
|
+
Template method for clearing all entries from the cache.
|
|
299
|
+
|
|
300
|
+
This method calls _setup() and delegates to _delete_all_entries().
|
|
301
|
+
|
|
302
|
+
Returns:
|
|
303
|
+
Number of entries deleted
|
|
304
|
+
"""
|
|
305
|
+
self._setup()
|
|
306
|
+
return self._delete_all_entries()
|
|
307
|
+
|
|
308
|
+
@abstractmethod
|
|
309
|
+
def _delete_all_entries(self) -> int:
|
|
310
|
+
"""
|
|
311
|
+
Delete all entries for this Genie space from storage.
|
|
312
|
+
|
|
313
|
+
Returns:
|
|
314
|
+
Number of entries deleted
|
|
315
|
+
"""
|
|
316
|
+
pass
|
|
317
|
+
|
|
318
|
+
def stats(self) -> dict[str, Any]:
|
|
319
|
+
"""
|
|
320
|
+
Template method for returning cache statistics.
|
|
321
|
+
|
|
322
|
+
This method uses the Template Method pattern to consolidate the common
|
|
323
|
+
stats calculation algorithm. Subclasses provide counting implementations
|
|
324
|
+
via abstract methods and can add additional stats via hook methods.
|
|
325
|
+
|
|
326
|
+
Returns:
|
|
327
|
+
Dict with cache statistics (size, ttl, thresholds, etc.)
|
|
328
|
+
"""
|
|
329
|
+
self._setup()
|
|
330
|
+
ttl_seconds = self.time_to_live_seconds
|
|
331
|
+
ttl = self.time_to_live
|
|
332
|
+
|
|
333
|
+
# Calculate base stats using abstract counting methods
|
|
334
|
+
if ttl_seconds is None or ttl_seconds < 0:
|
|
335
|
+
total = self._count_all_entries()
|
|
336
|
+
base_stats: dict[str, Any] = {
|
|
337
|
+
"size": total,
|
|
338
|
+
"ttl_seconds": None,
|
|
339
|
+
"similarity_threshold": self.similarity_threshold,
|
|
340
|
+
"expired_entries": 0,
|
|
341
|
+
"valid_entries": total,
|
|
342
|
+
}
|
|
343
|
+
else:
|
|
344
|
+
total, expired = self._count_entries_with_ttl(ttl_seconds)
|
|
345
|
+
base_stats = {
|
|
346
|
+
"size": total,
|
|
347
|
+
"ttl_seconds": ttl.total_seconds() if ttl else None,
|
|
348
|
+
"similarity_threshold": self.similarity_threshold,
|
|
349
|
+
"expired_entries": expired,
|
|
350
|
+
"valid_entries": total - expired,
|
|
351
|
+
}
|
|
352
|
+
|
|
353
|
+
# Add any additional stats from subclasses
|
|
354
|
+
additional_stats = self._get_additional_stats()
|
|
355
|
+
base_stats.update(additional_stats)
|
|
356
|
+
|
|
357
|
+
return base_stats
|
|
358
|
+
|
|
359
|
+
@abstractmethod
|
|
360
|
+
def _count_all_entries(self) -> int:
|
|
361
|
+
"""
|
|
362
|
+
Count all cache entries for this Genie space.
|
|
363
|
+
|
|
364
|
+
Returns:
|
|
365
|
+
Total number of cache entries
|
|
366
|
+
"""
|
|
367
|
+
pass
|
|
368
|
+
|
|
369
|
+
@abstractmethod
|
|
370
|
+
def _count_entries_with_ttl(self, ttl_seconds: int) -> tuple[int, int]:
|
|
371
|
+
"""
|
|
372
|
+
Count total and expired entries for this Genie space.
|
|
373
|
+
|
|
374
|
+
Args:
|
|
375
|
+
ttl_seconds: TTL in seconds for determining expiration
|
|
376
|
+
|
|
377
|
+
Returns:
|
|
378
|
+
Tuple of (total_entries, expired_entries)
|
|
379
|
+
"""
|
|
380
|
+
pass
|
|
381
|
+
|
|
382
|
+
def _get_additional_stats(self) -> dict[str, Any]:
|
|
383
|
+
"""
|
|
384
|
+
Hook method to add additional stats from subclasses.
|
|
385
|
+
|
|
386
|
+
Override this method to add subclass-specific statistics like
|
|
387
|
+
capacity (in-memory) or prompt history stats (postgres).
|
|
388
|
+
|
|
389
|
+
Returns:
|
|
390
|
+
Dict with additional stats to merge into base stats
|
|
391
|
+
"""
|
|
392
|
+
return {}
|
|
393
|
+
|
|
394
|
+
# Properties that subclasses should implement or inherit
|
|
395
|
+
@property
|
|
396
|
+
@abstractmethod
|
|
397
|
+
def warehouse(self) -> WarehouseModel:
|
|
398
|
+
"""The warehouse used for executing cached SQL queries."""
|
|
399
|
+
pass
|
|
400
|
+
|
|
401
|
+
@property
|
|
402
|
+
@abstractmethod
|
|
403
|
+
def time_to_live(self) -> timedelta | None:
|
|
404
|
+
"""Time-to-live for cache entries. None means never expires."""
|
|
405
|
+
pass
|
|
406
|
+
|
|
407
|
+
@property
|
|
408
|
+
@abstractmethod
|
|
409
|
+
def similarity_threshold(self) -> float:
|
|
410
|
+
"""Minimum similarity for cache hit (using L2 distance converted to similarity)."""
|
|
411
|
+
pass
|
|
412
|
+
|
|
413
|
+
@property
|
|
414
|
+
def embedding_dims(self) -> int:
|
|
415
|
+
"""Dimension size for embeddings (auto-detected if not configured)."""
|
|
416
|
+
if self._embedding_dims is None:
|
|
417
|
+
raise RuntimeError(
|
|
418
|
+
"Embedding dimensions not yet initialized. Call _setup() first."
|
|
419
|
+
)
|
|
420
|
+
return self._embedding_dims
|
|
421
|
+
|
|
422
|
+
@property
|
|
423
|
+
def space_id(self) -> str:
|
|
424
|
+
"""The Genie space ID from the underlying service."""
|
|
425
|
+
return self.impl.space_id
|
|
426
|
+
|
|
427
|
+
@property
|
|
428
|
+
def workspace_client(self) -> WorkspaceClient | None:
|
|
429
|
+
"""Get workspace client, delegating to impl if not set."""
|
|
430
|
+
if self._workspace_client is not None:
|
|
431
|
+
return self._workspace_client
|
|
432
|
+
return self.impl.workspace_client
|
|
433
|
+
|
|
434
|
+
@property
|
|
435
|
+
def time_to_live_seconds(self) -> int | None:
|
|
436
|
+
"""TTL in seconds (None or negative = never expires)."""
|
|
437
|
+
ttl = self.time_to_live
|
|
438
|
+
if ttl is None:
|
|
439
|
+
return None
|
|
440
|
+
return int(ttl.total_seconds())
|
|
441
|
+
|
|
442
|
+
# Abstract method for embedding - subclasses must implement
|
|
443
|
+
@abstractmethod
|
|
444
|
+
def _embed_question(
|
|
445
|
+
self, question: str, conversation_id: str | None = None
|
|
446
|
+
) -> tuple[list[float], list[float], str]:
|
|
447
|
+
"""
|
|
448
|
+
Generate dual embeddings for a question with conversation context.
|
|
449
|
+
|
|
450
|
+
Args:
|
|
451
|
+
question: The question to embed
|
|
452
|
+
conversation_id: Optional conversation ID for retrieving context
|
|
453
|
+
|
|
454
|
+
Returns:
|
|
455
|
+
Tuple of (question_embedding, context_embedding, conversation_context_string)
|
|
456
|
+
"""
|
|
457
|
+
pass
|
|
458
|
+
|
|
459
|
+
# Shared implementation methods
|
|
460
|
+
def initialize(self) -> Self:
|
|
461
|
+
"""
|
|
462
|
+
Eagerly initialize the cache service.
|
|
463
|
+
|
|
464
|
+
Call this during tool creation to:
|
|
465
|
+
- Validate configuration early (fail fast)
|
|
466
|
+
- Initialize resources before any requests
|
|
467
|
+
- Avoid first-request latency from lazy initialization
|
|
468
|
+
|
|
469
|
+
Returns:
|
|
470
|
+
self for method chaining
|
|
471
|
+
"""
|
|
472
|
+
self._setup()
|
|
473
|
+
return self
|
|
474
|
+
|
|
475
|
+
def _initialize_embeddings(
|
|
476
|
+
self,
|
|
477
|
+
embedding_model: str | LLMModel,
|
|
478
|
+
embedding_dims: int | None = None,
|
|
479
|
+
) -> None:
|
|
480
|
+
"""
|
|
481
|
+
Initialize the embeddings model and detect dimensions.
|
|
482
|
+
|
|
483
|
+
This helper method handles embedding model initialization for subclasses.
|
|
484
|
+
|
|
485
|
+
Args:
|
|
486
|
+
embedding_model: The embedding model name or LLMModel instance
|
|
487
|
+
embedding_dims: Optional pre-configured embedding dimensions
|
|
488
|
+
"""
|
|
489
|
+
# Convert embedding_model to LLMModel if it's a string
|
|
490
|
+
model: LLMModel = (
|
|
491
|
+
LLMModel(name=embedding_model)
|
|
492
|
+
if isinstance(embedding_model, str)
|
|
493
|
+
else embedding_model
|
|
494
|
+
)
|
|
495
|
+
self._embeddings = model.as_embeddings_model()
|
|
496
|
+
|
|
497
|
+
# Auto-detect embedding dimensions if not provided
|
|
498
|
+
if embedding_dims is None:
|
|
499
|
+
sample_embedding: list[float] = self._embeddings.embed_query("test")
|
|
500
|
+
self._embedding_dims = len(sample_embedding)
|
|
501
|
+
logger.debug(
|
|
502
|
+
"Auto-detected embedding dimensions",
|
|
503
|
+
layer=self.name,
|
|
504
|
+
dims=self._embedding_dims,
|
|
505
|
+
)
|
|
506
|
+
else:
|
|
507
|
+
self._embedding_dims = embedding_dims
|
|
508
|
+
|
|
509
|
+
def _embed_question_with_genie_history(
|
|
510
|
+
self,
|
|
511
|
+
question: str,
|
|
512
|
+
conversation_id: str | None,
|
|
513
|
+
context_window_size: int,
|
|
514
|
+
max_context_tokens: int,
|
|
515
|
+
) -> tuple[list[float], list[float], str]:
|
|
516
|
+
"""
|
|
517
|
+
Generate dual embeddings using Genie API for conversation history.
|
|
518
|
+
|
|
519
|
+
This method retrieves conversation history from the Genie API and
|
|
520
|
+
generates dual embeddings for semantic matching.
|
|
521
|
+
|
|
522
|
+
Args:
|
|
523
|
+
question: The question to embed
|
|
524
|
+
conversation_id: Optional conversation ID for retrieving context
|
|
525
|
+
context_window_size: Number of previous messages to include
|
|
526
|
+
max_context_tokens: Maximum tokens for context string
|
|
527
|
+
|
|
528
|
+
Returns:
|
|
529
|
+
Tuple of (question_embedding, context_embedding, conversation_context_string)
|
|
530
|
+
"""
|
|
531
|
+
conversation_context = ""
|
|
532
|
+
|
|
533
|
+
# If conversation context is enabled and available
|
|
534
|
+
if (
|
|
535
|
+
self.workspace_client is not None
|
|
536
|
+
and conversation_id is not None
|
|
537
|
+
and context_window_size > 0
|
|
538
|
+
):
|
|
539
|
+
try:
|
|
540
|
+
# Retrieve conversation history from Genie API
|
|
541
|
+
conversation_messages = get_conversation_history(
|
|
542
|
+
workspace_client=self.workspace_client,
|
|
543
|
+
space_id=self.space_id,
|
|
544
|
+
conversation_id=conversation_id,
|
|
545
|
+
max_messages=context_window_size * 2, # Get extra for safety
|
|
546
|
+
)
|
|
547
|
+
|
|
548
|
+
# Build context string
|
|
549
|
+
if conversation_messages:
|
|
550
|
+
recent_messages = (
|
|
551
|
+
conversation_messages[-context_window_size:]
|
|
552
|
+
if len(conversation_messages) > context_window_size
|
|
553
|
+
else conversation_messages
|
|
554
|
+
)
|
|
555
|
+
|
|
556
|
+
context_parts: list[str] = []
|
|
557
|
+
for msg in recent_messages:
|
|
558
|
+
if msg.content:
|
|
559
|
+
content: str = msg.content
|
|
560
|
+
if len(content) > 500:
|
|
561
|
+
content = content[:500] + "..."
|
|
562
|
+
context_parts.append(f"Previous: {content}")
|
|
563
|
+
|
|
564
|
+
conversation_context = "\n".join(context_parts)
|
|
565
|
+
|
|
566
|
+
# Truncate if too long
|
|
567
|
+
estimated_tokens = len(conversation_context) / 4
|
|
568
|
+
if estimated_tokens > max_context_tokens:
|
|
569
|
+
target_chars = max_context_tokens * 4
|
|
570
|
+
conversation_context = (
|
|
571
|
+
conversation_context[:target_chars] + "..."
|
|
572
|
+
)
|
|
573
|
+
|
|
574
|
+
logger.trace(
|
|
575
|
+
"Using conversation context from Genie API",
|
|
576
|
+
layer=self.name,
|
|
577
|
+
messages_count=len(conversation_messages),
|
|
578
|
+
window_size=context_window_size,
|
|
579
|
+
)
|
|
580
|
+
except Exception as e:
|
|
581
|
+
logger.warning(
|
|
582
|
+
"Failed to build conversation context, using question only",
|
|
583
|
+
layer=self.name,
|
|
584
|
+
error=str(e),
|
|
585
|
+
)
|
|
586
|
+
conversation_context = ""
|
|
587
|
+
|
|
588
|
+
return self._generate_dual_embeddings(question, conversation_context)
|
|
589
|
+
|
|
590
|
+
def _generate_dual_embeddings(
|
|
591
|
+
self, question: str, conversation_context: str
|
|
592
|
+
) -> tuple[list[float], list[float], str]:
|
|
593
|
+
"""
|
|
594
|
+
Generate dual embeddings for question and conversation context.
|
|
595
|
+
|
|
596
|
+
Args:
|
|
597
|
+
question: The question to embed
|
|
598
|
+
conversation_context: The conversation context string
|
|
599
|
+
|
|
600
|
+
Returns:
|
|
601
|
+
Tuple of (question_embedding, context_embedding, conversation_context)
|
|
602
|
+
"""
|
|
603
|
+
if conversation_context:
|
|
604
|
+
# Embed both question and context
|
|
605
|
+
embeddings: list[list[float]] = self._embeddings.embed_documents(
|
|
606
|
+
[question, conversation_context]
|
|
607
|
+
)
|
|
608
|
+
question_embedding = embeddings[0]
|
|
609
|
+
context_embedding = embeddings[1]
|
|
610
|
+
else:
|
|
611
|
+
# Only embed question, use zero vector for context
|
|
612
|
+
embeddings = self._embeddings.embed_documents([question])
|
|
613
|
+
question_embedding = embeddings[0]
|
|
614
|
+
context_embedding = [0.0] * len(question_embedding) # Zero vector
|
|
615
|
+
|
|
616
|
+
return question_embedding, context_embedding, conversation_context
|
|
617
|
+
|
|
618
|
+
@mlflow.trace(name="execute_cached_sql")
|
|
619
|
+
def _execute_sql(self, sql: str) -> pd.DataFrame | str:
|
|
620
|
+
"""
|
|
621
|
+
Execute SQL using the warehouse and return results.
|
|
622
|
+
|
|
623
|
+
Args:
|
|
624
|
+
sql: The SQL query to execute
|
|
625
|
+
|
|
626
|
+
Returns:
|
|
627
|
+
DataFrame with results, or error message string if execution failed
|
|
628
|
+
"""
|
|
629
|
+
return execute_sql_via_warehouse(
|
|
630
|
+
warehouse=self.warehouse,
|
|
631
|
+
sql=sql,
|
|
632
|
+
layer_name=self.name,
|
|
633
|
+
)
|
|
634
|
+
|
|
635
|
+
def _build_cache_hit_response(
|
|
636
|
+
self,
|
|
637
|
+
cached: SQLCacheEntry,
|
|
638
|
+
result: pd.DataFrame,
|
|
639
|
+
conversation_id: str | None,
|
|
640
|
+
) -> CacheResult:
|
|
641
|
+
"""
|
|
642
|
+
Build a CacheResult for a cache hit.
|
|
643
|
+
|
|
644
|
+
Args:
|
|
645
|
+
cached: The cached SQL entry
|
|
646
|
+
result: The fresh DataFrame from SQL execution
|
|
647
|
+
conversation_id: The current conversation ID
|
|
648
|
+
|
|
649
|
+
Returns:
|
|
650
|
+
CacheResult with cache_hit=True, including message_id and cache_entry_id
|
|
651
|
+
from the original cached entry for traceability and feedback support.
|
|
652
|
+
"""
|
|
653
|
+
# IMPORTANT: Use the current conversation_id (from the request), not the cached one
|
|
654
|
+
# This ensures the conversation continues properly
|
|
655
|
+
response = GenieResponse(
|
|
656
|
+
result=result,
|
|
657
|
+
query=cached.query,
|
|
658
|
+
description=cached.description,
|
|
659
|
+
conversation_id=conversation_id
|
|
660
|
+
if conversation_id
|
|
661
|
+
else cached.conversation_id,
|
|
662
|
+
)
|
|
663
|
+
# Cache hit - include message_id from original response for feedback support
|
|
664
|
+
# and cache_entry_id for traceability to genie_prompt_history
|
|
665
|
+
return CacheResult(
|
|
666
|
+
response=response,
|
|
667
|
+
cache_hit=True,
|
|
668
|
+
served_by=self.name,
|
|
669
|
+
message_id=cached.message_id,
|
|
670
|
+
cache_entry_id=cached.cache_entry_id,
|
|
671
|
+
)
|
|
672
|
+
|
|
673
|
+
def ask_question(
|
|
674
|
+
self, question: str, conversation_id: str | None = None
|
|
675
|
+
) -> CacheResult:
|
|
676
|
+
"""
|
|
677
|
+
Ask a question, using semantic cache if a similar query exists.
|
|
678
|
+
|
|
679
|
+
On cache hit, re-executes the cached SQL to get fresh data.
|
|
680
|
+
Returns CacheResult with cache metadata.
|
|
681
|
+
|
|
682
|
+
This method wraps ask_question_with_cache_info with error handling
|
|
683
|
+
to ensure graceful degradation on cache failures.
|
|
684
|
+
|
|
685
|
+
Args:
|
|
686
|
+
question: The question to ask
|
|
687
|
+
conversation_id: Optional conversation ID for context
|
|
688
|
+
|
|
689
|
+
Returns:
|
|
690
|
+
CacheResult with fresh response and cache metadata
|
|
691
|
+
"""
|
|
692
|
+
try:
|
|
693
|
+
return self.ask_question_with_cache_info(question, conversation_id)
|
|
694
|
+
except Exception as e:
|
|
695
|
+
logger.warning(
|
|
696
|
+
"Context-aware cache operation failed, delegating to underlying service",
|
|
697
|
+
layer=self.name,
|
|
698
|
+
error=str(e),
|
|
699
|
+
exc_info=True,
|
|
700
|
+
)
|
|
701
|
+
# Graceful degradation: fall back to underlying service
|
|
702
|
+
return self.impl.ask_question(question, conversation_id)
|
|
703
|
+
|
|
704
|
+
def ask_question_with_cache_info(
|
|
705
|
+
self,
|
|
706
|
+
question: str,
|
|
707
|
+
conversation_id: str | None = None,
|
|
708
|
+
) -> CacheResult:
|
|
709
|
+
"""
|
|
710
|
+
Template method for asking a question with cache lookup.
|
|
711
|
+
|
|
712
|
+
This method implements the cache lookup algorithm using the Template Method
|
|
713
|
+
pattern. Subclasses can customize behavior by overriding hook methods:
|
|
714
|
+
- _before_cache_lookup(): Called before cache search (e.g., store prompt)
|
|
715
|
+
- _after_cache_hit(): Called after a cache hit (e.g., update prompt flags)
|
|
716
|
+
- _after_cache_miss(): Called after a cache miss (e.g., store prompt)
|
|
717
|
+
|
|
718
|
+
Args:
|
|
719
|
+
question: The question to ask
|
|
720
|
+
conversation_id: Optional conversation ID for context and continuation
|
|
721
|
+
|
|
722
|
+
Returns:
|
|
723
|
+
CacheResult with fresh response and cache metadata
|
|
724
|
+
"""
|
|
725
|
+
self._setup()
|
|
726
|
+
|
|
727
|
+
# Step 1: Generate dual embeddings
|
|
728
|
+
question_embedding, context_embedding, conversation_context = (
|
|
729
|
+
self._embed_question(question, conversation_id)
|
|
730
|
+
)
|
|
731
|
+
|
|
732
|
+
# Step 2: Hook for pre-lookup actions (e.g., store prompt in history)
|
|
733
|
+
self._before_cache_lookup(question, conversation_id)
|
|
734
|
+
|
|
735
|
+
# Step 3: Search for similar cached entry
|
|
736
|
+
cache_result = self._find_similar(
|
|
737
|
+
question,
|
|
738
|
+
conversation_context,
|
|
739
|
+
question_embedding,
|
|
740
|
+
context_embedding,
|
|
741
|
+
conversation_id,
|
|
742
|
+
)
|
|
743
|
+
|
|
744
|
+
# Step 4: Handle cache hit or miss
|
|
745
|
+
if cache_result is not None:
|
|
746
|
+
cached, combined_similarity = cache_result
|
|
747
|
+
|
|
748
|
+
result = self._handle_cache_hit(
|
|
749
|
+
question,
|
|
750
|
+
conversation_id,
|
|
751
|
+
cached,
|
|
752
|
+
combined_similarity,
|
|
753
|
+
conversation_context,
|
|
754
|
+
question_embedding,
|
|
755
|
+
context_embedding,
|
|
756
|
+
)
|
|
757
|
+
|
|
758
|
+
# Hook for post-cache-hit actions (e.g., update prompt cache_hit flag)
|
|
759
|
+
self._after_cache_hit(question, conversation_id, result)
|
|
760
|
+
|
|
761
|
+
return result
|
|
762
|
+
|
|
763
|
+
# Handle cache miss
|
|
764
|
+
result = self._handle_cache_miss(
|
|
765
|
+
question,
|
|
766
|
+
conversation_id,
|
|
767
|
+
conversation_context,
|
|
768
|
+
question_embedding,
|
|
769
|
+
context_embedding,
|
|
770
|
+
)
|
|
771
|
+
|
|
772
|
+
# Hook for post-cache-miss actions (e.g., store prompt if not done earlier)
|
|
773
|
+
self._after_cache_miss(question, conversation_id, result)
|
|
774
|
+
|
|
775
|
+
return result
|
|
776
|
+
|
|
777
|
+
def _before_cache_lookup(self, question: str, conversation_id: str | None) -> None:
|
|
778
|
+
"""
|
|
779
|
+
Hook method called before cache lookup.
|
|
780
|
+
|
|
781
|
+
Override this method to perform actions before searching the cache,
|
|
782
|
+
such as storing the prompt in history.
|
|
783
|
+
|
|
784
|
+
Args:
|
|
785
|
+
question: The question being asked
|
|
786
|
+
conversation_id: Optional conversation ID
|
|
787
|
+
"""
|
|
788
|
+
pass
|
|
789
|
+
|
|
790
|
+
def _after_cache_hit(
|
|
791
|
+
self,
|
|
792
|
+
question: str,
|
|
793
|
+
conversation_id: str | None,
|
|
794
|
+
result: CacheResult,
|
|
795
|
+
) -> None:
|
|
796
|
+
"""
|
|
797
|
+
Hook method called after a cache hit.
|
|
798
|
+
|
|
799
|
+
Override this method to perform actions after a successful cache hit,
|
|
800
|
+
such as updating prompt history flags.
|
|
801
|
+
|
|
802
|
+
Args:
|
|
803
|
+
question: The question that was asked
|
|
804
|
+
conversation_id: Optional conversation ID
|
|
805
|
+
result: The cache result
|
|
806
|
+
"""
|
|
807
|
+
pass
|
|
808
|
+
|
|
809
|
+
def _after_cache_miss(
|
|
810
|
+
self,
|
|
811
|
+
question: str,
|
|
812
|
+
conversation_id: str | None,
|
|
813
|
+
result: CacheResult,
|
|
814
|
+
) -> None:
|
|
815
|
+
"""
|
|
816
|
+
Hook method called after a cache miss.
|
|
817
|
+
|
|
818
|
+
Override this method to perform actions after a cache miss,
|
|
819
|
+
such as storing prompt history if not done earlier.
|
|
820
|
+
|
|
821
|
+
Args:
|
|
822
|
+
question: The question that was asked
|
|
823
|
+
conversation_id: Optional conversation ID
|
|
824
|
+
result: The cache result
|
|
825
|
+
"""
|
|
826
|
+
pass
|
|
827
|
+
|
|
828
|
+
def _handle_cache_hit(
|
|
829
|
+
self,
|
|
830
|
+
question: str,
|
|
831
|
+
conversation_id: str | None,
|
|
832
|
+
cached: SQLCacheEntry,
|
|
833
|
+
combined_similarity: float,
|
|
834
|
+
conversation_context: str,
|
|
835
|
+
question_embedding: list[float],
|
|
836
|
+
context_embedding: list[float],
|
|
837
|
+
) -> CacheResult:
|
|
838
|
+
"""
|
|
839
|
+
Handle a cache hit - execute cached SQL and return response.
|
|
840
|
+
|
|
841
|
+
This method handles the common cache hit logic including SQL execution,
|
|
842
|
+
stale cache fallback, and response building.
|
|
843
|
+
|
|
844
|
+
Args:
|
|
845
|
+
question: The original question
|
|
846
|
+
conversation_id: The conversation ID
|
|
847
|
+
cached: The cached SQL entry
|
|
848
|
+
combined_similarity: The similarity score
|
|
849
|
+
conversation_context: The conversation context string
|
|
850
|
+
question_embedding: The question embedding
|
|
851
|
+
context_embedding: The context embedding
|
|
852
|
+
|
|
853
|
+
Returns:
|
|
854
|
+
CacheResult with the response
|
|
855
|
+
"""
|
|
856
|
+
logger.debug(
|
|
857
|
+
"Cache hit",
|
|
858
|
+
layer=self.name,
|
|
859
|
+
combined_similarity=f"{combined_similarity:.3f}",
|
|
860
|
+
question=question[:50],
|
|
861
|
+
conversation_id=conversation_id,
|
|
862
|
+
)
|
|
863
|
+
|
|
864
|
+
# Re-execute the cached SQL to get fresh data
|
|
865
|
+
result: pd.DataFrame | str = self._execute_sql(cached.query)
|
|
866
|
+
|
|
867
|
+
# Check if SQL execution failed (returns error string instead of DataFrame)
|
|
868
|
+
if isinstance(result, str):
|
|
869
|
+
logger.warning(
|
|
870
|
+
"Cached SQL execution failed, falling back to Genie",
|
|
871
|
+
layer=self.name,
|
|
872
|
+
question=question[:80],
|
|
873
|
+
conversation_id=conversation_id,
|
|
874
|
+
cached_sql=cached.query[:80],
|
|
875
|
+
error=result[:200],
|
|
876
|
+
space_id=self.space_id,
|
|
877
|
+
)
|
|
878
|
+
|
|
879
|
+
# Subclass should handle stale entry cleanup
|
|
880
|
+
self._on_stale_cache_entry(question)
|
|
881
|
+
|
|
882
|
+
# Fall back to Genie to get fresh SQL
|
|
883
|
+
logger.info(
|
|
884
|
+
"Delegating to Genie for fresh SQL",
|
|
885
|
+
layer=self.name,
|
|
886
|
+
question=question[:80],
|
|
887
|
+
conversation_id=conversation_id,
|
|
888
|
+
space_id=self.space_id,
|
|
889
|
+
delegating_to=type(self.impl).__name__,
|
|
890
|
+
)
|
|
891
|
+
fallback_result: CacheResult = self.impl.ask_question(
|
|
892
|
+
question, conversation_id
|
|
893
|
+
)
|
|
894
|
+
|
|
895
|
+
# Store the fresh SQL in cache
|
|
896
|
+
if fallback_result.response.query:
|
|
897
|
+
self._store_entry(
|
|
898
|
+
question,
|
|
899
|
+
conversation_context,
|
|
900
|
+
question_embedding,
|
|
901
|
+
context_embedding,
|
|
902
|
+
fallback_result.response,
|
|
903
|
+
message_id=fallback_result.message_id,
|
|
904
|
+
)
|
|
905
|
+
logger.info(
|
|
906
|
+
"Stored fresh SQL from fallback",
|
|
907
|
+
layer=self.name,
|
|
908
|
+
fresh_sql=fallback_result.response.query[:80],
|
|
909
|
+
space_id=self.space_id,
|
|
910
|
+
message_id=fallback_result.message_id,
|
|
911
|
+
)
|
|
912
|
+
else:
|
|
913
|
+
logger.warning(
|
|
914
|
+
"Fallback response has no SQL query to cache",
|
|
915
|
+
layer=self.name,
|
|
916
|
+
question=question[:80],
|
|
917
|
+
space_id=self.space_id,
|
|
918
|
+
)
|
|
919
|
+
|
|
920
|
+
# Return as cache miss (fallback scenario)
|
|
921
|
+
# Propagate message_id from fallback result
|
|
922
|
+
return CacheResult(
|
|
923
|
+
response=fallback_result.response,
|
|
924
|
+
cache_hit=False,
|
|
925
|
+
served_by=None,
|
|
926
|
+
message_id=fallback_result.message_id,
|
|
927
|
+
)
|
|
928
|
+
|
|
929
|
+
# Build and return cache hit response
|
|
930
|
+
return self._build_cache_hit_response(cached, result, conversation_id)
|
|
931
|
+
|
|
932
|
+
def _on_stale_cache_entry(self, question: str) -> None:
|
|
933
|
+
"""
|
|
934
|
+
Called when a stale cache entry is detected (SQL execution failed).
|
|
935
|
+
|
|
936
|
+
Subclasses can override this to clean up the stale entry from storage.
|
|
937
|
+
|
|
938
|
+
Args:
|
|
939
|
+
question: The question that had a stale cache entry
|
|
940
|
+
"""
|
|
941
|
+
# Default implementation does nothing - subclasses should override
|
|
942
|
+
pass
|
|
943
|
+
|
|
944
|
+
def _handle_cache_miss(
|
|
945
|
+
self,
|
|
946
|
+
question: str,
|
|
947
|
+
conversation_id: str | None,
|
|
948
|
+
conversation_context: str,
|
|
949
|
+
question_embedding: list[float],
|
|
950
|
+
context_embedding: list[float],
|
|
951
|
+
) -> CacheResult:
|
|
952
|
+
"""
|
|
953
|
+
Handle a cache miss - delegate to underlying service and store result.
|
|
954
|
+
|
|
955
|
+
Args:
|
|
956
|
+
question: The original question
|
|
957
|
+
conversation_id: The conversation ID
|
|
958
|
+
conversation_context: The conversation context string
|
|
959
|
+
question_embedding: The question embedding
|
|
960
|
+
context_embedding: The context embedding
|
|
961
|
+
|
|
962
|
+
Returns:
|
|
963
|
+
CacheResult from the underlying service
|
|
964
|
+
"""
|
|
965
|
+
logger.info(
|
|
966
|
+
"Cache MISS",
|
|
967
|
+
layer=self.name,
|
|
968
|
+
question=question[:80],
|
|
969
|
+
conversation_id=conversation_id,
|
|
970
|
+
space_id=self.space_id,
|
|
971
|
+
similarity_threshold=self.similarity_threshold,
|
|
972
|
+
delegating_to=type(self.impl).__name__,
|
|
973
|
+
)
|
|
974
|
+
|
|
975
|
+
result: CacheResult = self.impl.ask_question(question, conversation_id)
|
|
976
|
+
|
|
977
|
+
# Store in cache if we got a SQL query
|
|
978
|
+
if result.response.query:
|
|
979
|
+
logger.debug(
|
|
980
|
+
"Storing new cache entry",
|
|
981
|
+
layer=self.name,
|
|
982
|
+
question=question[:50],
|
|
983
|
+
conversation_id=conversation_id,
|
|
984
|
+
space=self.space_id,
|
|
985
|
+
message_id=result.message_id,
|
|
986
|
+
)
|
|
987
|
+
self._store_entry(
|
|
988
|
+
question,
|
|
989
|
+
conversation_context,
|
|
990
|
+
question_embedding,
|
|
991
|
+
context_embedding,
|
|
992
|
+
result.response,
|
|
993
|
+
message_id=result.message_id,
|
|
994
|
+
)
|
|
995
|
+
else:
|
|
996
|
+
logger.warning(
|
|
997
|
+
"Not caching: response has no SQL query",
|
|
998
|
+
layer=self.name,
|
|
999
|
+
question=question[:50],
|
|
1000
|
+
)
|
|
1001
|
+
|
|
1002
|
+
# Propagate message_id from underlying service result
|
|
1003
|
+
return CacheResult(
|
|
1004
|
+
response=result.response,
|
|
1005
|
+
cache_hit=False,
|
|
1006
|
+
served_by=None,
|
|
1007
|
+
message_id=result.message_id,
|
|
1008
|
+
)
|
|
1009
|
+
|
|
1010
|
+
@abstractmethod
|
|
1011
|
+
def _invalidate_by_question(self, question: str) -> bool:
|
|
1012
|
+
"""
|
|
1013
|
+
Invalidate cache entries matching a specific question.
|
|
1014
|
+
|
|
1015
|
+
This method is called when negative feedback is received to remove
|
|
1016
|
+
the corresponding cache entry.
|
|
1017
|
+
|
|
1018
|
+
Args:
|
|
1019
|
+
question: The question text to match and invalidate
|
|
1020
|
+
|
|
1021
|
+
Returns:
|
|
1022
|
+
True if an entry was found and invalidated, False otherwise
|
|
1023
|
+
"""
|
|
1024
|
+
pass
|
|
1025
|
+
|
|
1026
|
+
@mlflow.trace(name="genie_context_aware_send_feedback")
|
|
1027
|
+
def send_feedback(
|
|
1028
|
+
self,
|
|
1029
|
+
conversation_id: str,
|
|
1030
|
+
rating: GenieFeedbackRating,
|
|
1031
|
+
message_id: str | None = None,
|
|
1032
|
+
was_cache_hit: bool = False,
|
|
1033
|
+
) -> None:
|
|
1034
|
+
"""
|
|
1035
|
+
Send feedback for a Genie message with cache invalidation.
|
|
1036
|
+
|
|
1037
|
+
For context-aware caches, this method:
|
|
1038
|
+
1. If was_cache_hit is False: forwards feedback to the underlying service
|
|
1039
|
+
2. If rating is NEGATIVE: invalidates any matching cache entries
|
|
1040
|
+
|
|
1041
|
+
Args:
|
|
1042
|
+
conversation_id: The conversation containing the message
|
|
1043
|
+
rating: The feedback rating (POSITIVE, NEGATIVE, or NONE)
|
|
1044
|
+
message_id: Optional message ID. If None, looks up the most recent message.
|
|
1045
|
+
was_cache_hit: Whether the response being rated was served from cache.
|
|
1046
|
+
|
|
1047
|
+
Note:
|
|
1048
|
+
For cached responses (was_cache_hit=True), only cache invalidation is
|
|
1049
|
+
performed. No feedback is sent to the Genie API because cached responses
|
|
1050
|
+
don't have a corresponding Genie message.
|
|
1051
|
+
|
|
1052
|
+
Future Enhancement: To enable full Genie feedback for cached responses,
|
|
1053
|
+
the cache would need to store the original message_id. This would require:
|
|
1054
|
+
1. Adding message_id column to cache tables
|
|
1055
|
+
2. Adding message_id field to SQLCacheEntry dataclass
|
|
1056
|
+
3. Capturing message_id from the original Genie API response
|
|
1057
|
+
(databricks_ai_bridge.genie.GenieResponse doesn't expose this)
|
|
1058
|
+
4. Using WorkspaceClient directly instead of databricks_ai_bridge
|
|
1059
|
+
"""
|
|
1060
|
+
invalidated = False
|
|
1061
|
+
|
|
1062
|
+
# Handle cache invalidation on negative feedback
|
|
1063
|
+
if rating == GenieFeedbackRating.NEGATIVE:
|
|
1064
|
+
# Need to look up the message content to find matching cache entries
|
|
1065
|
+
if self.workspace_client is not None:
|
|
1066
|
+
from dao_ai.genie.cache.base import (
|
|
1067
|
+
get_latest_message_id,
|
|
1068
|
+
get_message_content,
|
|
1069
|
+
)
|
|
1070
|
+
|
|
1071
|
+
# Get message_id if not provided
|
|
1072
|
+
target_message_id = message_id
|
|
1073
|
+
if target_message_id is None:
|
|
1074
|
+
target_message_id = get_latest_message_id(
|
|
1075
|
+
workspace_client=self.workspace_client,
|
|
1076
|
+
space_id=self.space_id,
|
|
1077
|
+
conversation_id=conversation_id,
|
|
1078
|
+
)
|
|
1079
|
+
|
|
1080
|
+
# Get the message content (question) to find matching cache entries
|
|
1081
|
+
if target_message_id:
|
|
1082
|
+
question = get_message_content(
|
|
1083
|
+
workspace_client=self.workspace_client,
|
|
1084
|
+
space_id=self.space_id,
|
|
1085
|
+
conversation_id=conversation_id,
|
|
1086
|
+
message_id=target_message_id,
|
|
1087
|
+
)
|
|
1088
|
+
|
|
1089
|
+
if question:
|
|
1090
|
+
invalidated = self._invalidate_by_question(question)
|
|
1091
|
+
if invalidated:
|
|
1092
|
+
logger.info(
|
|
1093
|
+
"Invalidated cache entry due to negative feedback",
|
|
1094
|
+
layer=self.name,
|
|
1095
|
+
question=question[:80],
|
|
1096
|
+
conversation_id=conversation_id,
|
|
1097
|
+
message_id=target_message_id,
|
|
1098
|
+
)
|
|
1099
|
+
else:
|
|
1100
|
+
logger.debug(
|
|
1101
|
+
"No cache entry found to invalidate for negative feedback",
|
|
1102
|
+
layer=self.name,
|
|
1103
|
+
question=question[:80],
|
|
1104
|
+
conversation_id=conversation_id,
|
|
1105
|
+
)
|
|
1106
|
+
else:
|
|
1107
|
+
logger.warning(
|
|
1108
|
+
"Could not retrieve message content for cache invalidation",
|
|
1109
|
+
layer=self.name,
|
|
1110
|
+
conversation_id=conversation_id,
|
|
1111
|
+
message_id=target_message_id,
|
|
1112
|
+
)
|
|
1113
|
+
else:
|
|
1114
|
+
logger.warning(
|
|
1115
|
+
"Could not find message_id for cache invalidation",
|
|
1116
|
+
layer=self.name,
|
|
1117
|
+
conversation_id=conversation_id,
|
|
1118
|
+
)
|
|
1119
|
+
else:
|
|
1120
|
+
logger.warning(
|
|
1121
|
+
"No workspace_client available for cache invalidation",
|
|
1122
|
+
layer=self.name,
|
|
1123
|
+
conversation_id=conversation_id,
|
|
1124
|
+
)
|
|
1125
|
+
|
|
1126
|
+
# Forward feedback to underlying service if not a cache hit
|
|
1127
|
+
# For cache hits, there's no Genie message to provide feedback on
|
|
1128
|
+
if was_cache_hit:
|
|
1129
|
+
logger.info(
|
|
1130
|
+
"Skipping Genie API feedback - response was served from cache",
|
|
1131
|
+
layer=self.name,
|
|
1132
|
+
conversation_id=conversation_id,
|
|
1133
|
+
rating=rating.value if rating else None,
|
|
1134
|
+
cache_invalidated=invalidated,
|
|
1135
|
+
)
|
|
1136
|
+
return
|
|
1137
|
+
|
|
1138
|
+
# Forward to underlying service
|
|
1139
|
+
logger.debug(
|
|
1140
|
+
"Forwarding feedback to underlying service",
|
|
1141
|
+
layer=self.name,
|
|
1142
|
+
conversation_id=conversation_id,
|
|
1143
|
+
rating=rating.value if rating else None,
|
|
1144
|
+
delegating_to=type(self.impl).__name__,
|
|
1145
|
+
)
|
|
1146
|
+
self.impl.send_feedback(
|
|
1147
|
+
conversation_id=conversation_id,
|
|
1148
|
+
rating=rating,
|
|
1149
|
+
message_id=message_id,
|
|
1150
|
+
was_cache_hit=False, # Already handled, so pass False
|
|
1151
|
+
)
|