dao-ai 0.1.2__py3-none-any.whl → 0.1.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. dao_ai/apps/__init__.py +24 -0
  2. dao_ai/apps/handlers.py +105 -0
  3. dao_ai/apps/model_serving.py +29 -0
  4. dao_ai/apps/resources.py +1122 -0
  5. dao_ai/apps/server.py +39 -0
  6. dao_ai/cli.py +546 -37
  7. dao_ai/config.py +1179 -139
  8. dao_ai/evaluation.py +543 -0
  9. dao_ai/genie/__init__.py +55 -7
  10. dao_ai/genie/cache/__init__.py +34 -7
  11. dao_ai/genie/cache/base.py +143 -2
  12. dao_ai/genie/cache/context_aware/__init__.py +31 -0
  13. dao_ai/genie/cache/context_aware/base.py +1151 -0
  14. dao_ai/genie/cache/context_aware/in_memory.py +609 -0
  15. dao_ai/genie/cache/context_aware/persistent.py +802 -0
  16. dao_ai/genie/cache/context_aware/postgres.py +1166 -0
  17. dao_ai/genie/cache/core.py +1 -1
  18. dao_ai/genie/cache/lru.py +257 -75
  19. dao_ai/genie/cache/optimization.py +890 -0
  20. dao_ai/genie/core.py +235 -11
  21. dao_ai/memory/postgres.py +175 -39
  22. dao_ai/middleware/__init__.py +38 -0
  23. dao_ai/middleware/assertions.py +3 -3
  24. dao_ai/middleware/context_editing.py +230 -0
  25. dao_ai/middleware/core.py +4 -4
  26. dao_ai/middleware/guardrails.py +3 -3
  27. dao_ai/middleware/human_in_the_loop.py +3 -2
  28. dao_ai/middleware/message_validation.py +4 -4
  29. dao_ai/middleware/model_call_limit.py +77 -0
  30. dao_ai/middleware/model_retry.py +121 -0
  31. dao_ai/middleware/pii.py +157 -0
  32. dao_ai/middleware/summarization.py +1 -1
  33. dao_ai/middleware/tool_call_limit.py +210 -0
  34. dao_ai/middleware/tool_retry.py +174 -0
  35. dao_ai/middleware/tool_selector.py +129 -0
  36. dao_ai/models.py +327 -370
  37. dao_ai/nodes.py +9 -16
  38. dao_ai/orchestration/core.py +33 -9
  39. dao_ai/orchestration/supervisor.py +29 -13
  40. dao_ai/orchestration/swarm.py +6 -1
  41. dao_ai/{prompts.py → prompts/__init__.py} +12 -61
  42. dao_ai/prompts/instructed_retriever_decomposition.yaml +58 -0
  43. dao_ai/prompts/instruction_reranker.yaml +14 -0
  44. dao_ai/prompts/router.yaml +37 -0
  45. dao_ai/prompts/verifier.yaml +46 -0
  46. dao_ai/providers/base.py +28 -2
  47. dao_ai/providers/databricks.py +363 -33
  48. dao_ai/state.py +1 -0
  49. dao_ai/tools/__init__.py +5 -3
  50. dao_ai/tools/genie.py +103 -26
  51. dao_ai/tools/instructed_retriever.py +366 -0
  52. dao_ai/tools/instruction_reranker.py +202 -0
  53. dao_ai/tools/mcp.py +539 -97
  54. dao_ai/tools/router.py +89 -0
  55. dao_ai/tools/slack.py +13 -2
  56. dao_ai/tools/sql.py +7 -3
  57. dao_ai/tools/unity_catalog.py +32 -10
  58. dao_ai/tools/vector_search.py +493 -160
  59. dao_ai/tools/verifier.py +159 -0
  60. dao_ai/utils.py +182 -2
  61. dao_ai/vector_search.py +46 -1
  62. {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/METADATA +45 -9
  63. dao_ai-0.1.20.dist-info/RECORD +89 -0
  64. dao_ai/agent_as_code.py +0 -22
  65. dao_ai/genie/cache/semantic.py +0 -970
  66. dao_ai-0.1.2.dist-info/RECORD +0 -64
  67. {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/WHEEL +0 -0
  68. {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/entry_points.txt +0 -0
  69. {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/licenses/LICENSE +0 -0
@@ -38,7 +38,7 @@ def execute_sql_via_warehouse(
38
38
  w: WorkspaceClient = warehouse.workspace_client
39
39
  warehouse_id: str = str(warehouse.warehouse_id)
40
40
 
41
- logger.trace("Executing cached SQL", layer=layer_name, sql_prefix=sql[:100])
41
+ logger.trace("Executing cached SQL", layer=layer_name, sql=sql[:100])
42
42
 
43
43
  statement_response: StatementResponse = w.statement_execution.execute_statement(
44
44
  statement=sql,
dao_ai/genie/cache/lru.py CHANGED
@@ -6,15 +6,16 @@ by Genie. On cache hit, the cached SQL is re-executed against the warehouse
6
6
  to return fresh data while avoiding the Genie NL-to-SQL translation cost.
7
7
  """
8
8
 
9
+ from __future__ import annotations
10
+
9
11
  from collections import OrderedDict
10
12
  from datetime import datetime, timedelta
11
13
  from threading import Lock
12
- from typing import Any
13
14
 
14
15
  import mlflow
15
16
  import pandas as pd
16
17
  from databricks.sdk import WorkspaceClient
17
- from databricks.sdk.service.sql import StatementResponse, StatementState
18
+ from databricks.sdk.service.dashboards import GenieFeedbackRating
18
19
  from databricks_ai_bridge.genie import GenieResponse
19
20
  from loguru import logger
20
21
 
@@ -24,6 +25,7 @@ from dao_ai.genie.cache.base import (
24
25
  GenieServiceBase,
25
26
  SQLCacheEntry,
26
27
  )
28
+ from dao_ai.genie.cache.core import execute_sql_via_warehouse
27
29
 
28
30
 
29
31
  class LRUCacheService(GenieServiceBase):
@@ -124,9 +126,7 @@ class LRUCacheService(GenieServiceBase):
124
126
  if self._cache:
125
127
  oldest_key: str = next(iter(self._cache))
126
128
  del self._cache[oldest_key]
127
- logger.trace(
128
- "Evicted cache entry", layer=self.name, key_prefix=oldest_key[:50]
129
- )
129
+ logger.trace("Evicted cache entry", layer=self.name, key=oldest_key[:50])
130
130
 
131
131
  def _get(self, key: str) -> SQLCacheEntry | None:
132
132
  """Get from cache, returning None if not found or expired."""
@@ -137,14 +137,26 @@ class LRUCacheService(GenieServiceBase):
137
137
 
138
138
  if self._is_expired(entry):
139
139
  del self._cache[key]
140
- logger.trace("Expired cache entry", layer=self.name, key_prefix=key[:50])
140
+ logger.trace("Expired cache entry", layer=self.name, key=key[:50])
141
141
  return None
142
142
 
143
143
  self._cache.move_to_end(key)
144
144
  return entry
145
145
 
146
- def _put(self, key: str, response: GenieResponse) -> None:
146
+ def _put(
147
+ self, key: str, response: GenieResponse, message_id: str | None = None
148
+ ) -> None:
147
149
  """Store SQL query in cache, evicting if at capacity."""
150
+ # Skip caching if query is empty or whitespace
151
+ if not response.query or not response.query.strip():
152
+ logger.warning(
153
+ "Not caching: response has no SQL query",
154
+ layer=self.name,
155
+ key=key[:50],
156
+ description=response.description[:80] if response.description else None,
157
+ )
158
+ return
159
+
148
160
  if key in self._cache:
149
161
  del self._cache[key]
150
162
 
@@ -156,14 +168,18 @@ class LRUCacheService(GenieServiceBase):
156
168
  description=response.description,
157
169
  conversation_id=response.conversation_id,
158
170
  created_at=datetime.now(),
171
+ message_id=message_id,
172
+ # LRU cache is in-memory only, no database row ID
173
+ cache_entry_id=None,
159
174
  )
160
- logger.info(
175
+ logger.debug(
161
176
  "Stored cache entry",
162
177
  layer=self.name,
163
- key_prefix=key[:50],
164
- sql_prefix=response.query[:50] if response.query else None,
178
+ key=key[:50],
179
+ sql=response.query[:50] if response.query else None,
165
180
  cache_size=len(self._cache),
166
181
  capacity=self.capacity,
182
+ message_id=message_id,
167
183
  )
168
184
 
169
185
  @mlflow.trace(name="execute_cached_sql")
@@ -177,50 +193,22 @@ class LRUCacheService(GenieServiceBase):
177
193
  Returns:
178
194
  DataFrame with results, or error message string
179
195
  """
180
- w: WorkspaceClient = self.warehouse.workspace_client
181
- warehouse_id: str = str(self.warehouse.warehouse_id)
182
-
183
- logger.trace("Executing cached SQL", layer=self.name, sql_prefix=sql[:100])
184
-
185
- statement_response: StatementResponse = w.statement_execution.execute_statement(
186
- statement=sql,
187
- warehouse_id=warehouse_id,
188
- wait_timeout="30s",
189
- )
190
-
191
- # Poll for completion if still running
192
- while statement_response.status.state in [
193
- StatementState.PENDING,
194
- StatementState.RUNNING,
195
- ]:
196
- statement_response = w.statement_execution.get_statement(
197
- statement_response.statement_id
198
- )
199
-
200
- if statement_response.status.state != StatementState.SUCCEEDED:
201
- error_msg: str = f"SQL execution failed: {statement_response.status}"
196
+ # Validate SQL is not empty
197
+ if not sql or not sql.strip():
198
+ error_msg: str = "Cannot execute empty SQL query"
202
199
  logger.error(
203
- "SQL execution failed",
200
+ "SQL execution failed: empty query",
204
201
  layer=self.name,
205
- status=str(statement_response.status),
202
+ sql=repr(sql),
206
203
  )
207
204
  return error_msg
208
205
 
209
- # Convert to DataFrame
210
- if statement_response.result and statement_response.result.data_array:
211
- columns: list[str] = []
212
- if statement_response.manifest and statement_response.manifest.schema:
213
- columns = [
214
- col.name for col in statement_response.manifest.schema.columns
215
- ]
216
-
217
- data: list[list[Any]] = statement_response.result.data_array
218
- if columns:
219
- return pd.DataFrame(data, columns=columns)
220
- else:
221
- return pd.DataFrame(data)
222
-
223
- return pd.DataFrame()
206
+ # Use shared utility function for SQL execution
207
+ return execute_sql_via_warehouse(
208
+ warehouse=self.warehouse,
209
+ sql=sql,
210
+ layer_name=self.name,
211
+ )
224
212
 
225
213
  def ask_question(
226
214
  self, question: str, conversation_id: str | None = None
@@ -258,50 +246,159 @@ class LRUCacheService(GenieServiceBase):
258
246
  cached: SQLCacheEntry | None = self._get(key)
259
247
 
260
248
  if cached is not None:
261
- logger.info(
262
- "Cache HIT",
263
- layer=self.name,
264
- question_prefix=question[:50],
265
- conversation_id=conversation_id,
266
- cache_size=self.size,
267
- capacity=self.capacity,
268
- )
269
-
270
- # Re-execute the cached SQL to get fresh data
271
- result: pd.DataFrame | str = self._execute_sql(cached.query)
272
-
273
- # Use current conversation_id, not the cached one
274
- response: GenieResponse = GenieResponse(
275
- result=result,
276
- query=cached.query,
277
- description=cached.description,
278
- conversation_id=conversation_id
279
- if conversation_id
280
- else cached.conversation_id,
281
- )
282
-
283
- return CacheResult(response=response, cache_hit=True, served_by=self.name)
249
+ # Defensive check: if cached query is empty, treat as cache miss
250
+ if not cached.query or not cached.query.strip():
251
+ logger.warning(
252
+ "Cache HIT but query is empty, treating as MISS",
253
+ layer=self.name,
254
+ question=question[:80],
255
+ conversation_id=conversation_id,
256
+ key=key[:50],
257
+ )
258
+ # Invalidate this bad cache entry
259
+ with self._lock:
260
+ if key in self._cache:
261
+ del self._cache[key]
262
+ # Fall through to cache miss logic below
263
+ else:
264
+ cache_age_seconds = (datetime.now() - cached.created_at).total_seconds()
265
+ logger.info(
266
+ "Cache HIT",
267
+ layer=self.name,
268
+ question=question[:80],
269
+ conversation_id=conversation_id,
270
+ cached_sql=cached.query[:80] if cached.query else None,
271
+ cache_age_seconds=round(cache_age_seconds, 1),
272
+ cache_size=self.size,
273
+ capacity=self.capacity,
274
+ ttl_seconds=self.parameters.time_to_live_seconds,
275
+ )
276
+
277
+ # Re-execute the cached SQL to get fresh data
278
+ result: pd.DataFrame | str = self._execute_sql(cached.query)
279
+
280
+ # Check if SQL execution failed (returns error string instead of DataFrame)
281
+ if isinstance(result, str):
282
+ logger.warning(
283
+ "Cached SQL execution failed, falling back to Genie",
284
+ layer=self.name,
285
+ question=question[:80],
286
+ conversation_id=conversation_id,
287
+ cached_sql=cached.query[:80],
288
+ error=result[:200],
289
+ cache_key=key[:50],
290
+ )
291
+
292
+ # Invalidate the bad cache entry
293
+ with self._lock:
294
+ if key in self._cache:
295
+ del self._cache[key]
296
+ logger.info(
297
+ "Invalidated stale cache entry",
298
+ layer=self.name,
299
+ cache_key=key[:50],
300
+ cache_size=len(self._cache),
301
+ capacity=self.capacity,
302
+ )
303
+
304
+ # Fall back to Genie to get fresh SQL
305
+ logger.info(
306
+ "Delegating to Genie for fresh SQL",
307
+ layer=self.name,
308
+ question=question[:80],
309
+ delegating_to=type(self.impl).__name__,
310
+ )
311
+ fallback_result: CacheResult = self.impl.ask_question(
312
+ question, conversation_id
313
+ )
314
+
315
+ # Store the fresh SQL in cache (including message_id for feedback)
316
+ if fallback_result.response.query:
317
+ with self._lock:
318
+ self._put(
319
+ key,
320
+ fallback_result.response,
321
+ message_id=fallback_result.message_id,
322
+ )
323
+ logger.info(
324
+ "Stored fresh SQL from fallback",
325
+ layer=self.name,
326
+ fresh_sql=fallback_result.response.query[:80],
327
+ cache_size=len(self._cache),
328
+ capacity=self.capacity,
329
+ message_id=fallback_result.message_id,
330
+ )
331
+ else:
332
+ logger.warning(
333
+ "Fallback response has no SQL query to cache",
334
+ layer=self.name,
335
+ question=question[:80],
336
+ )
337
+
338
+ logger.info(
339
+ "Fallback completed successfully",
340
+ layer=self.name,
341
+ question=question[:80],
342
+ fallback_from="stale_cache",
343
+ has_result=fallback_result.response.result is not None,
344
+ )
345
+
346
+ # Return as cache miss (fallback scenario)
347
+ # Propagate message_id from fallback result
348
+ return CacheResult(
349
+ response=fallback_result.response,
350
+ cache_hit=False,
351
+ served_by=None,
352
+ message_id=fallback_result.message_id,
353
+ )
354
+
355
+ # Use current conversation_id, not the cached one
356
+ response: GenieResponse = GenieResponse(
357
+ result=result,
358
+ query=cached.query,
359
+ description=cached.description,
360
+ conversation_id=conversation_id
361
+ if conversation_id
362
+ else cached.conversation_id,
363
+ )
364
+
365
+ # Cache hit - include message_id from original response for feedback support
366
+ return CacheResult(
367
+ response=response,
368
+ cache_hit=True,
369
+ served_by=self.name,
370
+ message_id=cached.message_id,
371
+ # LRU cache is in-memory only, no cache_entry_id for traceability
372
+ cache_entry_id=None,
373
+ )
284
374
 
285
375
  # Cache miss - delegate to wrapped service
286
376
  logger.info(
287
377
  "Cache MISS",
288
378
  layer=self.name,
289
- question_prefix=question[:50],
379
+ question=question[:80],
290
380
  conversation_id=conversation_id,
291
381
  cache_size=self.size,
292
382
  capacity=self.capacity,
383
+ ttl_seconds=self.parameters.time_to_live_seconds,
293
384
  delegating_to=type(self.impl).__name__,
294
385
  )
295
386
 
296
387
  result: CacheResult = self.impl.ask_question(question, conversation_id)
297
388
  with self._lock:
298
- self._put(key, result.response)
299
- return CacheResult(response=result.response, cache_hit=False, served_by=None)
389
+ self._put(key, result.response, message_id=result.message_id)
390
+ # Propagate the inner cache's result - if it was a hit there, preserve that info
391
+ return result
300
392
 
301
393
  @property
302
394
  def space_id(self) -> str:
303
395
  return self.impl.space_id
304
396
 
397
+ @property
398
+ def workspace_client(self) -> WorkspaceClient | None:
399
+ """Get workspace client by delegating to impl."""
400
+ return self.impl.workspace_client
401
+
305
402
  def invalidate(self, question: str, conversation_id: str | None = None) -> bool:
306
403
  """
307
404
  Remove a specific entry from the cache.
@@ -345,3 +442,88 @@ class LRUCacheService(GenieServiceBase):
345
442
  "expired_entries": expired,
346
443
  "valid_entries": len(self._cache) - expired,
347
444
  }
445
+
446
+ @mlflow.trace(name="genie_lru_cache_send_feedback")
447
+ def send_feedback(
448
+ self,
449
+ conversation_id: str,
450
+ rating: GenieFeedbackRating,
451
+ message_id: str | None = None,
452
+ was_cache_hit: bool = False,
453
+ ) -> None:
454
+ """
455
+ Send feedback for a Genie message with cache invalidation.
456
+
457
+ For LRU cache, this method:
458
+ 1. If was_cache_hit is False: forwards feedback to the underlying service
459
+ 2. If rating is NEGATIVE: invalidates any matching cache entries
460
+
461
+ Args:
462
+ conversation_id: The conversation containing the message
463
+ rating: The feedback rating (POSITIVE, NEGATIVE, or NONE)
464
+ message_id: Optional message ID. If None, looks up the most recent message.
465
+ was_cache_hit: Whether the response being rated was served from cache.
466
+
467
+ Note:
468
+ For cached responses (was_cache_hit=True), only cache invalidation is
469
+ performed. No feedback is sent to the Genie API because cached responses
470
+ don't have a corresponding Genie message.
471
+
472
+ Future Enhancement: To enable full Genie feedback for cached responses,
473
+ the cache would need to store the original message_id. See GenieServiceBase
474
+ docstring for details on required changes.
475
+ """
476
+ # Handle cache invalidation on negative feedback
477
+ invalidated = False
478
+ if rating == GenieFeedbackRating.NEGATIVE:
479
+ # For LRU cache, we invalidate by conversation_id since that's part of the key
480
+ # Iterate through cache and remove entries matching the conversation_id
481
+ with self._lock:
482
+ keys_to_remove: list[str] = []
483
+ for key, entry in self._cache.items():
484
+ if entry.conversation_id == conversation_id:
485
+ keys_to_remove.append(key)
486
+
487
+ for key in keys_to_remove:
488
+ del self._cache[key]
489
+ invalidated = True
490
+ logger.info(
491
+ "Invalidated cache entry due to negative feedback",
492
+ layer=self.name,
493
+ cache_key=key[:50],
494
+ conversation_id=conversation_id,
495
+ )
496
+
497
+ if not keys_to_remove:
498
+ logger.debug(
499
+ "No cache entries found to invalidate for negative feedback",
500
+ layer=self.name,
501
+ conversation_id=conversation_id,
502
+ )
503
+
504
+ # Forward feedback to underlying service if not a cache hit
505
+ # For cache hits, there's no Genie message to provide feedback on
506
+ if was_cache_hit:
507
+ logger.info(
508
+ "Skipping Genie API feedback - response was served from cache",
509
+ layer=self.name,
510
+ conversation_id=conversation_id,
511
+ rating=rating.value if rating else None,
512
+ cache_invalidated=invalidated,
513
+ )
514
+ return
515
+
516
+ # Forward to underlying service
517
+ logger.debug(
518
+ "Forwarding feedback to underlying service",
519
+ layer=self.name,
520
+ conversation_id=conversation_id,
521
+ rating=rating.value if rating else None,
522
+ delegating_to=type(self.impl).__name__,
523
+ )
524
+ self.impl.send_feedback(
525
+ conversation_id=conversation_id,
526
+ rating=rating,
527
+ message_id=message_id,
528
+ was_cache_hit=False, # Already handled, so pass False
529
+ )