dao-ai 0.0.36__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. dao_ai/__init__.py +29 -0
  2. dao_ai/cli.py +195 -30
  3. dao_ai/config.py +770 -244
  4. dao_ai/genie/__init__.py +1 -22
  5. dao_ai/genie/cache/__init__.py +1 -2
  6. dao_ai/genie/cache/base.py +20 -70
  7. dao_ai/genie/cache/core.py +75 -0
  8. dao_ai/genie/cache/lru.py +44 -21
  9. dao_ai/genie/cache/semantic.py +390 -109
  10. dao_ai/genie/core.py +35 -0
  11. dao_ai/graph.py +27 -253
  12. dao_ai/hooks/__init__.py +9 -6
  13. dao_ai/hooks/core.py +22 -190
  14. dao_ai/memory/__init__.py +10 -0
  15. dao_ai/memory/core.py +23 -5
  16. dao_ai/memory/databricks.py +389 -0
  17. dao_ai/memory/postgres.py +2 -2
  18. dao_ai/messages.py +6 -4
  19. dao_ai/middleware/__init__.py +125 -0
  20. dao_ai/middleware/assertions.py +778 -0
  21. dao_ai/middleware/base.py +50 -0
  22. dao_ai/middleware/core.py +61 -0
  23. dao_ai/middleware/guardrails.py +415 -0
  24. dao_ai/middleware/human_in_the_loop.py +228 -0
  25. dao_ai/middleware/message_validation.py +554 -0
  26. dao_ai/middleware/summarization.py +192 -0
  27. dao_ai/models.py +1177 -108
  28. dao_ai/nodes.py +118 -161
  29. dao_ai/optimization.py +664 -0
  30. dao_ai/orchestration/__init__.py +52 -0
  31. dao_ai/orchestration/core.py +287 -0
  32. dao_ai/orchestration/supervisor.py +264 -0
  33. dao_ai/orchestration/swarm.py +226 -0
  34. dao_ai/prompts.py +126 -29
  35. dao_ai/providers/databricks.py +126 -381
  36. dao_ai/state.py +139 -21
  37. dao_ai/tools/__init__.py +8 -5
  38. dao_ai/tools/core.py +57 -4
  39. dao_ai/tools/email.py +280 -0
  40. dao_ai/tools/genie.py +47 -24
  41. dao_ai/tools/mcp.py +4 -3
  42. dao_ai/tools/memory.py +50 -0
  43. dao_ai/tools/python.py +4 -12
  44. dao_ai/tools/search.py +14 -0
  45. dao_ai/tools/slack.py +1 -1
  46. dao_ai/tools/unity_catalog.py +8 -6
  47. dao_ai/tools/vector_search.py +16 -9
  48. dao_ai/utils.py +72 -8
  49. dao_ai-0.1.0.dist-info/METADATA +1878 -0
  50. dao_ai-0.1.0.dist-info/RECORD +62 -0
  51. dao_ai/chat_models.py +0 -204
  52. dao_ai/guardrails.py +0 -112
  53. dao_ai/tools/genie/__init__.py +0 -236
  54. dao_ai/tools/human_in_the_loop.py +0 -100
  55. dao_ai-0.0.36.dist-info/METADATA +0 -951
  56. dao_ai-0.0.36.dist-info/RECORD +0 -47
  57. {dao_ai-0.0.36.dist-info → dao_ai-0.1.0.dist-info}/WHEEL +0 -0
  58. {dao_ai-0.0.36.dist-info → dao_ai-0.1.0.dist-info}/entry_points.txt +0 -0
  59. {dao_ai-0.0.36.dist-info → dao_ai-0.1.0.dist-info}/licenses/LICENSE +0 -0
dao_ai/genie/__init__.py CHANGED
@@ -17,9 +17,6 @@ Example usage:
17
17
  from dao_ai.genie.cache import LRUCacheService, SemanticCacheService
