dao-ai 0.0.36__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/__init__.py +29 -0
- dao_ai/cli.py +195 -30
- dao_ai/config.py +770 -244
- dao_ai/genie/__init__.py +1 -22
- dao_ai/genie/cache/__init__.py +1 -2
- dao_ai/genie/cache/base.py +20 -70
- dao_ai/genie/cache/core.py +75 -0
- dao_ai/genie/cache/lru.py +44 -21
- dao_ai/genie/cache/semantic.py +390 -109
- dao_ai/genie/core.py +35 -0
- dao_ai/graph.py +27 -253
- dao_ai/hooks/__init__.py +9 -6
- dao_ai/hooks/core.py +22 -190
- dao_ai/memory/__init__.py +10 -0
- dao_ai/memory/core.py +23 -5
- dao_ai/memory/databricks.py +389 -0
- dao_ai/memory/postgres.py +2 -2
- dao_ai/messages.py +6 -4
- dao_ai/middleware/__init__.py +125 -0
- dao_ai/middleware/assertions.py +778 -0
- dao_ai/middleware/base.py +50 -0
- dao_ai/middleware/core.py +61 -0
- dao_ai/middleware/guardrails.py +415 -0
- dao_ai/middleware/human_in_the_loop.py +228 -0
- dao_ai/middleware/message_validation.py +554 -0
- dao_ai/middleware/summarization.py +192 -0
- dao_ai/models.py +1177 -108
- dao_ai/nodes.py +118 -161
- dao_ai/optimization.py +664 -0
- dao_ai/orchestration/__init__.py +52 -0
- dao_ai/orchestration/core.py +287 -0
- dao_ai/orchestration/supervisor.py +264 -0
- dao_ai/orchestration/swarm.py +226 -0
- dao_ai/prompts.py +126 -29
- dao_ai/providers/databricks.py +126 -381
- dao_ai/state.py +139 -21
- dao_ai/tools/__init__.py +8 -5
- dao_ai/tools/core.py +57 -4
- dao_ai/tools/email.py +280 -0
- dao_ai/tools/genie.py +47 -24
- dao_ai/tools/mcp.py +4 -3
- dao_ai/tools/memory.py +50 -0
- dao_ai/tools/python.py +4 -12
- dao_ai/tools/search.py +14 -0
- dao_ai/tools/slack.py +1 -1
- dao_ai/tools/unity_catalog.py +8 -6
- dao_ai/tools/vector_search.py +16 -9
- dao_ai/utils.py +72 -8
- dao_ai-0.1.0.dist-info/METADATA +1878 -0
- dao_ai-0.1.0.dist-info/RECORD +62 -0
- dao_ai/chat_models.py +0 -204
- dao_ai/guardrails.py +0 -112
- dao_ai/tools/genie/__init__.py +0 -236
- dao_ai/tools/human_in_the_loop.py +0 -100
- dao_ai-0.0.36.dist-info/METADATA +0 -951
- dao_ai-0.0.36.dist-info/RECORD +0 -47
- {dao_ai-0.0.36.dist-info → dao_ai-0.1.0.dist-info}/WHEEL +0 -0
- {dao_ai-0.0.36.dist-info → dao_ai-0.1.0.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.36.dist-info → dao_ai-0.1.0.dist-info}/licenses/LICENSE +0 -0
dao_ai/genie/__init__.py
CHANGED
|
@@ -17,9 +17,6 @@ Example usage:
|
|
|
17
17
|
from dao_ai.genie.cache import LRUCacheService, SemanticCacheService
|
|
18
18
|
"""
|
|
19
19
|
|
|
20
|
-
import mlflow
|
|
21
|
-
from databricks_ai_bridge.genie import Genie, GenieResponse
|
|
22
|
-
|
|
23
20
|
from dao_ai.genie.cache import (
|
|
24
21
|
CacheResult,
|
|
25
22
|
GenieServiceBase,
|
|
@@ -27,25 +24,7 @@ from dao_ai.genie.cache import (
|
|
|
27
24
|
SemanticCacheService,
|
|
28
25
|
SQLCacheEntry,
|
|
29
26
|
)
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
class GenieService(GenieServiceBase):
|
|
33
|
-
"""Concrete implementation of GenieServiceBase using the Genie SDK."""
|
|
34
|
-
|
|
35
|
-
genie: Genie
|
|
36
|
-
|
|
37
|
-
def __init__(self, genie: Genie) -> None:
|
|
38
|
-
self.genie = genie
|
|
39
|
-
|
|
40
|
-
@mlflow.trace(name="genie_ask_question")
|
|
41
|
-
def ask_question(
|
|
42
|
-
self, question: str, conversation_id: str | None = None
|
|
43
|
-
) -> GenieResponse:
|
|
44
|
-
response: GenieResponse = self.genie.ask_question(
|
|
45
|
-
question, conversation_id=conversation_id
|
|
46
|
-
)
|
|
47
|
-
return response
|
|
48
|
-
|
|
27
|
+
from dao_ai.genie.core import GenieService
|
|
49
28
|
|
|
50
29
|
__all__ = [
|
|
51
30
|
# Service classes
|
dao_ai/genie/cache/__init__.py
CHANGED
|
@@ -15,7 +15,6 @@ Example usage:
|
|
|
15
15
|
genie_service = SemanticCacheService(
|
|
16
16
|
impl=GenieService(genie),
|
|
17
17
|
parameters=semantic_params,
|
|
18
|
-
genie_space_id=space_id,
|
|
19
18
|
)
|
|
20
19
|
genie_service = LRUCacheService(
|
|
21
20
|
impl=genie_service,
|
|
@@ -27,8 +26,8 @@ from dao_ai.genie.cache.base import (
|
|
|
27
26
|
CacheResult,
|
|
28
27
|
GenieServiceBase,
|
|
29
28
|
SQLCacheEntry,
|
|
30
|
-
execute_sql_via_warehouse,
|
|
31
29
|
)
|
|
30
|
+
from dao_ai.genie.cache.core import execute_sql_via_warehouse
|
|
32
31
|
from dao_ai.genie.cache.lru import LRUCacheService
|
|
33
32
|
from dao_ai.genie.cache.semantic import SemanticCacheService
|
|
34
33
|
|
dao_ai/genie/cache/base.py
CHANGED
|
@@ -2,21 +2,21 @@
|
|
|
2
2
|
Base classes and types for Genie cache implementations.
|
|
3
3
|
|
|
4
4
|
This module provides the foundational types used across different cache
|
|
5
|
-
implementations (LRU, Semantic, etc.).
|
|
5
|
+
implementations (LRU, Semantic, etc.). It contains only abstract base classes
|
|
6
|
+
and data structures - no concrete implementations.
|
|
6
7
|
"""
|
|
7
8
|
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
8
11
|
from abc import ABC, abstractmethod
|
|
9
12
|
from dataclasses import dataclass
|
|
10
13
|
from datetime import datetime
|
|
11
|
-
from typing import
|
|
14
|
+
from typing import TYPE_CHECKING
|
|
12
15
|
|
|
13
|
-
import pandas as pd
|
|
14
|
-
from databricks.sdk import WorkspaceClient
|
|
15
|
-
from databricks.sdk.service.sql import StatementResponse, StatementState
|
|
16
16
|
from databricks_ai_bridge.genie import GenieResponse
|
|
17
|
-
from loguru import logger
|
|
18
17
|
|
|
19
|
-
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from dao_ai.genie.cache.base import CacheResult
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
class GenieServiceBase(ABC):
|
|
@@ -25,8 +25,19 @@ class GenieServiceBase(ABC):
|
|
|
25
25
|
@abstractmethod
|
|
26
26
|
def ask_question(
|
|
27
27
|
self, question: str, conversation_id: str | None = None
|
|
28
|
-
) ->
|
|
29
|
-
"""
|
|
28
|
+
) -> "CacheResult":
|
|
29
|
+
"""
|
|
30
|
+
Ask a question to Genie and return the response with cache metadata.
|
|
31
|
+
|
|
32
|
+
All implementations return CacheResult to provide consistent cache information,
|
|
33
|
+
even when caching is disabled (cache_hit=False, served_by=None).
|
|
34
|
+
"""
|
|
35
|
+
pass
|
|
36
|
+
|
|
37
|
+
@property
|
|
38
|
+
@abstractmethod
|
|
39
|
+
def space_id(self) -> str:
|
|
40
|
+
"""The space ID for the Genie service."""
|
|
30
41
|
pass
|
|
31
42
|
|
|
32
43
|
|
|
@@ -59,64 +70,3 @@ class CacheResult:
|
|
|
59
70
|
response: GenieResponse
|
|
60
71
|
cache_hit: bool
|
|
61
72
|
served_by: str | None = None
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
def execute_sql_via_warehouse(
|
|
65
|
-
warehouse: WarehouseModel,
|
|
66
|
-
sql: str,
|
|
67
|
-
layer_name: str = "cache",
|
|
68
|
-
) -> pd.DataFrame | str:
|
|
69
|
-
"""
|
|
70
|
-
Execute SQL using a Databricks warehouse and return results as DataFrame.
|
|
71
|
-
|
|
72
|
-
This is a shared utility for cache implementations that need to re-execute
|
|
73
|
-
cached SQL queries.
|
|
74
|
-
|
|
75
|
-
Args:
|
|
76
|
-
warehouse: The warehouse configuration for SQL execution
|
|
77
|
-
sql: The SQL query to execute
|
|
78
|
-
layer_name: Name of the cache layer (for logging)
|
|
79
|
-
|
|
80
|
-
Returns:
|
|
81
|
-
DataFrame with results, or error message string
|
|
82
|
-
"""
|
|
83
|
-
w: WorkspaceClient = warehouse.workspace_client
|
|
84
|
-
warehouse_id: str = str(warehouse.warehouse_id)
|
|
85
|
-
|
|
86
|
-
logger.debug(f"[{layer_name}] Executing cached SQL: {sql[:100]}...")
|
|
87
|
-
|
|
88
|
-
statement_response: StatementResponse = w.statement_execution.execute_statement(
|
|
89
|
-
statement=sql,
|
|
90
|
-
warehouse_id=warehouse_id,
|
|
91
|
-
wait_timeout="30s",
|
|
92
|
-
)
|
|
93
|
-
|
|
94
|
-
# Poll for completion if still running
|
|
95
|
-
while statement_response.status.state in [
|
|
96
|
-
StatementState.PENDING,
|
|
97
|
-
StatementState.RUNNING,
|
|
98
|
-
]:
|
|
99
|
-
statement_response = w.statement_execution.get_statement(
|
|
100
|
-
statement_response.statement_id
|
|
101
|
-
)
|
|
102
|
-
|
|
103
|
-
if statement_response.status.state != StatementState.SUCCEEDED:
|
|
104
|
-
error_msg: str = f"SQL execution failed: {statement_response.status}"
|
|
105
|
-
logger.error(f"[{layer_name}] {error_msg}")
|
|
106
|
-
return error_msg
|
|
107
|
-
|
|
108
|
-
# Convert to DataFrame
|
|
109
|
-
if statement_response.result and statement_response.result.data_array:
|
|
110
|
-
columns: list[str] = []
|
|
111
|
-
if statement_response.manifest and statement_response.manifest.schema:
|
|
112
|
-
columns = [col.name for col in statement_response.manifest.schema.columns]
|
|
113
|
-
elif hasattr(statement_response.result, "schema"):
|
|
114
|
-
columns = [col.name for col in statement_response.result.schema.columns]
|
|
115
|
-
|
|
116
|
-
data: list[list[Any]] = statement_response.result.data_array
|
|
117
|
-
if columns:
|
|
118
|
-
return pd.DataFrame(data, columns=columns)
|
|
119
|
-
else:
|
|
120
|
-
return pd.DataFrame(data)
|
|
121
|
-
|
|
122
|
-
return pd.DataFrame()
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Core utilities for Genie cache implementations.
|
|
3
|
+
|
|
4
|
+
This module provides shared utility functions used by different cache
|
|
5
|
+
implementations (LRU, Semantic, etc.). These are concrete implementations
|
|
6
|
+
of common operations needed across cache types.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
import pandas as pd
|
|
12
|
+
from databricks.sdk import WorkspaceClient
|
|
13
|
+
from databricks.sdk.service.sql import StatementResponse, StatementState
|
|
14
|
+
from loguru import logger
|
|
15
|
+
|
|
16
|
+
from dao_ai.config import WarehouseModel
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def execute_sql_via_warehouse(
|
|
20
|
+
warehouse: WarehouseModel,
|
|
21
|
+
sql: str,
|
|
22
|
+
layer_name: str = "cache",
|
|
23
|
+
) -> pd.DataFrame | str:
|
|
24
|
+
"""
|
|
25
|
+
Execute SQL using a Databricks warehouse and return results as DataFrame.
|
|
26
|
+
|
|
27
|
+
This is a shared utility for cache implementations that need to re-execute
|
|
28
|
+
cached SQL queries.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
warehouse: The warehouse configuration for SQL execution
|
|
32
|
+
sql: The SQL query to execute
|
|
33
|
+
layer_name: Name of the cache layer (for logging)
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
DataFrame with results, or error message string
|
|
37
|
+
"""
|
|
38
|
+
w: WorkspaceClient = warehouse.workspace_client
|
|
39
|
+
warehouse_id: str = str(warehouse.warehouse_id)
|
|
40
|
+
|
|
41
|
+
logger.debug(f"[{layer_name}] Executing cached SQL: {sql[:100]}...")
|
|
42
|
+
|
|
43
|
+
statement_response: StatementResponse = w.statement_execution.execute_statement(
|
|
44
|
+
statement=sql,
|
|
45
|
+
warehouse_id=warehouse_id,
|
|
46
|
+
wait_timeout="30s",
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
# Poll for completion if still running
|
|
50
|
+
while statement_response.status.state in [
|
|
51
|
+
StatementState.PENDING,
|
|
52
|
+
StatementState.RUNNING,
|
|
53
|
+
]:
|
|
54
|
+
statement_response = w.statement_execution.get_statement(
|
|
55
|
+
statement_response.statement_id
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
if statement_response.status.state != StatementState.SUCCEEDED:
|
|
59
|
+
error_msg: str = f"SQL execution failed: {statement_response.status}"
|
|
60
|
+
logger.error(f"[{layer_name}] {error_msg}")
|
|
61
|
+
return error_msg
|
|
62
|
+
|
|
63
|
+
# Convert to DataFrame
|
|
64
|
+
if statement_response.result and statement_response.result.data_array:
|
|
65
|
+
columns: list[str] = []
|
|
66
|
+
if statement_response.manifest and statement_response.manifest.schema:
|
|
67
|
+
columns = [col.name for col in statement_response.manifest.schema.columns]
|
|
68
|
+
|
|
69
|
+
data: list[list[Any]] = statement_response.result.data_array
|
|
70
|
+
if columns:
|
|
71
|
+
return pd.DataFrame(data, columns=columns)
|
|
72
|
+
else:
|
|
73
|
+
return pd.DataFrame(data)
|
|
74
|
+
|
|
75
|
+
return pd.DataFrame()
|
dao_ai/genie/cache/lru.py
CHANGED
|
@@ -96,9 +96,21 @@ class LRUCacheService(GenieServiceBase):
|
|
|
96
96
|
return timedelta(seconds=ttl)
|
|
97
97
|
|
|
98
98
|
@staticmethod
|
|
99
|
-
def _normalize_key(question: str) -> str:
|
|
100
|
-
"""
|
|
101
|
-
|
|
99
|
+
def _normalize_key(question: str, conversation_id: str | None = None) -> str:
|
|
100
|
+
"""
|
|
101
|
+
Normalize the question and conversation_id to create a consistent cache key.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
question: The question text
|
|
105
|
+
conversation_id: Optional conversation ID to include in the key
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
A normalized cache key combining question and conversation_id
|
|
109
|
+
"""
|
|
110
|
+
normalized_question = question.strip().lower()
|
|
111
|
+
if conversation_id:
|
|
112
|
+
return f"{conversation_id}::{normalized_question}"
|
|
113
|
+
return normalized_question
|
|
102
114
|
|
|
103
115
|
def _is_expired(self, entry: SQLCacheEntry) -> bool:
|
|
104
116
|
"""Check if a cache entry has exceeded its TTL. Returns False if TTL is disabled."""
|
|
@@ -192,8 +204,6 @@ class LRUCacheService(GenieServiceBase):
|
|
|
192
204
|
columns = [
|
|
193
205
|
col.name for col in statement_response.manifest.schema.columns
|
|
194
206
|
]
|
|
195
|
-
elif hasattr(statement_response.result, "schema"):
|
|
196
|
-
columns = [col.name for col in statement_response.result.schema.columns]
|
|
197
207
|
|
|
198
208
|
data: list[list[Any]] = statement_response.result.data_array
|
|
199
209
|
if columns:
|
|
@@ -205,17 +215,14 @@ class LRUCacheService(GenieServiceBase):
|
|
|
205
215
|
|
|
206
216
|
def ask_question(
|
|
207
217
|
self, question: str, conversation_id: str | None = None
|
|
208
|
-
) ->
|
|
218
|
+
) -> CacheResult:
|
|
209
219
|
"""
|
|
210
220
|
Ask a question, using cached SQL query if available.
|
|
211
221
|
|
|
212
222
|
On cache hit, re-executes the cached SQL to get fresh data.
|
|
213
|
-
|
|
223
|
+
Returns CacheResult with cache metadata.
|
|
214
224
|
"""
|
|
215
|
-
|
|
216
|
-
question, conversation_id
|
|
217
|
-
)
|
|
218
|
-
return result.response
|
|
225
|
+
return self.ask_question_with_cache_info(question, conversation_id)
|
|
219
226
|
|
|
220
227
|
@mlflow.trace(name="genie_lru_cache_lookup")
|
|
221
228
|
def ask_question_with_cache_info(
|
|
@@ -235,7 +242,7 @@ class LRUCacheService(GenieServiceBase):
|
|
|
235
242
|
Returns:
|
|
236
243
|
CacheResult with fresh response and cache metadata
|
|
237
244
|
"""
|
|
238
|
-
key: str = self._normalize_key(question)
|
|
245
|
+
key: str = self._normalize_key(question, conversation_id)
|
|
239
246
|
|
|
240
247
|
# Check cache
|
|
241
248
|
with self._lock:
|
|
@@ -244,17 +251,20 @@ class LRUCacheService(GenieServiceBase):
|
|
|
244
251
|
if cached is not None:
|
|
245
252
|
logger.info(
|
|
246
253
|
f"[{self.name}] Cache HIT: '{question[:50]}...' "
|
|
247
|
-
f"(cache_size={self.size}/{self.capacity})"
|
|
254
|
+
f"(conversation_id={conversation_id}, cache_size={self.size}/{self.capacity})"
|
|
248
255
|
)
|
|
249
256
|
|
|
250
257
|
# Re-execute the cached SQL to get fresh data
|
|
251
258
|
result: pd.DataFrame | str = self._execute_sql(cached.query)
|
|
252
259
|
|
|
260
|
+
# Use current conversation_id, not the cached one
|
|
253
261
|
response: GenieResponse = GenieResponse(
|
|
254
262
|
result=result,
|
|
255
263
|
query=cached.query,
|
|
256
264
|
description=cached.description,
|
|
257
|
-
conversation_id=
|
|
265
|
+
conversation_id=conversation_id
|
|
266
|
+
if conversation_id
|
|
267
|
+
else cached.conversation_id,
|
|
258
268
|
)
|
|
259
269
|
|
|
260
270
|
return CacheResult(response=response, cache_hit=True, served_by=self.name)
|
|
@@ -262,17 +272,30 @@ class LRUCacheService(GenieServiceBase):
|
|
|
262
272
|
# Cache miss - delegate to wrapped service
|
|
263
273
|
logger.info(
|
|
264
274
|
f"[{self.name}] Cache MISS: '{question[:50]}...' "
|
|
265
|
-
f"(cache_size={self.size}/{self.capacity}, delegating to {type(self.impl).__name__})"
|
|
275
|
+
f"(conversation_id={conversation_id}, cache_size={self.size}/{self.capacity}, delegating to {type(self.impl).__name__})"
|
|
266
276
|
)
|
|
267
277
|
|
|
268
|
-
|
|
278
|
+
result: CacheResult = self.impl.ask_question(question, conversation_id)
|
|
269
279
|
with self._lock:
|
|
270
|
-
self._put(key, response)
|
|
271
|
-
return CacheResult(response=response, cache_hit=False, served_by=None)
|
|
280
|
+
self._put(key, result.response)
|
|
281
|
+
return CacheResult(response=result.response, cache_hit=False, served_by=None)
|
|
272
282
|
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
283
|
+
@property
|
|
284
|
+
def space_id(self) -> str:
|
|
285
|
+
return self.impl.space_id
|
|
286
|
+
|
|
287
|
+
def invalidate(self, question: str, conversation_id: str | None = None) -> bool:
|
|
288
|
+
"""
|
|
289
|
+
Remove a specific entry from the cache.
|
|
290
|
+
|
|
291
|
+
Args:
|
|
292
|
+
question: The question text
|
|
293
|
+
conversation_id: Optional conversation ID to match
|
|
294
|
+
|
|
295
|
+
Returns:
|
|
296
|
+
True if the entry was found and removed, False otherwise
|
|
297
|
+
"""
|
|
298
|
+
key: str = self._normalize_key(question, conversation_id)
|
|
276
299
|
with self._lock:
|
|
277
300
|
if key in self._cache:
|
|
278
301
|
del self._cache[key]
|