dao-ai 0.0.28__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. dao_ai/__init__.py +29 -0
  2. dao_ai/agent_as_code.py +2 -5
  3. dao_ai/cli.py +245 -40
  4. dao_ai/config.py +1491 -370
  5. dao_ai/genie/__init__.py +38 -0
  6. dao_ai/genie/cache/__init__.py +43 -0
  7. dao_ai/genie/cache/base.py +72 -0
  8. dao_ai/genie/cache/core.py +79 -0
  9. dao_ai/genie/cache/lru.py +347 -0
  10. dao_ai/genie/cache/semantic.py +970 -0
  11. dao_ai/genie/core.py +35 -0
  12. dao_ai/graph.py +27 -253
  13. dao_ai/hooks/__init__.py +9 -6
  14. dao_ai/hooks/core.py +27 -195
  15. dao_ai/logging.py +56 -0
  16. dao_ai/memory/__init__.py +10 -0
  17. dao_ai/memory/core.py +65 -30
  18. dao_ai/memory/databricks.py +402 -0
  19. dao_ai/memory/postgres.py +79 -38
  20. dao_ai/messages.py +6 -4
  21. dao_ai/middleware/__init__.py +125 -0
  22. dao_ai/middleware/assertions.py +806 -0
  23. dao_ai/middleware/base.py +50 -0
  24. dao_ai/middleware/core.py +67 -0
  25. dao_ai/middleware/guardrails.py +420 -0
  26. dao_ai/middleware/human_in_the_loop.py +232 -0
  27. dao_ai/middleware/message_validation.py +586 -0
  28. dao_ai/middleware/summarization.py +197 -0
  29. dao_ai/models.py +1306 -114
  30. dao_ai/nodes.py +245 -159
  31. dao_ai/optimization.py +674 -0
  32. dao_ai/orchestration/__init__.py +52 -0
  33. dao_ai/orchestration/core.py +294 -0
  34. dao_ai/orchestration/supervisor.py +278 -0
  35. dao_ai/orchestration/swarm.py +271 -0
  36. dao_ai/prompts.py +128 -31
  37. dao_ai/providers/databricks.py +573 -601
  38. dao_ai/state.py +157 -21
  39. dao_ai/tools/__init__.py +13 -5
  40. dao_ai/tools/agent.py +1 -3
  41. dao_ai/tools/core.py +64 -11
  42. dao_ai/tools/email.py +232 -0
  43. dao_ai/tools/genie.py +144 -294
  44. dao_ai/tools/mcp.py +223 -155
  45. dao_ai/tools/memory.py +50 -0
  46. dao_ai/tools/python.py +9 -14
  47. dao_ai/tools/search.py +14 -0
  48. dao_ai/tools/slack.py +22 -10
  49. dao_ai/tools/sql.py +202 -0
  50. dao_ai/tools/time.py +30 -7
  51. dao_ai/tools/unity_catalog.py +165 -88
  52. dao_ai/tools/vector_search.py +331 -221
  53. dao_ai/utils.py +166 -20
  54. dao_ai-0.1.2.dist-info/METADATA +455 -0
  55. dao_ai-0.1.2.dist-info/RECORD +64 -0
  56. dao_ai/chat_models.py +0 -204
  57. dao_ai/guardrails.py +0 -112
  58. dao_ai/tools/human_in_the_loop.py +0 -100
  59. dao_ai-0.0.28.dist-info/METADATA +0 -1168
  60. dao_ai-0.0.28.dist-info/RECORD +0 -41
  61. {dao_ai-0.0.28.dist-info → dao_ai-0.1.2.dist-info}/WHEEL +0 -0
  62. {dao_ai-0.0.28.dist-info → dao_ai-0.1.2.dist-info}/entry_points.txt +0 -0
  63. {dao_ai-0.0.28.dist-info → dao_ai-0.1.2.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,38 @@
1
+ """
2
+ Genie service implementations and caching layers.
3
+
4
+ This package provides core Genie functionality that can be used across
5
+ different contexts (tools, direct integration, etc.).
6
+
7
+ Main exports:
8
+ - GenieService: Core service implementation wrapping Databricks Genie SDK
9
+ - GenieServiceBase: Abstract base class for service implementations
10
+
11
+ Cache implementations are available in the cache subpackage:
12
+ - dao_ai.genie.cache.lru: LRU (Least Recently Used) cache
13
+ - dao_ai.genie.cache.semantic: Semantic similarity cache using pg_vector
14
+
15
+ Example usage:
16
+ from dao_ai.genie import GenieService
17
+ from dao_ai.genie.cache import LRUCacheService, SemanticCacheService
18
+ """
19
+
20
+ from dao_ai.genie.cache import (
21
+ CacheResult,
22
+ GenieServiceBase,
23
+ LRUCacheService,
24
+ SemanticCacheService,
25
+ SQLCacheEntry,
26
+ )
27
+ from dao_ai.genie.core import GenieService
28
+
29
+ __all__ = [
30
+ # Service classes
31
+ "GenieService",
32
+ "GenieServiceBase",
33
+ # Cache types (from cache subpackage)
34
+ "CacheResult",
35
+ "LRUCacheService",
36
+ "SemanticCacheService",
37
+ "SQLCacheEntry",
38
+ ]
@@ -0,0 +1,43 @@
1
+ """
2
+ Genie cache implementations.
3
+
4
+ This package provides caching layers for Genie SQL queries that can be
5
+ chained together using the decorator pattern.
6
+
7
+ Available cache implementations:
8
+ - LRUCacheService: In-memory LRU cache with O(1) exact match lookup
9
+ - SemanticCacheService: PostgreSQL pg_vector-based semantic similarity cache
10
+
11
+ Example usage:
12
+ from dao_ai.genie.cache import LRUCacheService, SemanticCacheService
13
+
14
+ # Chain caches: LRU (checked first) -> Semantic (checked second) -> Genie
15
+ genie_service = SemanticCacheService(
16
+ impl=GenieService(genie),
17
+ parameters=semantic_params,
18
+ )
19
+ genie_service = LRUCacheService(
20
+ impl=genie_service,
21
+ parameters=lru_params,
22
+ )
23
+ """
24
+
25
+ from dao_ai.genie.cache.base import (
26
+ CacheResult,
27
+ GenieServiceBase,
28
+ SQLCacheEntry,
29
+ )
30
+ from dao_ai.genie.cache.core import execute_sql_via_warehouse
31
+ from dao_ai.genie.cache.lru import LRUCacheService
32
+ from dao_ai.genie.cache.semantic import SemanticCacheService
33
+
34
+ __all__ = [
35
+ # Base types
36
+ "CacheResult",
37
+ "GenieServiceBase",
38
+ "SQLCacheEntry",
39
+ "execute_sql_via_warehouse",
40
+ # Cache implementations
41
+ "LRUCacheService",
42
+ "SemanticCacheService",
43
+ ]
@@ -0,0 +1,72 @@
1
+ """
2
+ Base classes and types for Genie cache implementations.
3
+
4
+ This module provides the foundational types used across different cache
5
+ implementations (LRU, Semantic, etc.). It contains only abstract base classes
6
+ and data structures - no concrete implementations.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from abc import ABC, abstractmethod
12
+ from dataclasses import dataclass
13
+ from datetime import datetime
14
+ from typing import TYPE_CHECKING
15
+
16
+ from databricks_ai_bridge.genie import GenieResponse
17
+
18
+ if TYPE_CHECKING:
19
+ from dao_ai.genie.cache.base import CacheResult
20
+
21
+
22
+ class GenieServiceBase(ABC):
23
+ """Abstract base class for Genie service implementations."""
24
+
25
+ @abstractmethod
26
+ def ask_question(
27
+ self, question: str, conversation_id: str | None = None
28
+ ) -> "CacheResult":
29
+ """
30
+ Ask a question to Genie and return the response with cache metadata.
31
+
32
+ All implementations return CacheResult to provide consistent cache information,
33
+ even when caching is disabled (cache_hit=False, served_by=None).
34
+ """
35
+ pass
36
+
37
+ @property
38
+ @abstractmethod
39
+ def space_id(self) -> str:
40
+ """The space ID for the Genie service."""
41
+ pass
42
+
43
+
44
+ @dataclass
45
+ class SQLCacheEntry:
46
+ """
47
+ A cache entry storing the SQL query metadata for re-execution.
48
+
49
+ Instead of caching the full result, we cache the SQL query so that
50
+ on cache hit we can re-execute it to get fresh data.
51
+ """
52
+
53
+ query: str
54
+ description: str
55
+ conversation_id: str
56
+ created_at: datetime
57
+
58
+
59
+ @dataclass
60
+ class CacheResult:
61
+ """
62
+ Result of a cache-aware query with metadata about cache behavior.
63
+
64
+ Attributes:
65
+ response: The GenieResponse (fresh data, possibly from cached SQL)
66
+ cache_hit: Whether the SQL query came from cache
67
+ served_by: Name of the layer that served the cached SQL (None if from origin)
68
+ """
69
+
70
+ response: GenieResponse
71
+ cache_hit: bool
72
+ served_by: str | None = None
@@ -0,0 +1,79 @@
1
+ """
2
+ Core utilities for Genie cache implementations.
3
+
4
+ This module provides shared utility functions used by different cache
5
+ implementations (LRU, Semantic, etc.). These are concrete implementations
6
+ of common operations needed across cache types.
7
+ """
8
+
9
+ from typing import Any
10
+
11
+ import pandas as pd
12
+ from databricks.sdk import WorkspaceClient
13
+ from databricks.sdk.service.sql import StatementResponse, StatementState
14
+ from loguru import logger
15
+
16
+ from dao_ai.config import WarehouseModel
17
+
18
+
19
+ def execute_sql_via_warehouse(
20
+ warehouse: WarehouseModel,
21
+ sql: str,
22
+ layer_name: str = "cache",
23
+ ) -> pd.DataFrame | str:
24
+ """
25
+ Execute SQL using a Databricks warehouse and return results as DataFrame.
26
+
27
+ This is a shared utility for cache implementations that need to re-execute
28
+ cached SQL queries.
29
+
30
+ Args:
31
+ warehouse: The warehouse configuration for SQL execution
32
+ sql: The SQL query to execute
33
+ layer_name: Name of the cache layer (for logging)
34
+
35
+ Returns:
36
+ DataFrame with results, or error message string
37
+ """
38
+ w: WorkspaceClient = warehouse.workspace_client
39
+ warehouse_id: str = str(warehouse.warehouse_id)
40
+
41
+ logger.trace("Executing cached SQL", layer=layer_name, sql_prefix=sql[:100])
42
+
43
+ statement_response: StatementResponse = w.statement_execution.execute_statement(
44
+ statement=sql,
45
+ warehouse_id=warehouse_id,
46
+ wait_timeout="30s",
47
+ )
48
+
49
+ # Poll for completion if still running
50
+ while statement_response.status.state in [
51
+ StatementState.PENDING,
52
+ StatementState.RUNNING,
53
+ ]:
54
+ statement_response = w.statement_execution.get_statement(
55
+ statement_response.statement_id
56
+ )
57
+
58
+ if statement_response.status.state != StatementState.SUCCEEDED:
59
+ error_msg: str = f"SQL execution failed: {statement_response.status}"
60
+ logger.error(
61
+ "SQL execution failed",
62
+ layer=layer_name,
63
+ status=str(statement_response.status),
64
+ )
65
+ return error_msg
66
+
67
+ # Convert to DataFrame
68
+ if statement_response.result and statement_response.result.data_array:
69
+ columns: list[str] = []
70
+ if statement_response.manifest and statement_response.manifest.schema:
71
+ columns = [col.name for col in statement_response.manifest.schema.columns]
72
+
73
+ data: list[list[Any]] = statement_response.result.data_array
74
+ if columns:
75
+ return pd.DataFrame(data, columns=columns)
76
+ else:
77
+ return pd.DataFrame(data)
78
+
79
+ return pd.DataFrame()
@@ -0,0 +1,347 @@
1
+ """
2
+ LRU (Least Recently Used) cache implementation for Genie SQL queries.
3
+
4
+ This module provides an in-memory LRU cache that stores SQL queries generated
5
+ by Genie. On cache hit, the cached SQL is re-executed against the warehouse
6
+ to return fresh data while avoiding the Genie NL-to-SQL translation cost.
7
+ """
8
+
9
+ from collections import OrderedDict
10
+ from datetime import datetime, timedelta
11
+ from threading import Lock
12
+ from typing import Any
13
+
14
+ import mlflow
15
+ import pandas as pd
16
+ from databricks.sdk import WorkspaceClient
17
+ from databricks.sdk.service.sql import StatementResponse, StatementState
18
+ from databricks_ai_bridge.genie import GenieResponse
19
+ from loguru import logger
20
+
21
+ from dao_ai.config import GenieLRUCacheParametersModel, WarehouseModel
22
+ from dao_ai.genie.cache.base import (
23
+ CacheResult,
24
+ GenieServiceBase,
25
+ SQLCacheEntry,
26
+ )
27
+
28
+
29
+ class LRUCacheService(GenieServiceBase):
30
+ """
31
+ LRU caching decorator that caches SQL queries and re-executes them.
32
+
33
+ This service caches the SQL query generated by Genie (not the result data).
34
+ On cache hit, it re-executes the cached SQL using the provided warehouse
35
+ to return fresh data while avoiding the Genie NL-to-SQL translation cost.
36
+
37
+ Example:
38
+ from dao_ai.config import GenieLRUCacheParametersModel, WarehouseModel
39
+ from dao_ai.genie.cache import LRUCacheService
40
+
41
+ cache_params = GenieLRUCacheParametersModel(
42
+ warehouse=warehouse_model,
43
+ capacity=100,
44
+ time_to_live_seconds=86400 # 24 hours
45
+ )
46
+ genie = LRUCacheService(
47
+ impl=GenieService(Genie(space_id="my-space")),
48
+ parameters=cache_params
49
+ )
50
+
51
+ Thread-safe: Uses a lock to protect cache operations.
52
+ """
53
+
54
+ impl: GenieServiceBase
55
+ parameters: GenieLRUCacheParametersModel
56
+ name: str
57
+ _cache: OrderedDict[str, SQLCacheEntry]
58
+ _lock: Lock
59
+
60
+ def __init__(
61
+ self,
62
+ impl: GenieServiceBase,
63
+ parameters: GenieLRUCacheParametersModel,
64
+ name: str | None = None,
65
+ ) -> None:
66
+ """
67
+ Initialize the SQL cache service.
68
+
69
+ Args:
70
+ impl: The underlying GenieServiceBase to delegate to on cache miss
71
+ parameters: Cache configuration including warehouse, capacity, and TTL
72
+ name: Name for this cache layer (for logging). Defaults to class name.
73
+ """
74
+ self.impl = impl
75
+ self.parameters = parameters
76
+ self.name = name if name is not None else self.__class__.__name__
77
+ self._cache = OrderedDict()
78
+ self._lock = Lock()
79
+
80
+ @property
81
+ def warehouse(self) -> WarehouseModel:
82
+ """The warehouse used for executing cached SQL queries."""
83
+ return self.parameters.warehouse
84
+
85
+ @property
86
+ def capacity(self) -> int:
87
+ """Maximum number of SQL queries to cache."""
88
+ return self.parameters.capacity
89
+
90
+ @property
91
+ def time_to_live(self) -> timedelta | None:
92
+ """Duration after which cached queries expire. None means never expires."""
93
+ ttl = self.parameters.time_to_live_seconds
94
+ if ttl is None or ttl < 0:
95
+ return None
96
+ return timedelta(seconds=ttl)
97
+
98
+ @staticmethod
99
+ def _normalize_key(question: str, conversation_id: str | None = None) -> str:
100
+ """
101
+ Normalize the question and conversation_id to create a consistent cache key.
102
+
103
+ Args:
104
+ question: The question text
105
+ conversation_id: Optional conversation ID to include in the key
106
+
107
+ Returns:
108
+ A normalized cache key combining question and conversation_id
109
+ """
110
+ normalized_question = question.strip().lower()
111
+ if conversation_id:
112
+ return f"{conversation_id}::{normalized_question}"
113
+ return normalized_question
114
+
115
+ def _is_expired(self, entry: SQLCacheEntry) -> bool:
116
+ """Check if a cache entry has exceeded its TTL. Returns False if TTL is disabled."""
117
+ if self.time_to_live is None:
118
+ return False
119
+ age: timedelta = datetime.now() - entry.created_at
120
+ return age > self.time_to_live
121
+
122
+ def _evict_oldest(self) -> None:
123
+ """Remove the oldest (least recently used) entry."""
124
+ if self._cache:
125
+ oldest_key: str = next(iter(self._cache))
126
+ del self._cache[oldest_key]
127
+ logger.trace(
128
+ "Evicted cache entry", layer=self.name, key_prefix=oldest_key[:50]
129
+ )
130
+
131
+ def _get(self, key: str) -> SQLCacheEntry | None:
132
+ """Get from cache, returning None if not found or expired."""
133
+ if key not in self._cache:
134
+ return None
135
+
136
+ entry: SQLCacheEntry = self._cache[key]
137
+
138
+ if self._is_expired(entry):
139
+ del self._cache[key]
140
+ logger.trace("Expired cache entry", layer=self.name, key_prefix=key[:50])
141
+ return None
142
+
143
+ self._cache.move_to_end(key)
144
+ return entry
145
+
146
+ def _put(self, key: str, response: GenieResponse) -> None:
147
+ """Store SQL query in cache, evicting if at capacity."""
148
+ if key in self._cache:
149
+ del self._cache[key]
150
+
151
+ while len(self._cache) >= self.capacity:
152
+ self._evict_oldest()
153
+
154
+ self._cache[key] = SQLCacheEntry(
155
+ query=response.query,
156
+ description=response.description,
157
+ conversation_id=response.conversation_id,
158
+ created_at=datetime.now(),
159
+ )
160
+ logger.info(
161
+ "Stored cache entry",
162
+ layer=self.name,
163
+ key_prefix=key[:50],
164
+ sql_prefix=response.query[:50] if response.query else None,
165
+ cache_size=len(self._cache),
166
+ capacity=self.capacity,
167
+ )
168
+
169
+ @mlflow.trace(name="execute_cached_sql")
170
+ def _execute_sql(self, sql: str) -> pd.DataFrame | str:
171
+ """
172
+ Execute SQL using the warehouse and return results as DataFrame.
173
+
174
+ Args:
175
+ sql: The SQL query to execute
176
+
177
+ Returns:
178
+ DataFrame with results, or error message string
179
+ """
180
+ w: WorkspaceClient = self.warehouse.workspace_client
181
+ warehouse_id: str = str(self.warehouse.warehouse_id)
182
+
183
+ logger.trace("Executing cached SQL", layer=self.name, sql_prefix=sql[:100])
184
+
185
+ statement_response: StatementResponse = w.statement_execution.execute_statement(
186
+ statement=sql,
187
+ warehouse_id=warehouse_id,
188
+ wait_timeout="30s",
189
+ )
190
+
191
+ # Poll for completion if still running
192
+ while statement_response.status.state in [
193
+ StatementState.PENDING,
194
+ StatementState.RUNNING,
195
+ ]:
196
+ statement_response = w.statement_execution.get_statement(
197
+ statement_response.statement_id
198
+ )
199
+
200
+ if statement_response.status.state != StatementState.SUCCEEDED:
201
+ error_msg: str = f"SQL execution failed: {statement_response.status}"
202
+ logger.error(
203
+ "SQL execution failed",
204
+ layer=self.name,
205
+ status=str(statement_response.status),
206
+ )
207
+ return error_msg
208
+
209
+ # Convert to DataFrame
210
+ if statement_response.result and statement_response.result.data_array:
211
+ columns: list[str] = []
212
+ if statement_response.manifest and statement_response.manifest.schema:
213
+ columns = [
214
+ col.name for col in statement_response.manifest.schema.columns
215
+ ]
216
+
217
+ data: list[list[Any]] = statement_response.result.data_array
218
+ if columns:
219
+ return pd.DataFrame(data, columns=columns)
220
+ else:
221
+ return pd.DataFrame(data)
222
+
223
+ return pd.DataFrame()
224
+
225
+ def ask_question(
226
+ self, question: str, conversation_id: str | None = None
227
+ ) -> CacheResult:
228
+ """
229
+ Ask a question, using cached SQL query if available.
230
+
231
+ On cache hit, re-executes the cached SQL to get fresh data.
232
+ Returns CacheResult with cache metadata.
233
+ """
234
+ return self.ask_question_with_cache_info(question, conversation_id)
235
+
236
+ @mlflow.trace(name="genie_lru_cache_lookup")
237
+ def ask_question_with_cache_info(
238
+ self,
239
+ question: str,
240
+ conversation_id: str | None = None,
241
+ ) -> CacheResult:
242
+ """
243
+ Ask a question with detailed cache hit information.
244
+
245
+ On cache hit, the cached SQL is re-executed to return fresh data.
246
+
247
+ Args:
248
+ question: The question to ask
249
+ conversation_id: Optional conversation ID
250
+
251
+ Returns:
252
+ CacheResult with fresh response and cache metadata
253
+ """
254
+ key: str = self._normalize_key(question, conversation_id)
255
+
256
+ # Check cache
257
+ with self._lock:
258
+ cached: SQLCacheEntry | None = self._get(key)
259
+
260
+ if cached is not None:
261
+ logger.info(
262
+ "Cache HIT",
263
+ layer=self.name,
264
+ question_prefix=question[:50],
265
+ conversation_id=conversation_id,
266
+ cache_size=self.size,
267
+ capacity=self.capacity,
268
+ )
269
+
270
+ # Re-execute the cached SQL to get fresh data
271
+ result: pd.DataFrame | str = self._execute_sql(cached.query)
272
+
273
+ # Use current conversation_id, not the cached one
274
+ response: GenieResponse = GenieResponse(
275
+ result=result,
276
+ query=cached.query,
277
+ description=cached.description,
278
+ conversation_id=conversation_id
279
+ if conversation_id
280
+ else cached.conversation_id,
281
+ )
282
+
283
+ return CacheResult(response=response, cache_hit=True, served_by=self.name)
284
+
285
+ # Cache miss - delegate to wrapped service
286
+ logger.info(
287
+ "Cache MISS",
288
+ layer=self.name,
289
+ question_prefix=question[:50],
290
+ conversation_id=conversation_id,
291
+ cache_size=self.size,
292
+ capacity=self.capacity,
293
+ delegating_to=type(self.impl).__name__,
294
+ )
295
+
296
+ result: CacheResult = self.impl.ask_question(question, conversation_id)
297
+ with self._lock:
298
+ self._put(key, result.response)
299
+ return CacheResult(response=result.response, cache_hit=False, served_by=None)
300
+
301
+ @property
302
+ def space_id(self) -> str:
303
+ return self.impl.space_id
304
+
305
+ def invalidate(self, question: str, conversation_id: str | None = None) -> bool:
306
+ """
307
+ Remove a specific entry from the cache.
308
+
309
+ Args:
310
+ question: The question text
311
+ conversation_id: Optional conversation ID to match
312
+
313
+ Returns:
314
+ True if the entry was found and removed, False otherwise
315
+ """
316
+ key: str = self._normalize_key(question, conversation_id)
317
+ with self._lock:
318
+ if key in self._cache:
319
+ del self._cache[key]
320
+ return True
321
+ return False
322
+
323
+ def clear(self) -> int:
324
+ """Clear all entries from the cache."""
325
+ with self._lock:
326
+ count: int = len(self._cache)
327
+ self._cache.clear()
328
+ return count
329
+
330
+ @property
331
+ def size(self) -> int:
332
+ """Current number of entries in the cache."""
333
+ with self._lock:
334
+ return len(self._cache)
335
+
336
+ def stats(self) -> dict[str, int | float | None]:
337
+ """Return cache statistics."""
338
+ with self._lock:
339
+ expired: int = sum(1 for e in self._cache.values() if self._is_expired(e))
340
+ ttl = self.time_to_live
341
+ return {
342
+ "size": len(self._cache),
343
+ "capacity": self.capacity,
344
+ "ttl_seconds": ttl.total_seconds() if ttl else None,
345
+ "expired_entries": expired,
346
+ "valid_entries": len(self._cache) - expired,
347
+ }