18
18
  """
19
19
 
20
- import mlflow
21
- from databricks_ai_bridge.genie import Genie, GenieResponse
22
-
23
20
  from dao_ai.genie.cache import (
24
21
  CacheResult,
25
22
  GenieServiceBase,
@@ -27,25 +24,7 @@ from dao_ai.genie.cache import (
27
24
  SemanticCacheService,
28
25
  SQLCacheEntry,
29
26
  )
30
-
31
-
32
- class GenieService(GenieServiceBase):
33
- """Concrete implementation of GenieServiceBase using the Genie SDK."""
34
-
35
- genie: Genie
36
-
37
- def __init__(self, genie: Genie) -> None:
38
- self.genie = genie
39
-
40
- @mlflow.trace(name="genie_ask_question")
41
- def ask_question(
42
- self, question: str, conversation_id: str | None = None
43
- ) -> GenieResponse:
44
- response: GenieResponse = self.genie.ask_question(
45
- question, conversation_id=conversation_id
46
- )
47
- return response
48
-
27
+ from dao_ai.genie.core import GenieService
49
28
 
50
29
  __all__ = [
51
30
  # Service classes
@@ -15,7 +15,6 @@ Example usage:
15
15
  genie_service = SemanticCacheService(
16
16
  impl=GenieService(genie),
17
17
  parameters=semantic_params,
18
- genie_space_id=space_id,
19
18
  )
20
19
  genie_service = LRUCacheService(
21
20
  impl=genie_service,
@@ -27,8 +26,8 @@ from dao_ai.genie.cache.base import (
27
26
  CacheResult,
28
27
  GenieServiceBase,
29
28
  SQLCacheEntry,
30
- execute_sql_via_warehouse,
31
29
  )
30
+ from dao_ai.genie.cache.core import execute_sql_via_warehouse
32
31
  from dao_ai.genie.cache.lru import LRUCacheService
33
32
  from dao_ai.genie.cache.semantic import SemanticCacheService
34
33
 
@@ -2,21 +2,21 @@
2
2
  Base classes and types for Genie cache implementations.
3
3
 
4
4
  This module provides the foundational types used across different cache
5
- implementations (LRU, Semantic, etc.).
5
+ implementations (LRU, Semantic, etc.). It contains only abstract base classes
6
+ and data structures - no concrete implementations.
6
7
  """
7
8
 
9
+ from __future__ import annotations
10
+
8
11
  from abc import ABC, abstractmethod
9
12
  from dataclasses import dataclass
10
13
  from datetime import datetime
11
- from typing import Any
14
+ from typing import TYPE_CHECKING
12
15
 
13
- import pandas as pd
14
- from databricks.sdk import WorkspaceClient
15
- from databricks.sdk.service.sql import StatementResponse, StatementState
16
16
  from databricks_ai_bridge.genie import GenieResponse
17
- from loguru import logger
18
17
 
19
- from dao_ai.config import WarehouseModel
18
+ if TYPE_CHECKING:
19
+ from dao_ai.genie.cache.base import CacheResult
20
20
 
21
21
 
22
22
  class GenieServiceBase(ABC):
@@ -25,8 +25,19 @@ class GenieServiceBase(ABC):
25
25
  @abstractmethod
26
26
  def ask_question(
27
27
  self, question: str, conversation_id: str | None = None
28
- ) -> GenieResponse:
29
- """Ask a question to Genie and return the response."""
28
+ ) -> "CacheResult":
29
+ """
30
+ Ask a question to Genie and return the response with cache metadata.
31
+
32
+ All implementations return CacheResult to provide consistent cache information,
33
+ even when caching is disabled (cache_hit=False, served_by=None).
34
+ """
35
+ pass
36
+
37
+ @property
38
+ @abstractmethod
39
+ def space_id(self) -> str:
40
+ """The space ID for the Genie service."""
30
41
  pass
31
42
 
32
43
 
@@ -59,64 +70,3 @@ class CacheResult:
59
70
  response: GenieResponse
60
71
  cache_hit: bool
61
72
  served_by: str | None = None
