dao-ai 0.0.35__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. dao_ai/__init__.py +29 -0
  2. dao_ai/cli.py +195 -30
  3. dao_ai/config.py +797 -242
  4. dao_ai/genie/__init__.py +38 -0
  5. dao_ai/genie/cache/__init__.py +43 -0
  6. dao_ai/genie/cache/base.py +72 -0
  7. dao_ai/genie/cache/core.py +75 -0
  8. dao_ai/genie/cache/lru.py +329 -0
  9. dao_ai/genie/cache/semantic.py +919 -0
  10. dao_ai/genie/core.py +35 -0
  11. dao_ai/graph.py +27 -253
  12. dao_ai/hooks/__init__.py +9 -6
  13. dao_ai/hooks/core.py +22 -190
  14. dao_ai/memory/__init__.py +10 -0
  15. dao_ai/memory/core.py +23 -5
  16. dao_ai/memory/databricks.py +389 -0
  17. dao_ai/memory/postgres.py +2 -2
  18. dao_ai/messages.py +6 -4
  19. dao_ai/middleware/__init__.py +125 -0
  20. dao_ai/middleware/assertions.py +778 -0
  21. dao_ai/middleware/base.py +50 -0
  22. dao_ai/middleware/core.py +61 -0
  23. dao_ai/middleware/guardrails.py +415 -0
  24. dao_ai/middleware/human_in_the_loop.py +228 -0
  25. dao_ai/middleware/message_validation.py +554 -0
  26. dao_ai/middleware/summarization.py +192 -0
  27. dao_ai/models.py +1177 -108
  28. dao_ai/nodes.py +118 -161
  29. dao_ai/optimization.py +664 -0
  30. dao_ai/orchestration/__init__.py +52 -0
  31. dao_ai/orchestration/core.py +287 -0
  32. dao_ai/orchestration/supervisor.py +264 -0
  33. dao_ai/orchestration/swarm.py +226 -0
  34. dao_ai/prompts.py +126 -29
  35. dao_ai/providers/databricks.py +126 -381
  36. dao_ai/state.py +139 -21
  37. dao_ai/tools/__init__.py +11 -5
  38. dao_ai/tools/core.py +57 -4
  39. dao_ai/tools/email.py +280 -0
  40. dao_ai/tools/genie.py +108 -35
  41. dao_ai/tools/mcp.py +4 -3
  42. dao_ai/tools/memory.py +50 -0
  43. dao_ai/tools/python.py +4 -12
  44. dao_ai/tools/search.py +14 -0
  45. dao_ai/tools/slack.py +1 -1
  46. dao_ai/tools/unity_catalog.py +8 -6
  47. dao_ai/tools/vector_search.py +16 -9
  48. dao_ai/utils.py +72 -8
  49. dao_ai-0.1.0.dist-info/METADATA +1878 -0
  50. dao_ai-0.1.0.dist-info/RECORD +62 -0
  51. dao_ai/chat_models.py +0 -204
  52. dao_ai/guardrails.py +0 -112
  53. dao_ai/tools/human_in_the_loop.py +0 -100
  54. dao_ai-0.0.35.dist-info/METADATA +0 -1169
  55. dao_ai-0.0.35.dist-info/RECORD +0 -41
  56. {dao_ai-0.0.35.dist-info → dao_ai-0.1.0.dist-info}/WHEEL +0 -0
  57. {dao_ai-0.0.35.dist-info → dao_ai-0.1.0.dist-info}/entry_points.txt +0 -0
  58. {dao_ai-0.0.35.dist-info → dao_ai-0.1.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,38 @@
1
+ """
2
+ Genie service implementations and caching layers.
3
+
4
+ This package provides core Genie functionality that can be used across
5
+ different contexts (tools, direct integration, etc.).
6
+
7
+ Main exports:
8
+ - GenieService: Core service implementation wrapping Databricks Genie SDK
9
+ - GenieServiceBase: Abstract base class for service implementations
10
+
11
+ Cache implementations are available in the cache subpackage:
12
+ - dao_ai.genie.cache.lru: LRU (Least Recently Used) cache
13
+ - dao_ai.genie.cache.semantic: Semantic similarity cache using pg_vector
14
+
15
+ Example usage:
16
+ from dao_ai.genie import GenieService
17
+ from dao_ai.genie.cache import LRUCacheService, SemanticCacheService
18
+ """
19
+
20
+ from dao_ai.genie.cache import (
21
+ CacheResult,
22
+ GenieServiceBase,
23
+ LRUCacheService,
24
+ SemanticCacheService,
25
+ SQLCacheEntry,
26
+ )
27
+ from dao_ai.genie.core import GenieService
28
+
29
+ __all__ = [
30
+ # Service classes
31
+ "GenieService",
32
+ "GenieServiceBase",
33
+ # Cache types (from cache subpackage)
34
+ "CacheResult",
35
+ "LRUCacheService",
36
+ "SemanticCacheService",
37
+ "SQLCacheEntry",
38
+ ]
@@ -0,0 +1,43 @@
1
+ """
2
+ Genie cache implementations.
3
+
4
+ This package provides caching layers for Genie SQL queries that can be
5
+ chained together using the decorator pattern.
6
+
7
+ Available cache implementations:
8
+ - LRUCacheService: In-memory LRU cache with O(1) exact match lookup
9
+ - SemanticCacheService: PostgreSQL pg_vector-based semantic similarity cache
10
+
11
+ Example usage:
12
+ from dao_ai.genie.cache import LRUCacheService, SemanticCacheService
13
+
14
+ # Chain caches: LRU (checked first) -> Semantic (checked second) -> Genie
15
+ genie_service = SemanticCacheService(
16
+ impl=GenieService(genie),
17
+ parameters=semantic_params,
18
+ )
19
+ genie_service = LRUCacheService(
20
+ impl=genie_service,
21
+ parameters=lru_params,
22
+ )
23
+ """
24
+
25
+ from dao_ai.genie.cache.base import (
26
+ CacheResult,
27
+ GenieServiceBase,
28
+ SQLCacheEntry,
29
+ )
30
+ from dao_ai.genie.cache.core import execute_sql_via_warehouse
31
+ from dao_ai.genie.cache.lru import LRUCacheService
32
+ from dao_ai.genie.cache.semantic import SemanticCacheService
33
+
34
+ __all__ = [
35
+ # Base types
36
+ "CacheResult",
37
+ "GenieServiceBase",
38
+ "SQLCacheEntry",
39
+ "execute_sql_via_warehouse",
40
+ # Cache implementations
41
+ "LRUCacheService",
42
+ "SemanticCacheService",
43
+ ]
@@ -0,0 +1,72 @@
1
+ """
2
+ Base classes and types for Genie cache implementations.
3
+
4
+ This module provides the foundational types used across different cache
5
+ implementations (LRU, Semantic, etc.). It contains only abstract base classes
6
+ and data structures - no concrete implementations.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from abc import ABC, abstractmethod
12
+ from dataclasses import dataclass
13
+ from datetime import datetime
14
+ from typing import TYPE_CHECKING
15
+
16
+ from databricks_ai_bridge.genie import GenieResponse
17
+
18
+ if TYPE_CHECKING:
19
+ from dao_ai.genie.cache.base import CacheResult
20
+
21
+
22
+ class GenieServiceBase(ABC):
23
+ """Abstract base class for Genie service implementations."""
24
+
25
+ @abstractmethod
26
+ def ask_question(
27
+ self, question: str, conversation_id: str | None = None
28
+ ) -> "CacheResult":
29
+ """
30
+ Ask a question to Genie and return the response with cache metadata.
31
+
32
+ All implementations return CacheResult to provide consistent cache information,
33
+ even when caching is disabled (cache_hit=False, served_by=None).
34
+ """
35
+ pass
36
+
37
+ @property
38
+ @abstractmethod
39
+ def space_id(self) -> str:
40
+ """The space ID for the Genie service."""
41
+ pass
42
+
43
+
44
+ @dataclass
45
+ class SQLCacheEntry:
46
+ """
47
+ A cache entry storing the SQL query metadata for re-execution.
48
+
49
+ Instead of caching the full result, we cache the SQL query so that
50
+ on cache hit we can re-execute it to get fresh data.
51
+ """
52
+
53
+ query: str
54
+ description: str
55
+ conversation_id: str
56
+ created_at: datetime
57
+
58
+
59
+ @dataclass
60
+ class CacheResult:
61
+ """
62
+ Result of a cache-aware query with metadata about cache behavior.
63
+
64
+ Attributes:
65
+ response: The GenieResponse (fresh data, possibly from cached SQL)
66
+ cache_hit: Whether the SQL query came from cache
67
+ served_by: Name of the layer that served the cached SQL (None if from origin)
68
+ """
69
+
70
+ response: GenieResponse
71
+ cache_hit: bool
72
+ served_by: str | None = None
@@ -0,0 +1,75 @@
1
+ """
2
+ Core utilities for Genie cache implementations.
3
+
4
+ This module provides shared utility functions used by different cache
5
+ implementations (LRU, Semantic, etc.). These are concrete implementations
6
+ of common operations needed across cache types.
7
+ """
8
+
9
+ from typing import Any
10
+
11
+ import pandas as pd
12
+ from databricks.sdk import WorkspaceClient
13
+ from databricks.sdk.service.sql import StatementResponse, StatementState
14
+ from loguru import logger
15
+
16
+ from dao_ai.config import WarehouseModel
17
+
18
+
19
+ def execute_sql_via_warehouse(
20
+ warehouse: WarehouseModel,
21
+ sql: str,
22
+ layer_name: str = "cache",
23
+ ) -> pd.DataFrame | str:
24
+ """
25
+ Execute SQL using a Databricks warehouse and return results as DataFrame.
26
+
27
+ This is a shared utility for cache implementations that need to re-execute
28
+ cached SQL queries.
29
+
30
+ Args:
31
+ warehouse: The warehouse configuration for SQL execution
32
+ sql: The SQL query to execute
33
+ layer_name: Name of the cache layer (for logging)
34
+
35
+ Returns:
36
+ DataFrame with results, or error message string
37
+ """
38
+ w: WorkspaceClient = warehouse.workspace_client
39
+ warehouse_id: str = str(warehouse.warehouse_id)
40
+
41
+ logger.debug(f"[{layer_name}] Executing cached SQL: {sql[:100]}...")
42
+
43
+ statement_response: StatementResponse = w.statement_execution.execute_statement(
44
+ statement=sql,
45
+ warehouse_id=warehouse_id,
46
+ wait_timeout="30s",
47
+ )
48
+
49
+ # Poll for completion if still running
50
+ while statement_response.status.state in [
51
+ StatementState.PENDING,
52
+ StatementState.RUNNING,
53
+ ]:
54
+ statement_response = w.statement_execution.get_statement(
55
+ statement_response.statement_id
56
+ )
57
+
58
+ if statement_response.status.state != StatementState.SUCCEEDED:
59
+ error_msg: str = f"SQL execution failed: {statement_response.status}"
60
+ logger.error(f"[{layer_name}] {error_msg}")
61
+ return error_msg
62
+
63
+ # Convert to DataFrame
64
+ if statement_response.result and statement_response.result.data_array:
65
+ columns: list[str] = []
66
+ if statement_response.manifest and statement_response.manifest.schema:
67
+ columns = [col.name for col in statement_response.manifest.schema.columns]
68
+
69
+ data: list[list[Any]] = statement_response.result.data_array
70
+ if columns:
71
+ return pd.DataFrame(data, columns=columns)
72
+ else:
73
+ return pd.DataFrame(data)
74
+
75
+ return pd.DataFrame()
@@ -0,0 +1,329 @@
1
+ """
2
+ LRU (Least Recently Used) cache implementation for Genie SQL queries.
3
+
4
+ This module provides an in-memory LRU cache that stores SQL queries generated
5
+ by Genie. On cache hit, the cached SQL is re-executed against the warehouse
6
+ to return fresh data while avoiding the Genie NL-to-SQL translation cost.
7
+ """
8
+
9
+ from collections import OrderedDict
10
+ from datetime import datetime, timedelta
11
+ from threading import Lock
12
+ from typing import Any
13
+
14
+ import mlflow
15
+ import pandas as pd
16
+ from databricks.sdk import WorkspaceClient
17
+ from databricks.sdk.service.sql import StatementResponse, StatementState
18
+ from databricks_ai_bridge.genie import GenieResponse
19
+ from loguru import logger
20
+
21
+ from dao_ai.config import GenieLRUCacheParametersModel, WarehouseModel
22
+ from dao_ai.genie.cache.base import (
23
+ CacheResult,
24
+ GenieServiceBase,
25
+ SQLCacheEntry,
26
+ )
27
+
28
+
29
+ class LRUCacheService(GenieServiceBase):
30
+ """
31
+ LRU caching decorator that caches SQL queries and re-executes them.
32
+
33
+ This service caches the SQL query generated by Genie (not the result data).
34
+ On cache hit, it re-executes the cached SQL using the provided warehouse
35
+ to return fresh data while avoiding the Genie NL-to-SQL translation cost.
36
+
37
+ Example:
38
+ from dao_ai.config import GenieLRUCacheParametersModel, WarehouseModel
39
+ from dao_ai.genie.cache import LRUCacheService
40
+
41
+ cache_params = GenieLRUCacheParametersModel(
42
+ warehouse=warehouse_model,
43
+ capacity=100,
44
+ time_to_live_seconds=86400 # 24 hours
45
+ )
46
+ genie = LRUCacheService(
47
+ impl=GenieService(Genie(space_id="my-space")),
48
+ parameters=cache_params
49
+ )
50
+
51
+ Thread-safe: Uses a lock to protect cache operations.
52
+ """
53
+
54
+ impl: GenieServiceBase
55
+ parameters: GenieLRUCacheParametersModel
56
+ name: str
57
+ _cache: OrderedDict[str, SQLCacheEntry]
58
+ _lock: Lock
59
+
60
+ def __init__(
61
+ self,
62
+ impl: GenieServiceBase,
63
+ parameters: GenieLRUCacheParametersModel,
64
+ name: str | None = None,
65
+ ) -> None:
66
+ """
67
+ Initialize the SQL cache service.
68
+
69
+ Args:
70
+ impl: The underlying GenieServiceBase to delegate to on cache miss
71
+ parameters: Cache configuration including warehouse, capacity, and TTL
72
+ name: Name for this cache layer (for logging). Defaults to class name.
73
+ """
74
+ self.impl = impl
75
+ self.parameters = parameters
76
+ self.name = name if name is not None else self.__class__.__name__
77
+ self._cache = OrderedDict()
78
+ self._lock = Lock()
79
+
80
+ @property
81
+ def warehouse(self) -> WarehouseModel:
82
+ """The warehouse used for executing cached SQL queries."""
83
+ return self.parameters.warehouse
84
+
85
+ @property
86
+ def capacity(self) -> int:
87
+ """Maximum number of SQL queries to cache."""
88
+ return self.parameters.capacity
89
+
90
+ @property
91
+ def time_to_live(self) -> timedelta | None:
92
+ """Duration after which cached queries expire. None means never expires."""
93
+ ttl = self.parameters.time_to_live_seconds
94
+ if ttl is None or ttl < 0:
95
+ return None
96
+ return timedelta(seconds=ttl)
97
+
98
+ @staticmethod
99
+ def _normalize_key(question: str, conversation_id: str | None = None) -> str:
100
+ """
101
+ Normalize the question and conversation_id to create a consistent cache key.
102
+
103
+ Args:
104
+ question: The question text
105
+ conversation_id: Optional conversation ID to include in the key
106
+
107
+ Returns:
108
+ A normalized cache key combining question and conversation_id
109
+ """
110
+ normalized_question = question.strip().lower()
111
+ if conversation_id:
112
+ return f"{conversation_id}::{normalized_question}"
113
+ return normalized_question
114
+
115
+ def _is_expired(self, entry: SQLCacheEntry) -> bool:
116
+ """Check if a cache entry has exceeded its TTL. Returns False if TTL is disabled."""
117
+ if self.time_to_live is None:
118
+ return False
119
+ age: timedelta = datetime.now() - entry.created_at
120
+ return age > self.time_to_live
121
+
122
+ def _evict_oldest(self) -> None:
123
+ """Remove the oldest (least recently used) entry."""
124
+ if self._cache:
125
+ oldest_key: str = next(iter(self._cache))
126
+ del self._cache[oldest_key]
127
+ logger.debug(f"[{self.name}] Evicted: {oldest_key[:50]}...")
128
+
129
+ def _get(self, key: str) -> SQLCacheEntry | None:
130
+ """Get from cache, returning None if not found or expired."""
131
+ if key not in self._cache:
132
+ return None
133
+
134
+ entry: SQLCacheEntry = self._cache[key]
135
+
136
+ if self._is_expired(entry):
137
+ del self._cache[key]
138
+ logger.debug(f"[{self.name}] Expired: {key[:50]}...")
139
+ return None
140
+
141
+ self._cache.move_to_end(key)
142
+ return entry
143
+
144
+ def _put(self, key: str, response: GenieResponse) -> None:
145
+ """Store SQL query in cache, evicting if at capacity."""
146
+ if key in self._cache:
147
+ del self._cache[key]
148
+
149
+ while len(self._cache) >= self.capacity:
150
+ self._evict_oldest()
151
+
152
+ self._cache[key] = SQLCacheEntry(
153
+ query=response.query,
154
+ description=response.description,
155
+ conversation_id=response.conversation_id,
156
+ created_at=datetime.now(),
157
+ )
158
+ logger.info(
159
+ f"[{self.name}] Stored cache entry: key='{key[:50]}...' "
160
+ f"sql='{response.query[:50] if response.query else 'None'}...' "
161
+ f"(cache_size={len(self._cache)}/{self.capacity})"
162
+ )
163
+
164
+ @mlflow.trace(name="execute_cached_sql")
165
+ def _execute_sql(self, sql: str) -> pd.DataFrame | str:
166
+ """
167
+ Execute SQL using the warehouse and return results as DataFrame.
168
+
169
+ Args:
170
+ sql: The SQL query to execute
171
+
172
+ Returns:
173
+ DataFrame with results, or error message string
174
+ """
175
+ w: WorkspaceClient = self.warehouse.workspace_client
176
+ warehouse_id: str = str(self.warehouse.warehouse_id)
177
+
178
+ logger.debug(f"[{self.name}] Executing cached SQL: {sql[:100]}...")
179
+
180
+ statement_response: StatementResponse = w.statement_execution.execute_statement(
181
+ statement=sql,
182
+ warehouse_id=warehouse_id,
183
+ wait_timeout="30s",
184
+ )
185
+
186
+ # Poll for completion if still running
187
+ while statement_response.status.state in [
188
+ StatementState.PENDING,
189
+ StatementState.RUNNING,
190
+ ]:
191
+ statement_response = w.statement_execution.get_statement(
192
+ statement_response.statement_id
193
+ )
194
+
195
+ if statement_response.status.state != StatementState.SUCCEEDED:
196
+ error_msg: str = f"SQL execution failed: {statement_response.status}"
197
+ logger.error(f"[{self.name}] {error_msg}")
198
+ return error_msg
199
+
200
+ # Convert to DataFrame
201
+ if statement_response.result and statement_response.result.data_array:
202
+ columns: list[str] = []
203
+ if statement_response.manifest and statement_response.manifest.schema:
204
+ columns = [
205
+ col.name for col in statement_response.manifest.schema.columns
206
+ ]
207
+
208
+ data: list[list[Any]] = statement_response.result.data_array
209
+ if columns:
210
+ return pd.DataFrame(data, columns=columns)
211
+ else:
212
+ return pd.DataFrame(data)
213
+
214
+ return pd.DataFrame()
215
+
216
+ def ask_question(
217
+ self, question: str, conversation_id: str | None = None
218
+ ) -> CacheResult:
219
+ """
220
+ Ask a question, using cached SQL query if available.
221
+
222
+ On cache hit, re-executes the cached SQL to get fresh data.
223
+ Returns CacheResult with cache metadata.
224
+ """
225
+ return self.ask_question_with_cache_info(question, conversation_id)
226
+
227
+ @mlflow.trace(name="genie_lru_cache_lookup")
228
+ def ask_question_with_cache_info(
229
+ self,
230
+ question: str,
231
+ conversation_id: str | None = None,
232
+ ) -> CacheResult:
233
+ """
234
+ Ask a question with detailed cache hit information.
235
+
236
+ On cache hit, the cached SQL is re-executed to return fresh data.
237
+
238
+ Args:
239
+ question: The question to ask
240
+ conversation_id: Optional conversation ID
241
+
242
+ Returns:
243
+ CacheResult with fresh response and cache metadata
244
+ """
245
+ key: str = self._normalize_key(question, conversation_id)
246
+
247
+ # Check cache
248
+ with self._lock:
249
+ cached: SQLCacheEntry | None = self._get(key)
250
+
251
+ if cached is not None:
252
+ logger.info(
253
+ f"[{self.name}] Cache HIT: '{question[:50]}...' "
254
+ f"(conversation_id={conversation_id}, cache_size={self.size}/{self.capacity})"
255
+ )
256
+
257
+ # Re-execute the cached SQL to get fresh data
258
+ result: pd.DataFrame | str = self._execute_sql(cached.query)
259
+
260
+ # Use current conversation_id, not the cached one
261
+ response: GenieResponse = GenieResponse(
262
+ result=result,
263
+ query=cached.query,
264
+ description=cached.description,
265
+ conversation_id=conversation_id
266
+ if conversation_id
267
+ else cached.conversation_id,
268
+ )
269
+
270
+ return CacheResult(response=response, cache_hit=True, served_by=self.name)
271
+
272
+ # Cache miss - delegate to wrapped service
273
+ logger.info(
274
+ f"[{self.name}] Cache MISS: '{question[:50]}...' "
275
+ f"(conversation_id={conversation_id}, cache_size={self.size}/{self.capacity}, delegating to {type(self.impl).__name__})"
276
+ )
277
+
278
+ result: CacheResult = self.impl.ask_question(question, conversation_id)
279
+ with self._lock:
280
+ self._put(key, result.response)
281
+ return CacheResult(response=result.response, cache_hit=False, served_by=None)
282
+
283
+ @property
284
+ def space_id(self) -> str:
285
+ return self.impl.space_id
286
+
287
+ def invalidate(self, question: str, conversation_id: str | None = None) -> bool:
288
+ """
289
+ Remove a specific entry from the cache.
290
+
291
+ Args:
292
+ question: The question text
293
+ conversation_id: Optional conversation ID to match
294
+
295
+ Returns:
296
+ True if the entry was found and removed, False otherwise
297
+ """
298
+ key: str = self._normalize_key(question, conversation_id)
299
+ with self._lock:
300
+ if key in self._cache:
301
+ del self._cache[key]
302
+ return True
303
+ return False
304
+
305
+ def clear(self) -> int:
306
+ """Clear all entries from the cache."""
307
+ with self._lock:
308
+ count: int = len(self._cache)
309
+ self._cache.clear()
310
+ return count
311
+
312
+ @property
313
+ def size(self) -> int:
314
+ """Current number of entries in the cache."""
315
+ with self._lock:
316
+ return len(self._cache)
317
+
318
+ def stats(self) -> dict[str, int | float | None]:
319
+ """Return cache statistics."""
320
+ with self._lock:
321
+ expired: int = sum(1 for e in self._cache.values() if self._is_expired(e))
322
+ ttl = self.time_to_live
323
+ return {
324
+ "size": len(self._cache),
325
+ "capacity": self.capacity,
326
+ "ttl_seconds": ttl.total_seconds() if ttl else None,
327
+ "expired_entries": expired,
328
+ "valid_entries": len(self._cache) - expired,
329
+ }