dao-ai 0.1.19__py3-none-any.whl → 0.1.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/cli.py +37 -7
- dao_ai/config.py +265 -10
- dao_ai/genie/__init__.py +55 -7
- dao_ai/genie/cache/__init__.py +36 -9
- dao_ai/genie/cache/base.py +143 -2
- dao_ai/genie/cache/context_aware/__init__.py +52 -0
- dao_ai/genie/cache/context_aware/base.py +1204 -0
- dao_ai/genie/cache/{in_memory_semantic.py → context_aware/in_memory.py} +233 -383
- dao_ai/genie/cache/context_aware/optimization.py +930 -0
- dao_ai/genie/cache/context_aware/persistent.py +802 -0
- dao_ai/genie/cache/context_aware/postgres.py +1343 -0
- dao_ai/genie/cache/lru.py +248 -70
- dao_ai/genie/core.py +235 -11
- dao_ai/middleware/__init__.py +8 -1
- dao_ai/middleware/tool_call_observability.py +227 -0
- dao_ai/nodes.py +4 -4
- dao_ai/tools/__init__.py +2 -2
- dao_ai/tools/genie.py +10 -10
- dao_ai/utils.py +7 -3
- {dao_ai-0.1.19.dist-info → dao_ai-0.1.21.dist-info}/METADATA +1 -1
- {dao_ai-0.1.19.dist-info → dao_ai-0.1.21.dist-info}/RECORD +24 -19
- dao_ai/genie/cache/semantic.py +0 -1004
- {dao_ai-0.1.19.dist-info → dao_ai-0.1.21.dist-info}/WHEEL +0 -0
- {dao_ai-0.1.19.dist-info → dao_ai-0.1.21.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.1.19.dist-info → dao_ai-0.1.21.dist-info}/licenses/LICENSE +0 -0
dao_ai/genie/cache/lru.py
CHANGED
|
@@ -6,15 +6,16 @@ by Genie. On cache hit, the cached SQL is re-executed against the warehouse
|
|
|
6
6
|
to return fresh data while avoiding the Genie NL-to-SQL translation cost.
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
9
11
|
from collections import OrderedDict
|
|
10
12
|
from datetime import datetime, timedelta
|
|
11
13
|
from threading import Lock
|
|
12
|
-
from typing import Any
|
|
13
14
|
|
|
14
15
|
import mlflow
|
|
15
16
|
import pandas as pd
|
|
16
17
|
from databricks.sdk import WorkspaceClient
|
|
17
|
-
from databricks.sdk.service.
|
|
18
|
+
from databricks.sdk.service.dashboards import GenieFeedbackRating
|
|
18
19
|
from databricks_ai_bridge.genie import GenieResponse
|
|
19
20
|
from loguru import logger
|
|
20
21
|
|
|
@@ -24,6 +25,7 @@ from dao_ai.genie.cache.base import (
|
|
|
24
25
|
GenieServiceBase,
|
|
25
26
|
SQLCacheEntry,
|
|
26
27
|
)
|
|
28
|
+
from dao_ai.genie.cache.core import execute_sql_via_warehouse
|
|
27
29
|
|
|
28
30
|
|
|
29
31
|
class LRUCacheService(GenieServiceBase):
|
|
@@ -141,8 +143,20 @@ class LRUCacheService(GenieServiceBase):
|
|
|
141
143
|
self._cache.move_to_end(key)
|
|
142
144
|
return entry
|
|
143
145
|
|
|
144
|
-
def _put(
|
|
146
|
+
def _put(
|
|
147
|
+
self, key: str, response: GenieResponse, message_id: str | None = None
|
|
148
|
+
) -> None:
|
|
145
149
|
"""Store SQL query in cache, evicting if at capacity."""
|
|
150
|
+
# Skip caching if query is empty or whitespace
|
|
151
|
+
if not response.query or not response.query.strip():
|
|
152
|
+
logger.warning(
|
|
153
|
+
"Not caching: response has no SQL query",
|
|
154
|
+
layer=self.name,
|
|
155
|
+
key=key[:50],
|
|
156
|
+
description=response.description[:80] if response.description else None,
|
|
157
|
+
)
|
|
158
|
+
return
|
|
159
|
+
|
|
146
160
|
if key in self._cache:
|
|
147
161
|
del self._cache[key]
|
|
148
162
|
|
|
@@ -154,6 +168,9 @@ class LRUCacheService(GenieServiceBase):
|
|
|
154
168
|
description=response.description,
|
|
155
169
|
conversation_id=response.conversation_id,
|
|
156
170
|
created_at=datetime.now(),
|
|
171
|
+
message_id=message_id,
|
|
172
|
+
# LRU cache is in-memory only, no database row ID
|
|
173
|
+
cache_entry_id=None,
|
|
157
174
|
)
|
|
158
175
|
logger.debug(
|
|
159
176
|
"Stored cache entry",
|
|
@@ -162,6 +179,7 @@ class LRUCacheService(GenieServiceBase):
|
|
|
162
179
|
sql=response.query[:50] if response.query else None,
|
|
163
180
|
cache_size=len(self._cache),
|
|
164
181
|
capacity=self.capacity,
|
|
182
|
+
message_id=message_id,
|
|
165
183
|
)
|
|
166
184
|
|
|
167
185
|
@mlflow.trace(name="execute_cached_sql")
|
|
@@ -175,50 +193,22 @@ class LRUCacheService(GenieServiceBase):
|
|
|
175
193
|
Returns:
|
|
176
194
|
DataFrame with results, or error message string
|
|
177
195
|
"""
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
logger.trace("Executing cached SQL", layer=self.name, sql=sql[:100])
|
|
182
|
-
|
|
183
|
-
statement_response: StatementResponse = w.statement_execution.execute_statement(
|
|
184
|
-
statement=sql,
|
|
185
|
-
warehouse_id=warehouse_id,
|
|
186
|
-
wait_timeout="30s",
|
|
187
|
-
)
|
|
188
|
-
|
|
189
|
-
# Poll for completion if still running
|
|
190
|
-
while statement_response.status.state in [
|
|
191
|
-
StatementState.PENDING,
|
|
192
|
-
StatementState.RUNNING,
|
|
193
|
-
]:
|
|
194
|
-
statement_response = w.statement_execution.get_statement(
|
|
195
|
-
statement_response.statement_id
|
|
196
|
-
)
|
|
197
|
-
|
|
198
|
-
if statement_response.status.state != StatementState.SUCCEEDED:
|
|
199
|
-
error_msg: str = f"SQL execution failed: {statement_response.status}"
|
|
196
|
+
# Validate SQL is not empty
|
|
197
|
+
if not sql or not sql.strip():
|
|
198
|
+
error_msg: str = "Cannot execute empty SQL query"
|
|
200
199
|
logger.error(
|
|
201
|
-
"SQL execution failed",
|
|
200
|
+
"SQL execution failed: empty query",
|
|
202
201
|
layer=self.name,
|
|
203
|
-
|
|
202
|
+
sql=repr(sql),
|
|
204
203
|
)
|
|
205
204
|
return error_msg
|
|
206
205
|
|
|
207
|
-
#
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
]
|
|
214
|
-
|
|
215
|
-
data: list[list[Any]] = statement_response.result.data_array
|
|
216
|
-
if columns:
|
|
217
|
-
return pd.DataFrame(data, columns=columns)
|
|
218
|
-
else:
|
|
219
|
-
return pd.DataFrame(data)
|
|
220
|
-
|
|
221
|
-
return pd.DataFrame()
|
|
206
|
+
# Use shared utility function for SQL execution
|
|
207
|
+
return execute_sql_via_warehouse(
|
|
208
|
+
warehouse=self.warehouse,
|
|
209
|
+
sql=sql,
|
|
210
|
+
layer_name=self.name,
|
|
211
|
+
)
|
|
222
212
|
|
|
223
213
|
def ask_question(
|
|
224
214
|
self, question: str, conversation_id: str | None = None
|
|
@@ -256,33 +246,131 @@ class LRUCacheService(GenieServiceBase):
|
|
|
256
246
|
cached: SQLCacheEntry | None = self._get(key)
|
|
257
247
|
|
|
258
248
|
if cached is not None:
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
249
|
+
# Defensive check: if cached query is empty, treat as cache miss
|
|
250
|
+
if not cached.query or not cached.query.strip():
|
|
251
|
+
logger.warning(
|
|
252
|
+
"Cache HIT but query is empty, treating as MISS",
|
|
253
|
+
layer=self.name,
|
|
254
|
+
question=question[:80],
|
|
255
|
+
conversation_id=conversation_id,
|
|
256
|
+
key=key[:50],
|
|
257
|
+
)
|
|
258
|
+
# Invalidate this bad cache entry
|
|
259
|
+
with self._lock:
|
|
260
|
+
if key in self._cache:
|
|
261
|
+
del self._cache[key]
|
|
262
|
+
# Fall through to cache miss logic below
|
|
263
|
+
else:
|
|
264
|
+
cache_age_seconds = (datetime.now() - cached.created_at).total_seconds()
|
|
265
|
+
logger.info(
|
|
266
|
+
"Cache HIT",
|
|
267
|
+
layer=self.name,
|
|
268
|
+
question=question[:80],
|
|
269
|
+
conversation_id=conversation_id,
|
|
270
|
+
cached_sql=cached.query[:80] if cached.query else None,
|
|
271
|
+
cache_age_seconds=round(cache_age_seconds, 1),
|
|
272
|
+
cache_size=self.size,
|
|
273
|
+
capacity=self.capacity,
|
|
274
|
+
ttl_seconds=self.parameters.time_to_live_seconds,
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
# Re-execute the cached SQL to get fresh data
|
|
278
|
+
result: pd.DataFrame | str = self._execute_sql(cached.query)
|
|
279
|
+
|
|
280
|
+
# Check if SQL execution failed (returns error string instead of DataFrame)
|
|
281
|
+
if isinstance(result, str):
|
|
282
|
+
logger.warning(
|
|
283
|
+
"Cached SQL execution failed, falling back to Genie",
|
|
284
|
+
layer=self.name,
|
|
285
|
+
question=question[:80],
|
|
286
|
+
conversation_id=conversation_id,
|
|
287
|
+
cached_sql=cached.query[:80],
|
|
288
|
+
error=result[:200],
|
|
289
|
+
cache_key=key[:50],
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
# Invalidate the bad cache entry
|
|
293
|
+
with self._lock:
|
|
294
|
+
if key in self._cache:
|
|
295
|
+
del self._cache[key]
|
|
296
|
+
logger.info(
|
|
297
|
+
"Invalidated stale cache entry",
|
|
298
|
+
layer=self.name,
|
|
299
|
+
cache_key=key[:50],
|
|
300
|
+
cache_size=len(self._cache),
|
|
301
|
+
capacity=self.capacity,
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
# Fall back to Genie to get fresh SQL
|
|
305
|
+
logger.info(
|
|
306
|
+
"Delegating to Genie for fresh SQL",
|
|
307
|
+
layer=self.name,
|
|
308
|
+
question=question[:80],
|
|
309
|
+
delegating_to=type(self.impl).__name__,
|
|
310
|
+
)
|
|
311
|
+
fallback_result: CacheResult = self.impl.ask_question(
|
|
312
|
+
question, conversation_id
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
# Store the fresh SQL in cache (including message_id for feedback)
|
|
316
|
+
if fallback_result.response.query:
|
|
317
|
+
with self._lock:
|
|
318
|
+
self._put(
|
|
319
|
+
key,
|
|
320
|
+
fallback_result.response,
|
|
321
|
+
message_id=fallback_result.message_id,
|
|
322
|
+
)
|
|
323
|
+
logger.info(
|
|
324
|
+
"Stored fresh SQL from fallback",
|
|
325
|
+
layer=self.name,
|
|
326
|
+
fresh_sql=fallback_result.response.query[:80],
|
|
327
|
+
cache_size=len(self._cache),
|
|
328
|
+
capacity=self.capacity,
|
|
329
|
+
message_id=fallback_result.message_id,
|
|
330
|
+
)
|
|
331
|
+
else:
|
|
332
|
+
logger.warning(
|
|
333
|
+
"Fallback response has no SQL query to cache",
|
|
334
|
+
layer=self.name,
|
|
335
|
+
question=question[:80],
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
logger.info(
|
|
339
|
+
"Fallback completed successfully",
|
|
340
|
+
layer=self.name,
|
|
341
|
+
question=question[:80],
|
|
342
|
+
fallback_from="stale_cache",
|
|
343
|
+
has_result=fallback_result.response.result is not None,
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
# Return as cache miss (fallback scenario)
|
|
347
|
+
# Propagate message_id from fallback result
|
|
348
|
+
return CacheResult(
|
|
349
|
+
response=fallback_result.response,
|
|
350
|
+
cache_hit=False,
|
|
351
|
+
served_by=None,
|
|
352
|
+
message_id=fallback_result.message_id,
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
# Use current conversation_id, not the cached one
|
|
356
|
+
response: GenieResponse = GenieResponse(
|
|
357
|
+
result=result,
|
|
358
|
+
query=cached.query,
|
|
359
|
+
description=cached.description,
|
|
360
|
+
conversation_id=conversation_id
|
|
361
|
+
if conversation_id
|
|
362
|
+
else cached.conversation_id,
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
# Cache hit - include message_id from original response for feedback support
|
|
366
|
+
return CacheResult(
|
|
367
|
+
response=response,
|
|
368
|
+
cache_hit=True,
|
|
369
|
+
served_by=self.name,
|
|
370
|
+
message_id=cached.message_id,
|
|
371
|
+
# LRU cache is in-memory only, no cache_entry_id for traceability
|
|
372
|
+
cache_entry_id=None,
|
|
373
|
+
)
|
|
286
374
|
|
|
287
375
|
# Cache miss - delegate to wrapped service
|
|
288
376
|
logger.info(
|
|
@@ -298,7 +386,7 @@ class LRUCacheService(GenieServiceBase):
|
|
|
298
386
|
|
|
299
387
|
result: CacheResult = self.impl.ask_question(question, conversation_id)
|
|
300
388
|
with self._lock:
|
|
301
|
-
self._put(key, result.response)
|
|
389
|
+
self._put(key, result.response, message_id=result.message_id)
|
|
302
390
|
# Propagate the inner cache's result - if it was a hit there, preserve that info
|
|
303
391
|
return result
|
|
304
392
|
|
|
@@ -306,6 +394,11 @@ class LRUCacheService(GenieServiceBase):
|
|
|
306
394
|
def space_id(self) -> str:
|
|
307
395
|
return self.impl.space_id
|
|
308
396
|
|
|
397
|
+
@property
|
|
398
|
+
def workspace_client(self) -> WorkspaceClient | None:
|
|
399
|
+
"""Get workspace client by delegating to impl."""
|
|
400
|
+
return self.impl.workspace_client
|
|
401
|
+
|
|
309
402
|
def invalidate(self, question: str, conversation_id: str | None = None) -> bool:
|
|
310
403
|
"""
|
|
311
404
|
Remove a specific entry from the cache.
|
|
@@ -349,3 +442,88 @@ class LRUCacheService(GenieServiceBase):
|
|
|
349
442
|
"expired_entries": expired,
|
|
350
443
|
"valid_entries": len(self._cache) - expired,
|
|
351
444
|
}
|
|
445
|
+
|
|
446
|
+
@mlflow.trace(name="genie_lru_cache_send_feedback")
|
|
447
|
+
def send_feedback(
|
|
448
|
+
self,
|
|
449
|
+
conversation_id: str,
|
|
450
|
+
rating: GenieFeedbackRating,
|
|
451
|
+
message_id: str | None = None,
|
|
452
|
+
was_cache_hit: bool = False,
|
|
453
|
+
) -> None:
|
|
454
|
+
"""
|
|
455
|
+
Send feedback for a Genie message with cache invalidation.
|
|
456
|
+
|
|
457
|
+
For LRU cache, this method:
|
|
458
|
+
1. If was_cache_hit is False: forwards feedback to the underlying service
|
|
459
|
+
2. If rating is NEGATIVE: invalidates any matching cache entries
|
|
460
|
+
|
|
461
|
+
Args:
|
|
462
|
+
conversation_id: The conversation containing the message
|
|
463
|
+
rating: The feedback rating (POSITIVE, NEGATIVE, or NONE)
|
|
464
|
+
message_id: Optional message ID. If None, looks up the most recent message.
|
|
465
|
+
was_cache_hit: Whether the response being rated was served from cache.
|
|
466
|
+
|
|
467
|
+
Note:
|
|
468
|
+
For cached responses (was_cache_hit=True), only cache invalidation is
|
|
469
|
+
performed. No feedback is sent to the Genie API because cached responses
|
|
470
|
+
don't have a corresponding Genie message.
|
|
471
|
+
|
|
472
|
+
Future Enhancement: To enable full Genie feedback for cached responses,
|
|
473
|
+
the cache would need to store the original message_id. See GenieServiceBase
|
|
474
|
+
docstring for details on required changes.
|
|
475
|
+
"""
|
|
476
|
+
# Handle cache invalidation on negative feedback
|
|
477
|
+
invalidated = False
|
|
478
|
+
if rating == GenieFeedbackRating.NEGATIVE:
|
|
479
|
+
# For LRU cache, we invalidate by conversation_id since that's part of the key
|
|
480
|
+
# Iterate through cache and remove entries matching the conversation_id
|
|
481
|
+
with self._lock:
|
|
482
|
+
keys_to_remove: list[str] = []
|
|
483
|
+
for key, entry in self._cache.items():
|
|
484
|
+
if entry.conversation_id == conversation_id:
|
|
485
|
+
keys_to_remove.append(key)
|
|
486
|
+
|
|
487
|
+
for key in keys_to_remove:
|
|
488
|
+
del self._cache[key]
|
|
489
|
+
invalidated = True
|
|
490
|
+
logger.info(
|
|
491
|
+
"Invalidated cache entry due to negative feedback",
|
|
492
|
+
layer=self.name,
|
|
493
|
+
cache_key=key[:50],
|
|
494
|
+
conversation_id=conversation_id,
|
|
495
|
+
)
|
|
496
|
+
|
|
497
|
+
if not keys_to_remove:
|
|
498
|
+
logger.debug(
|
|
499
|
+
"No cache entries found to invalidate for negative feedback",
|
|
500
|
+
layer=self.name,
|
|
501
|
+
conversation_id=conversation_id,
|
|
502
|
+
)
|
|
503
|
+
|
|
504
|
+
# Forward feedback to underlying service if not a cache hit
|
|
505
|
+
# For cache hits, there's no Genie message to provide feedback on
|
|
506
|
+
if was_cache_hit:
|
|
507
|
+
logger.info(
|
|
508
|
+
"Skipping Genie API feedback - response was served from cache",
|
|
509
|
+
layer=self.name,
|
|
510
|
+
conversation_id=conversation_id,
|
|
511
|
+
rating=rating.value if rating else None,
|
|
512
|
+
cache_invalidated=invalidated,
|
|
513
|
+
)
|
|
514
|
+
return
|
|
515
|
+
|
|
516
|
+
# Forward to underlying service
|
|
517
|
+
logger.debug(
|
|
518
|
+
"Forwarding feedback to underlying service",
|
|
519
|
+
layer=self.name,
|
|
520
|
+
conversation_id=conversation_id,
|
|
521
|
+
rating=rating.value if rating else None,
|
|
522
|
+
delegating_to=type(self.impl).__name__,
|
|
523
|
+
)
|
|
524
|
+
self.impl.send_feedback(
|
|
525
|
+
conversation_id=conversation_id,
|
|
526
|
+
rating=rating,
|
|
527
|
+
message_id=message_id,
|
|
528
|
+
was_cache_hit=False, # Already handled, so pass False
|
|
529
|
+
)
|
dao_ai/genie/core.py
CHANGED
|
@@ -1,35 +1,259 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Core Genie service implementation.
|
|
3
3
|
|
|
4
|
-
This module provides
|
|
5
|
-
|
|
4
|
+
This module provides:
|
|
5
|
+
- Extended Genie and GenieResponse classes that capture message_id
|
|
6
|
+
- GenieService: Concrete implementation of GenieServiceBase
|
|
7
|
+
|
|
8
|
+
The extended classes wrap the databricks_ai_bridge versions to add message_id
|
|
9
|
+
support, which is needed for sending feedback to the Genie API.
|
|
6
10
|
"""
|
|
7
11
|
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
from dataclasses import dataclass
|
|
15
|
+
from typing import TYPE_CHECKING, Union
|
|
16
|
+
|
|
8
17
|
import mlflow
|
|
9
|
-
|
|
18
|
+
import pandas as pd
|
|
19
|
+
from databricks.sdk import WorkspaceClient
|
|
20
|
+
from databricks.sdk.service.dashboards import GenieFeedbackRating
|
|
21
|
+
from databricks_ai_bridge.genie import Genie as DatabricksGenie
|
|
22
|
+
from databricks_ai_bridge.genie import GenieResponse as DatabricksGenieResponse
|
|
23
|
+
from loguru import logger
|
|
10
24
|
|
|
11
25
|
from dao_ai.genie.cache import CacheResult, GenieServiceBase
|
|
26
|
+
from dao_ai.genie.cache.base import get_latest_message_id
|
|
27
|
+
|
|
28
|
+
if TYPE_CHECKING:
|
|
29
|
+
from typing import Optional
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
# =============================================================================
|
|
33
|
+
# Extended Genie Classes with message_id Support
|
|
34
|
+
# =============================================================================
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass
|
|
38
|
+
class GenieResponse(DatabricksGenieResponse):
|
|
39
|
+
"""
|
|
40
|
+
Extended GenieResponse that includes message_id.
|
|
41
|
+
|
|
42
|
+
This extends the databricks_ai_bridge GenieResponse to capture the message_id
|
|
43
|
+
from API responses, which is required for sending feedback to the Genie API.
|
|
44
|
+
|
|
45
|
+
Attributes:
|
|
46
|
+
result: The query result as string or DataFrame
|
|
47
|
+
query: The generated SQL query
|
|
48
|
+
description: Description of the query
|
|
49
|
+
conversation_id: The conversation ID
|
|
50
|
+
message_id: The message ID (NEW - enables feedback without extra API call)
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
result: Union[str, pd.DataFrame] = ""
|
|
54
|
+
query: Optional[str] = ""
|
|
55
|
+
description: Optional[str] = ""
|
|
56
|
+
conversation_id: Optional[str] = None
|
|
57
|
+
message_id: Optional[str] = None
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class Genie(DatabricksGenie):
|
|
61
|
+
"""
|
|
62
|
+
Extended Genie that captures message_id in responses.
|
|
63
|
+
|
|
64
|
+
This extends the databricks_ai_bridge Genie to return GenieResponse objects
|
|
65
|
+
that include the message_id from the API response. This enables sending
|
|
66
|
+
feedback without requiring an additional API call to look up the message ID.
|
|
67
|
+
|
|
68
|
+
Usage:
|
|
69
|
+
genie = Genie(space_id="my-space")
|
|
70
|
+
response = genie.ask_question("What are total sales?")
|
|
71
|
+
print(response.message_id) # Now available!
|
|
72
|
+
|
|
73
|
+
The original databricks_ai_bridge classes are available as:
|
|
74
|
+
- DatabricksGenie
|
|
75
|
+
- DatabricksGenieResponse
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
def ask_question(
|
|
79
|
+
self, question: str, conversation_id: str | None = None
|
|
80
|
+
) -> GenieResponse:
|
|
81
|
+
"""
|
|
82
|
+
Ask a question and return response with message_id.
|
|
83
|
+
|
|
84
|
+
This overrides the parent method to capture the message_id from the
|
|
85
|
+
API response and include it in the returned GenieResponse.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
question: The question to ask
|
|
89
|
+
conversation_id: Optional conversation ID for follow-up questions
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
GenieResponse with message_id populated
|
|
93
|
+
"""
|
|
94
|
+
with mlflow.start_span(name="ask_question"):
|
|
95
|
+
# Start or continue conversation
|
|
96
|
+
if not conversation_id:
|
|
97
|
+
resp = self.start_conversation(question)
|
|
98
|
+
else:
|
|
99
|
+
resp = self.create_message(conversation_id, question)
|
|
100
|
+
|
|
101
|
+
# Capture message_id from the API response
|
|
102
|
+
message_id = resp.get("message_id")
|
|
103
|
+
|
|
104
|
+
# Poll for the result using parent's method
|
|
105
|
+
genie_response = self.poll_for_result(resp["conversation_id"], message_id)
|
|
106
|
+
|
|
107
|
+
# Ensure conversation_id is set
|
|
108
|
+
if not genie_response.conversation_id:
|
|
109
|
+
genie_response.conversation_id = resp["conversation_id"]
|
|
110
|
+
|
|
111
|
+
# Return our extended response with message_id
|
|
112
|
+
return GenieResponse(
|
|
113
|
+
result=genie_response.result,
|
|
114
|
+
query=genie_response.query,
|
|
115
|
+
description=genie_response.description,
|
|
116
|
+
conversation_id=genie_response.conversation_id,
|
|
117
|
+
message_id=message_id,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
# =============================================================================
|
|
122
|
+
# GenieService Implementation
|
|
123
|
+
# =============================================================================
|
|
12
124
|
|
|
13
125
|
|
|
14
126
|
class GenieService(GenieServiceBase):
|
|
15
|
-
"""
|
|
127
|
+
"""
|
|
128
|
+
Concrete implementation of GenieServiceBase using the extended Genie.
|
|
129
|
+
|
|
130
|
+
This service wraps the extended Genie class and provides the GenieServiceBase
|
|
131
|
+
interface for use with cache layers.
|
|
132
|
+
"""
|
|
16
133
|
|
|
17
134
|
genie: Genie
|
|
135
|
+
_workspace_client: WorkspaceClient | None
|
|
136
|
+
|
|
137
|
+
def __init__(
|
|
138
|
+
self,
|
|
139
|
+
genie: Genie | DatabricksGenie,
|
|
140
|
+
workspace_client: WorkspaceClient | None = None,
|
|
141
|
+
) -> None:
|
|
142
|
+
"""
|
|
143
|
+
Initialize the GenieService.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
genie: The Genie instance for asking questions. Can be either our
|
|
147
|
+
extended Genie or the original DatabricksGenie.
|
|
148
|
+
workspace_client: Optional WorkspaceClient for feedback API.
|
|
149
|
+
If not provided, one will be created lazily when needed.
|
|
150
|
+
"""
|
|
151
|
+
self.genie = genie # type: ignore[assignment]
|
|
152
|
+
self._workspace_client = workspace_client
|
|
153
|
+
|
|
154
|
+
@property
|
|
155
|
+
def workspace_client(self) -> WorkspaceClient:
|
|
156
|
+
"""
|
|
157
|
+
Get or create a WorkspaceClient for API calls.
|
|
18
158
|
|
|
19
|
-
|
|
20
|
-
|
|
159
|
+
Lazily creates a WorkspaceClient using default credentials if not provided.
|
|
160
|
+
"""
|
|
161
|
+
if self._workspace_client is None:
|
|
162
|
+
self._workspace_client = WorkspaceClient()
|
|
163
|
+
return self._workspace_client
|
|
21
164
|
|
|
22
165
|
@mlflow.trace(name="genie_ask_question")
|
|
23
166
|
def ask_question(
|
|
24
167
|
self, question: str, conversation_id: str | None = None
|
|
25
168
|
) -> CacheResult:
|
|
26
|
-
"""
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
169
|
+
"""
|
|
170
|
+
Ask question to Genie and return CacheResult.
|
|
171
|
+
|
|
172
|
+
No caching at this level - returns cache miss with fresh response.
|
|
173
|
+
If using our extended Genie, the message_id will be captured in the response.
|
|
174
|
+
"""
|
|
175
|
+
response = self.genie.ask_question(question, conversation_id=conversation_id)
|
|
176
|
+
|
|
177
|
+
# Extract message_id if available (from our extended GenieResponse)
|
|
178
|
+
message_id = getattr(response, "message_id", None)
|
|
179
|
+
|
|
30
180
|
# No caching at this level - return cache miss
|
|
31
|
-
return CacheResult(
|
|
181
|
+
return CacheResult(
|
|
182
|
+
response=response,
|
|
183
|
+
cache_hit=False,
|
|
184
|
+
served_by=None,
|
|
185
|
+
message_id=message_id,
|
|
186
|
+
)
|
|
32
187
|
|
|
33
188
|
@property
|
|
34
189
|
def space_id(self) -> str:
|
|
35
190
|
return self.genie.space_id
|
|
191
|
+
|
|
192
|
+
@mlflow.trace(name="genie_send_feedback")
|
|
193
|
+
def send_feedback(
|
|
194
|
+
self,
|
|
195
|
+
conversation_id: str,
|
|
196
|
+
rating: GenieFeedbackRating,
|
|
197
|
+
message_id: str | None = None,
|
|
198
|
+
was_cache_hit: bool = False,
|
|
199
|
+
) -> None:
|
|
200
|
+
"""
|
|
201
|
+
Send feedback for a Genie message.
|
|
202
|
+
|
|
203
|
+
For the core GenieService, this always sends feedback to the Genie API
|
|
204
|
+
(the was_cache_hit parameter is ignored here - it's used by cache wrappers).
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
conversation_id: The conversation containing the message
|
|
208
|
+
rating: The feedback rating (POSITIVE, NEGATIVE, or NONE)
|
|
209
|
+
message_id: Optional message ID. If None, looks up the most recent message.
|
|
210
|
+
was_cache_hit: Ignored by GenieService. Cache wrappers use this to decide
|
|
211
|
+
whether to forward feedback to the underlying service.
|
|
212
|
+
"""
|
|
213
|
+
# Look up message_id if not provided
|
|
214
|
+
if message_id is None:
|
|
215
|
+
message_id = get_latest_message_id(
|
|
216
|
+
workspace_client=self.workspace_client,
|
|
217
|
+
space_id=self.space_id,
|
|
218
|
+
conversation_id=conversation_id,
|
|
219
|
+
)
|
|
220
|
+
if message_id is None:
|
|
221
|
+
logger.warning(
|
|
222
|
+
"Could not find message_id for feedback, skipping",
|
|
223
|
+
space_id=self.space_id,
|
|
224
|
+
conversation_id=conversation_id,
|
|
225
|
+
rating=rating.value if rating else None,
|
|
226
|
+
)
|
|
227
|
+
return
|
|
228
|
+
|
|
229
|
+
logger.info(
|
|
230
|
+
"Sending feedback to Genie",
|
|
231
|
+
space_id=self.space_id,
|
|
232
|
+
conversation_id=conversation_id,
|
|
233
|
+
message_id=message_id,
|
|
234
|
+
rating=rating.value if rating else None,
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
try:
|
|
238
|
+
self.workspace_client.genie.send_message_feedback(
|
|
239
|
+
space_id=self.space_id,
|
|
240
|
+
conversation_id=conversation_id,
|
|
241
|
+
message_id=message_id,
|
|
242
|
+
rating=rating,
|
|
243
|
+
)
|
|
244
|
+
logger.debug(
|
|
245
|
+
"Feedback sent successfully",
|
|
246
|
+
space_id=self.space_id,
|
|
247
|
+
conversation_id=conversation_id,
|
|
248
|
+
message_id=message_id,
|
|
249
|
+
)
|
|
250
|
+
except Exception as e:
|
|
251
|
+
logger.error(
|
|
252
|
+
"Failed to send feedback to Genie",
|
|
253
|
+
space_id=self.space_id,
|
|
254
|
+
conversation_id=conversation_id,
|
|
255
|
+
message_id=message_id,
|
|
256
|
+
rating=rating.value if rating else None,
|
|
257
|
+
error=str(e),
|
|
258
|
+
exc_info=True,
|
|
259
|
+
)
|
dao_ai/middleware/__init__.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
# DAO AI Middleware Module
|
|
2
|
-
#
|
|
2
|
+
# Middleware implementations compatible with LangChain v1's create_agent
|
|
3
3
|
|
|
4
4
|
# Re-export LangChain built-in middleware
|
|
5
5
|
from langchain.agents.middleware import (
|
|
@@ -82,6 +82,10 @@ from dao_ai.middleware.summarization import (
|
|
|
82
82
|
create_summarization_middleware,
|
|
83
83
|
)
|
|
84
84
|
from dao_ai.middleware.tool_call_limit import create_tool_call_limit_middleware
|
|
85
|
+
from dao_ai.middleware.tool_call_observability import (
|
|
86
|
+
ToolCallObservabilityMiddleware,
|
|
87
|
+
create_tool_call_observability_middleware,
|
|
88
|
+
)
|
|
85
89
|
from dao_ai.middleware.tool_retry import create_tool_retry_middleware
|
|
86
90
|
from dao_ai.middleware.tool_selector import create_llm_tool_selector_middleware
|
|
87
91
|
|
|
@@ -160,4 +164,7 @@ __all__ = [
|
|
|
160
164
|
"create_clear_tool_uses_edit",
|
|
161
165
|
# PII middleware factory functions
|
|
162
166
|
"create_pii_middleware",
|
|
167
|
+
# Tool call observability middleware
|
|
168
|
+
"ToolCallObservabilityMiddleware",
|
|
169
|
+
"create_tool_call_observability_middleware",
|
|
163
170
|
]
|