62
-
63
-
64
- def execute_sql_via_warehouse(
65
- warehouse: WarehouseModel,
66
- sql: str,
67
- layer_name: str = "cache",
68
- ) -> pd.DataFrame | str:
69
- """
70
- Execute SQL using a Databricks warehouse and return results as DataFrame.
71
-
72
- This is a shared utility for cache implementations that need to re-execute
73
- cached SQL queries.
74
-
75
- Args:
76
- warehouse: The warehouse configuration for SQL execution
77
- sql: The SQL query to execute
78
- layer_name: Name of the cache layer (for logging)
79
-
80
- Returns:
81
- DataFrame with results, or error message string
82
- """
83
- w: WorkspaceClient = warehouse.workspace_client
84
- warehouse_id: str = str(warehouse.warehouse_id)
85
-
86
- logger.debug(f"[{layer_name}] Executing cached SQL: {sql[:100]}...")
87
-
88
- statement_response: StatementResponse = w.statement_execution.execute_statement(
89
- statement=sql,
90
- warehouse_id=warehouse_id,
91
- wait_timeout="30s",
92
- )
93
-
94
- # Poll for completion if still running
95
- while statement_response.status.state in [
96
- StatementState.PENDING,
97
- StatementState.RUNNING,
98
- ]:
99
- statement_response = w.statement_execution.get_statement(
100
- statement_response.statement_id
101
- )
102
-
103
- if statement_response.status.state != StatementState.SUCCEEDED:
104
- error_msg: str = f"SQL execution failed: {statement_response.status}"
105
- logger.error(f"[{layer_name}] {error_msg}")
106
- return error_msg
107
-
108
- # Convert to DataFrame
109
- if statement_response.result and statement_response.result.data_array:
110
- columns: list[str] = []
111
- if statement_response.manifest and statement_response.manifest.schema:
112
- columns = [col.name for col in statement_response.manifest.schema.columns]
113
- elif hasattr(statement_response.result, "schema"):
114
- columns = [col.name for col in statement_response.result.schema.columns]
115
-
116
- data: list[list[Any]] = statement_response.result.data_array
117
- if columns:
118
- return pd.DataFrame(data, columns=columns)
119
- else:
120
- return pd.DataFrame(data)
121
-
122
- return pd.DataFrame()
@@ -0,0 +1,75 @@
1
+ """
2
+ Core utilities for Genie cache implementations.
3
+
4
+ This module provides shared utility functions used by different cache
5
+ implementations (LRU, Semantic, etc.). These are concrete implementations
6
+ of common operations needed across cache types.
7
+ """
8
+
9
+ from typing import Any
10
+
11
+ import pandas as pd
12
+ from databricks.sdk import WorkspaceClient
13
+ from databricks.sdk.service.sql import StatementResponse, StatementState
14
+ from loguru import logger
15
+
16
+ from dao_ai.config import WarehouseModel
17
+
18
+
19
+ def execute_sql_via_warehouse(
20
+ warehouse: WarehouseModel,
21
+ sql: str,
22
+ layer_name: str = "cache",
23
+ ) -> pd.DataFrame | str:
24
+ """
25
+ Execute SQL using a Databricks warehouse and return results as DataFrame.
26
+
27
+ This is a shared utility for cache implementations that need to re-execute
28
+ cached SQL queries.
29
+
30
+ Args:
31
+ warehouse: The warehouse configuration for SQL execution
32
+ sql: The SQL query to execute
33
+ layer_name: Name of the cache layer (for logging)
34
+
35
+ Returns:
36
+ DataFrame with results, or error message string
37
+ """
38
+ w: WorkspaceClient = warehouse.workspace_client
39
+ warehouse_id: str = str(warehouse.warehouse_id)
40
+
41
+ logger.debug(f"[{layer_name}] Executing cached SQL: {sql[:100]}...")
42
+
43
+ statement_response: StatementResponse = w.statement_execution.execute_statement(
44
+ statement=sql,
45
+ warehouse_id=warehouse_id,
46
+ wait_timeout="30s",
47
+ )
48
+
49
+ # Poll for completion if still running
50
+ while statement_response.status.state in [
51
+ StatementState.PENDING,
52
+ StatementState.RUNNING,
53
+ ]:
54
+ statement_response = w.statement_execution.get_statement(
55
+ statement_response.statement_id
56
+ )
57
+
58
+ if statement_response.status.state != StatementState.SUCCEEDED:
59
+ error_msg: str = f"SQL execution failed: {statement_response.status}"
60
+ logger.error(f"[{layer_name}] {error_msg}")
61
+ return error_msg
62
+
63
+ # Convert to DataFrame
64
+ if statement_response.result and statement_response.result.data_array:
65
+ columns: list[str] = []
66
+ if statement_response.manifest and statement_response.manifest.schema:
67
+ columns = [col.name for col in statement_response.manifest.schema.columns]
68
+
69
+ data: list[list[Any]] = statement_response.result.data_array
70
+ if columns:
71
+ return pd.DataFrame(data, columns=columns)
72
+ else:
73
+ return pd.DataFrame(data)
74
+
75
+ return pd.DataFrame()
dao_ai/genie/cache/lru.py CHANGED
@@ -96,9 +96,21 @@ class LRUCacheService(GenieServiceBase):
96
96
  return timedelta(seconds=ttl)
97
97
 
98
98
  @staticmethod
99
- def _normalize_key(question: str) -> str:
100
- """Normalize the question to create a consistent cache key."""
101
- return question.strip().lower()
99
+ def _normalize_key(question: str, conversation_id: str | None = None) -> str:
100
+ """
101
+ Normalize the question and conversation_id to create a consistent cache key.
102
+
103
+ Args:
104
+ question: The question text
105
+ conversation_id: Optional conversation ID to include in the key
106
+
107
+ Returns:
108
+ A normalized cache key combining question and conversation_id
109
+ """
110
+ normalized_question = question.strip().lower()
111
+ if conversation_id:
112
+ return f"{conversation_id}::{normalized_question}"
113
+ return normalized_question
102
114
 
103
115
  def _is_expired(self, entry: SQLCacheEntry) -> bool:
104
116
  """Check if a cache entry has exceeded its TTL. Returns False if TTL is disabled."""
@@ -192,8 +204,6 @@ class LRUCacheService(GenieServiceBase):
192
204
  columns = [
193
205
  col.name for col in statement_response.manifest.schema.columns
194
206
  ]
195
- elif hasattr(statement_response.result, "schema"):
196
- columns = [col.name for col in statement_response.result.schema.columns]
197
207
 
198
208
  data: list[list[Any]] = statement_response.result.data_array
199
209
  if columns:
@@ -205,17 +215,14 @@ class LRUCacheService(GenieServiceBase):
205
215
 
206
216
  def ask_question(
207
217
  self, question: str, conversation_id: str | None = None
208
- ) -> GenieResponse:
218
+ ) -> CacheResult:
209
219
  """
210
220
  Ask a question, using cached SQL query if available.
211
221
 
212
222
  On cache hit, re-executes the cached SQL to get fresh data.
213
- Implements GenieServiceBase for seamless chaining.
223
+ Returns CacheResult with cache metadata.
214
224
  """
215
- result: CacheResult = self.ask_question_with_cache_info(
216
- question, conversation_id
217
- )
218
- return result.response
225
+ return self.ask_question_with_cache_info(question, conversation_id)
219
226
 
220
227
  @mlflow.trace(name="genie_lru_cache_lookup")
221
228
  def ask_question_with_cache_info(
@@ -235,7 +242,7 @@ class LRUCacheService(GenieServiceBase):
235
242
  Returns:
236
243
  CacheResult with fresh response and cache metadata
237
244
  """
238
- key: str = self._normalize_key(question)
245
+ key: str = self._normalize_key(question, conversation_id)
239
246
 
240
247
  # Check cache
241
248
  with self._lock:
@@ -244,17 +251,20 @@ class LRUCacheService(GenieServiceBase):
244
251
  if cached is not None:
245
252
  logger.info(
246
253
  f"[{self.name}] Cache HIT: '{question[:50]}...' "
247
- f"(cache_size={self.size}/{self.capacity})"
254
+ f"(conversation_id={conversation_id}, cache_size={self.size}/{self.capacity})"
248
255
  )
249
256
 
250
257
  # Re-execute the cached SQL to get fresh data
251
258
  result: pd.DataFrame | str = self._execute_sql(cached.query)
252
259
 
260
+ # Use current conversation_id, not the cached one
253
261
  response: GenieResponse = GenieResponse(
254
262
  result=result,
255
263
  query=cached.query,
256
264
  description=cached.description,
257
- conversation_id=cached.conversation_id,
265
+ conversation_id=conversation_id
266
+ if conversation_id
267
+ else cached.conversation_id,
258
268
  )
259
269
 
260
270
  return CacheResult(response=response, cache_hit=True, served_by=self.name)
@@ -262,17 +272,30 @@ class LRUCacheService(GenieServiceBase):
262
272
  # Cache miss - delegate to wrapped service
263
273
  logger.info(
264
274
  f"[{self.name}] Cache MISS: '{question[:50]}...' "
265
- f"(cache_size={self.size}/{self.capacity}, delegating to {type(self.impl).__name__})"
275
+ f"(conversation_id={conversation_id}, cache_size={self.size}/{self.capacity}, delegating to {type(self.impl).__name__})"
266
276
  )
267
277
 
268
- response = self.impl.ask_question(question, conversation_id)
278
+ result: CacheResult = self.impl.ask_question(question, conversation_id)
269
279
  with self._lock:
270
- self._put(key, response)
271
- return CacheResult(response=response, cache_hit=False, served_by=None)
280
+ self._put(key, result.response)
281
+ return CacheResult(response=result.response, cache_hit=False, served_by=None)
272
282
 
273
- def invalidate(self, question: str) -> bool:
274
- """Remove a specific entry from the cache."""
275
- key: str = self._normalize_key(question)
283
+ @property
284
+ def space_id(self) -> str:
285
+ return self.impl.space_id
286
+
287
+ def invalidate(self, question: str, conversation_id: str | None = None) -> bool:
288
+ """
289
+ Remove a specific entry from the cache.
290
+
291
+ Args:
292
+ question: The question text
293
+ conversation_id: Optional conversation ID to match
294
+
295
+ Returns:
296
+ True if the entry was found and removed, False otherwise
297
+ """
298
+ key: str = self._normalize_key(question, conversation_id)
276
299
  with self._lock:
277
300
  if key in self._cache:
278
301
  del self._cache[key]