dao-ai 0.0.19__py3-none-any.whl → 0.0.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/config.py +2 -1
- dao_ai/guardrails.py +2 -2
- dao_ai/memory/postgres.py +1 -132
- dao_ai/tools/human_in_the_loop.py +5 -1
- dao_ai/tools/unity_catalog.py +328 -13
- {dao_ai-0.0.19.dist-info → dao_ai-0.0.20.dist-info}/METADATA +1 -1
- {dao_ai-0.0.19.dist-info → dao_ai-0.0.20.dist-info}/RECORD +10 -10
- {dao_ai-0.0.19.dist-info → dao_ai-0.0.20.dist-info}/WHEEL +0 -0
- {dao_ai-0.0.19.dist-info → dao_ai-0.0.20.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.19.dist-info → dao_ai-0.0.20.dist-info}/licenses/LICENSE +0 -0
dao_ai/config.py
CHANGED
|
@@ -867,7 +867,7 @@ class PythonFunctionModel(BaseFunctionModel, HasFullName):
|
|
|
867
867
|
|
|
868
868
|
class FactoryFunctionModel(BaseFunctionModel, HasFullName):
|
|
869
869
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
870
|
-
args: Optional[dict[str,
|
|
870
|
+
args: Optional[dict[str, AnyVariable]] = Field(default_factory=dict)
|
|
871
871
|
type: Literal[FunctionType.FACTORY] = FunctionType.FACTORY
|
|
872
872
|
|
|
873
873
|
@property
|
|
@@ -963,6 +963,7 @@ class McpFunctionModel(BaseFunctionModel, HasFullName):
|
|
|
963
963
|
class UnityCatalogFunctionModel(BaseFunctionModel, HasFullName):
|
|
964
964
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
965
965
|
schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
|
|
966
|
+
partial_args: Optional[dict[str, AnyVariable]] = Field(default_factory=dict)
|
|
966
967
|
type: Literal[FunctionType.UNITY_CATALOG] = FunctionType.UNITY_CATALOG
|
|
967
968
|
|
|
968
969
|
@property
|
dao_ai/guardrails.py
CHANGED
|
@@ -87,12 +87,12 @@ def judge_node(guardrails: GuardrailModel) -> RunnableLike:
|
|
|
87
87
|
)
|
|
88
88
|
|
|
89
89
|
if eval_result["score"]:
|
|
90
|
-
logger.debug("
|
|
90
|
+
logger.debug("Response approved by judge")
|
|
91
91
|
logger.debug(f"Judge's comment: {eval_result['comment']}")
|
|
92
92
|
return
|
|
93
93
|
else:
|
|
94
94
|
# Otherwise, return the judge's critique as a new user message
|
|
95
|
-
logger.warning("
|
|
95
|
+
logger.warning("Judge requested improvements")
|
|
96
96
|
comment: str = eval_result["comment"]
|
|
97
97
|
logger.warning(f"Judge's critique: {comment}")
|
|
98
98
|
content: str = "\n".join([human_message.content, comment])
|
dao_ai/memory/postgres.py
CHANGED
|
@@ -20,137 +20,6 @@ from dao_ai.memory.base import (
|
|
|
20
20
|
)
|
|
21
21
|
|
|
22
22
|
|
|
23
|
-
class PatchedAsyncPostgresStore(AsyncPostgresStore):
|
|
24
|
-
"""
|
|
25
|
-
Patched version of AsyncPostgresStore that properly handles event loop initialization
|
|
26
|
-
and task lifecycle management.
|
|
27
|
-
|
|
28
|
-
The issues occur because:
|
|
29
|
-
1. AsyncBatchedBaseStore.__init__ calls asyncio.get_running_loop() and fails if no event loop is running
|
|
30
|
-
2. The background _task can complete/fail, causing assertions in asearch/other methods to fail
|
|
31
|
-
3. Destructor tries to access _task even when it doesn't exist
|
|
32
|
-
|
|
33
|
-
This patch ensures proper initialization and handles task lifecycle robustly.
|
|
34
|
-
"""
|
|
35
|
-
|
|
36
|
-
def __init__(self, *args, **kwargs):
|
|
37
|
-
# Ensure we have a running event loop before calling super().__init__()
|
|
38
|
-
loop = None
|
|
39
|
-
try:
|
|
40
|
-
loop = asyncio.get_running_loop()
|
|
41
|
-
except RuntimeError:
|
|
42
|
-
# No running loop - create one temporarily for initialization
|
|
43
|
-
loop = asyncio.new_event_loop()
|
|
44
|
-
asyncio.set_event_loop(loop)
|
|
45
|
-
|
|
46
|
-
try:
|
|
47
|
-
super().__init__(*args, **kwargs)
|
|
48
|
-
except Exception as e:
|
|
49
|
-
# If parent initialization fails, ensure _task is at least defined
|
|
50
|
-
if not hasattr(self, "_task"):
|
|
51
|
-
self._task = None
|
|
52
|
-
logger.warning(f"AsyncPostgresStore initialization failed: {e}")
|
|
53
|
-
raise
|
|
54
|
-
|
|
55
|
-
def _ensure_task_running(self):
|
|
56
|
-
"""
|
|
57
|
-
Ensure the background task is running. Recreate it if necessary.
|
|
58
|
-
"""
|
|
59
|
-
if not hasattr(self, "_task") or self._task is None:
|
|
60
|
-
logger.error("AsyncPostgresStore task not initialized")
|
|
61
|
-
raise RuntimeError("Store task not properly initialized")
|
|
62
|
-
|
|
63
|
-
if self._task.done():
|
|
64
|
-
logger.warning(
|
|
65
|
-
"AsyncPostgresStore background task completed, attempting to restart"
|
|
66
|
-
)
|
|
67
|
-
# Try to get the task exception for debugging
|
|
68
|
-
try:
|
|
69
|
-
exception = self._task.exception()
|
|
70
|
-
if exception:
|
|
71
|
-
logger.error(f"Background task failed with: {exception}")
|
|
72
|
-
else:
|
|
73
|
-
logger.info("Background task completed normally")
|
|
74
|
-
except Exception as e:
|
|
75
|
-
logger.warning(f"Could not determine task completion reason: {e}")
|
|
76
|
-
|
|
77
|
-
# Try to restart the task
|
|
78
|
-
try:
|
|
79
|
-
import weakref
|
|
80
|
-
|
|
81
|
-
from langgraph.store.base.batch import _run
|
|
82
|
-
|
|
83
|
-
self._task = self._loop.create_task(
|
|
84
|
-
_run(self._aqueue, weakref.ref(self))
|
|
85
|
-
)
|
|
86
|
-
logger.info("Successfully restarted AsyncPostgresStore background task")
|
|
87
|
-
except Exception as e:
|
|
88
|
-
logger.error(f"Failed to restart background task: {e}")
|
|
89
|
-
raise RuntimeError(
|
|
90
|
-
f"Store background task failed and could not be restarted: {e}"
|
|
91
|
-
)
|
|
92
|
-
|
|
93
|
-
async def asearch(
|
|
94
|
-
self,
|
|
95
|
-
namespace_prefix,
|
|
96
|
-
/,
|
|
97
|
-
*,
|
|
98
|
-
query=None,
|
|
99
|
-
filter=None,
|
|
100
|
-
limit=10,
|
|
101
|
-
offset=0,
|
|
102
|
-
refresh_ttl=None,
|
|
103
|
-
):
|
|
104
|
-
"""
|
|
105
|
-
Override asearch to handle task lifecycle issues gracefully.
|
|
106
|
-
"""
|
|
107
|
-
self._ensure_task_running()
|
|
108
|
-
|
|
109
|
-
# Call parent implementation if task is healthy
|
|
110
|
-
return await super().asearch(
|
|
111
|
-
namespace_prefix,
|
|
112
|
-
query=query,
|
|
113
|
-
filter=filter,
|
|
114
|
-
limit=limit,
|
|
115
|
-
offset=offset,
|
|
116
|
-
refresh_ttl=refresh_ttl,
|
|
117
|
-
)
|
|
118
|
-
|
|
119
|
-
async def aget(self, namespace, key, /, *, refresh_ttl=None):
|
|
120
|
-
"""Override aget with task lifecycle management."""
|
|
121
|
-
self._ensure_task_running()
|
|
122
|
-
return await super().aget(namespace, key, refresh_ttl=refresh_ttl)
|
|
123
|
-
|
|
124
|
-
async def aput(self, namespace, key, value, /, *, refresh_ttl=None):
|
|
125
|
-
"""Override aput with task lifecycle management."""
|
|
126
|
-
self._ensure_task_running()
|
|
127
|
-
return await super().aput(namespace, key, value, refresh_ttl=refresh_ttl)
|
|
128
|
-
|
|
129
|
-
async def adelete(self, namespace, key):
|
|
130
|
-
"""Override adelete with task lifecycle management."""
|
|
131
|
-
self._ensure_task_running()
|
|
132
|
-
return await super().adelete(namespace, key)
|
|
133
|
-
|
|
134
|
-
async def alist_namespaces(self, *, prefix=None):
|
|
135
|
-
"""Override alist_namespaces with task lifecycle management."""
|
|
136
|
-
self._ensure_task_running()
|
|
137
|
-
return await super().alist_namespaces(prefix=prefix)
|
|
138
|
-
|
|
139
|
-
def __del__(self):
|
|
140
|
-
"""
|
|
141
|
-
Override destructor to handle missing _task attribute gracefully.
|
|
142
|
-
"""
|
|
143
|
-
try:
|
|
144
|
-
# Only try to cancel if _task exists and is not None
|
|
145
|
-
if hasattr(self, "_task") and self._task is not None:
|
|
146
|
-
if not self._task.done():
|
|
147
|
-
self._task.cancel()
|
|
148
|
-
except Exception as e:
|
|
149
|
-
# Log but don't raise - destructors should not raise exceptions
|
|
150
|
-
logger.debug(f"AsyncPostgresStore destructor cleanup: {e}")
|
|
151
|
-
pass
|
|
152
|
-
|
|
153
|
-
|
|
154
23
|
class AsyncPostgresPoolManager:
|
|
155
24
|
_pools: dict[str, AsyncConnectionPool] = {}
|
|
156
25
|
_lock: asyncio.Lock = asyncio.Lock()
|
|
@@ -251,7 +120,7 @@ class AsyncPostgresStoreManager(StoreManagerBase):
|
|
|
251
120
|
)
|
|
252
121
|
|
|
253
122
|
# Create store with the shared pool (using patched version)
|
|
254
|
-
self._store =
|
|
123
|
+
self._store = AsyncPostgresStore(conn=self.pool)
|
|
255
124
|
|
|
256
125
|
await self._store.setup()
|
|
257
126
|
|
|
@@ -87,7 +87,11 @@ def as_human_in_the_loop(
|
|
|
87
87
|
if isinstance(function, BaseFunctionModel):
|
|
88
88
|
human_in_the_loop: HumanInTheLoopModel | None = function.human_in_the_loop
|
|
89
89
|
if human_in_the_loop:
|
|
90
|
-
|
|
90
|
+
# Get tool name safely - handle RunnableBinding objects
|
|
91
|
+
tool_name = getattr(tool, "name", None) or getattr(
|
|
92
|
+
getattr(tool, "bound", None), "name", "unknown_tool"
|
|
93
|
+
)
|
|
94
|
+
logger.debug(f"Adding human-in-the-loop to tool: {tool_name}")
|
|
91
95
|
tool = add_human_in_the_loop(
|
|
92
96
|
tool=tool,
|
|
93
97
|
interrupt_config=human_in_the_loop.interupt_config,
|
dao_ai/tools/unity_catalog.py
CHANGED
|
@@ -1,14 +1,18 @@
|
|
|
1
|
-
from typing import Sequence
|
|
1
|
+
from typing import Any, Dict, Optional, Sequence, Union
|
|
2
2
|
|
|
3
|
-
from
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
)
|
|
3
|
+
from databricks.sdk import WorkspaceClient
|
|
4
|
+
from databricks.sdk.service.catalog import PermissionsChange, Privilege
|
|
5
|
+
from databricks_langchain import DatabricksFunctionClient, UCFunctionToolkit
|
|
7
6
|
from langchain_core.runnables.base import RunnableLike
|
|
7
|
+
from langchain_core.tools import StructuredTool
|
|
8
8
|
from loguru import logger
|
|
9
9
|
|
|
10
10
|
from dao_ai.config import (
|
|
11
|
+
AnyVariable,
|
|
12
|
+
CompositeVariableModel,
|
|
13
|
+
ToolModel,
|
|
11
14
|
UnityCatalogFunctionModel,
|
|
15
|
+
value_of,
|
|
12
16
|
)
|
|
13
17
|
from dao_ai.tools.human_in_the_loop import as_human_in_the_loop
|
|
14
18
|
|
|
@@ -32,19 +36,330 @@ def create_uc_tools(
|
|
|
32
36
|
|
|
33
37
|
logger.debug(f"create_uc_tools: {function}")
|
|
34
38
|
|
|
39
|
+
original_function_model = None
|
|
35
40
|
if isinstance(function, UnityCatalogFunctionModel):
|
|
36
|
-
|
|
41
|
+
original_function_model = function
|
|
42
|
+
function_name = function.full_name
|
|
43
|
+
else:
|
|
44
|
+
function_name = function
|
|
37
45
|
|
|
38
|
-
|
|
46
|
+
# Determine which tools to create
|
|
47
|
+
if original_function_model and original_function_model.partial_args:
|
|
48
|
+
logger.debug("Found partial_args, creating custom tool with partial arguments")
|
|
49
|
+
# Create a ToolModel wrapper for the with_partial_args function
|
|
50
|
+
tool_model = ToolModel(
|
|
51
|
+
name=original_function_model.name, function=original_function_model
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
# Use with_partial_args to create the authenticated tool
|
|
55
|
+
tools = [with_partial_args(tool_model, original_function_model.partial_args)]
|
|
56
|
+
else:
|
|
57
|
+
# Fallback to standard UC toolkit approach
|
|
58
|
+
client: DatabricksFunctionClient = DatabricksFunctionClient()
|
|
59
|
+
|
|
60
|
+
toolkit: UCFunctionToolkit = UCFunctionToolkit(
|
|
61
|
+
function_names=[function_name], client=client
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
tools = toolkit.tools or []
|
|
65
|
+
logger.debug(f"Retrieved tools: {tools}")
|
|
66
|
+
|
|
67
|
+
# Apply human-in-the-loop wrapper to all tools and return
|
|
68
|
+
return [as_human_in_the_loop(tool=tool, function=function_name) for tool in tools]
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _execute_uc_function(
|
|
72
|
+
client: DatabricksFunctionClient,
|
|
73
|
+
function_name: str,
|
|
74
|
+
partial_args: Dict[str, str] = None,
|
|
75
|
+
**kwargs: Any,
|
|
76
|
+
) -> str:
|
|
77
|
+
"""Execute Unity Catalog function with partial args and provided parameters."""
|
|
78
|
+
|
|
79
|
+
# Start with partial args if provided
|
|
80
|
+
all_params: Dict[str, Any] = dict(partial_args) if partial_args else {}
|
|
81
|
+
|
|
82
|
+
# Add any additional kwargs
|
|
83
|
+
all_params.update(kwargs)
|
|
39
84
|
|
|
40
|
-
|
|
41
|
-
|
|
85
|
+
logger.debug(
|
|
86
|
+
f"Calling UC function {function_name} with parameters: {list(all_params.keys())}"
|
|
42
87
|
)
|
|
43
88
|
|
|
44
|
-
|
|
89
|
+
result = client.execute_function(function_name=function_name, parameters=all_params)
|
|
90
|
+
|
|
91
|
+
# Handle errors and extract result
|
|
92
|
+
if hasattr(result, "error") and result.error:
|
|
93
|
+
logger.error(f"Unity Catalog function error: {result.error}")
|
|
94
|
+
raise RuntimeError(f"Function execution failed: {result.error}")
|
|
95
|
+
|
|
96
|
+
result_value: str = result.value if hasattr(result, "value") else str(result)
|
|
97
|
+
logger.debug(f"UC function result: {result_value}")
|
|
98
|
+
return result_value
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def _grant_function_permissions(
|
|
102
|
+
function_name: str,
|
|
103
|
+
client_id: str,
|
|
104
|
+
host: Optional[str] = None,
|
|
105
|
+
) -> None:
|
|
106
|
+
"""
|
|
107
|
+
Grant comprehensive permissions to the service principal for Unity Catalog function execution.
|
|
108
|
+
|
|
109
|
+
This includes:
|
|
110
|
+
- EXECUTE permission on the function itself
|
|
111
|
+
- USE permission on the containing schema
|
|
112
|
+
- USE permission on the containing catalog
|
|
113
|
+
"""
|
|
114
|
+
try:
|
|
115
|
+
# Initialize workspace client
|
|
116
|
+
workspace_client = WorkspaceClient(host=host) if host else WorkspaceClient()
|
|
117
|
+
|
|
118
|
+
# Parse the function name to get catalog and schema
|
|
119
|
+
parts = function_name.split(".")
|
|
120
|
+
if len(parts) != 3:
|
|
121
|
+
logger.warning(
|
|
122
|
+
f"Invalid function name format: {function_name}. Expected catalog.schema.function"
|
|
123
|
+
)
|
|
124
|
+
return
|
|
125
|
+
|
|
126
|
+
catalog_name, schema_name, func_name = parts
|
|
127
|
+
schema_full_name = f"{catalog_name}.{schema_name}"
|
|
128
|
+
|
|
129
|
+
logger.debug(
|
|
130
|
+
f"Granting comprehensive permissions on function {function_name} to principal {client_id}"
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
# 1. Grant EXECUTE permission on the function
|
|
134
|
+
try:
|
|
135
|
+
workspace_client.grants.update(
|
|
136
|
+
securable_type="function",
|
|
137
|
+
full_name=function_name,
|
|
138
|
+
changes=[
|
|
139
|
+
PermissionsChange(principal=client_id, add=[Privilege.EXECUTE])
|
|
140
|
+
],
|
|
141
|
+
)
|
|
142
|
+
logger.debug(f"Granted EXECUTE on function {function_name}")
|
|
143
|
+
except Exception as e:
|
|
144
|
+
logger.warning(f"Failed to grant EXECUTE on function {function_name}: {e}")
|
|
145
|
+
|
|
146
|
+
# 2. Grant USE_SCHEMA permission on the schema
|
|
147
|
+
try:
|
|
148
|
+
workspace_client.grants.update(
|
|
149
|
+
securable_type="schema",
|
|
150
|
+
full_name=schema_full_name,
|
|
151
|
+
changes=[
|
|
152
|
+
PermissionsChange(
|
|
153
|
+
principal=client_id,
|
|
154
|
+
add=[Privilege.USE_SCHEMA],
|
|
155
|
+
)
|
|
156
|
+
],
|
|
157
|
+
)
|
|
158
|
+
logger.debug(f"Granted USE_SCHEMA on schema {schema_full_name}")
|
|
159
|
+
except Exception as e:
|
|
160
|
+
logger.warning(
|
|
161
|
+
f"Failed to grant USE_SCHEMA on schema {schema_full_name}: {e}"
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
# 3. Grant USE_CATALOG and BROWSE permissions on the catalog
|
|
165
|
+
try:
|
|
166
|
+
workspace_client.grants.update(
|
|
167
|
+
securable_type="catalog",
|
|
168
|
+
full_name=catalog_name,
|
|
169
|
+
changes=[
|
|
170
|
+
PermissionsChange(
|
|
171
|
+
principal=client_id,
|
|
172
|
+
add=[Privilege.USE_CATALOG, Privilege.BROWSE],
|
|
173
|
+
)
|
|
174
|
+
],
|
|
175
|
+
)
|
|
176
|
+
logger.debug(f"Granted USE_CATALOG and BROWSE on catalog {catalog_name}")
|
|
177
|
+
except Exception as e:
|
|
178
|
+
logger.warning(
|
|
179
|
+
f"Failed to grant catalog permissions on {catalog_name}: {e}"
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
logger.debug(
|
|
183
|
+
f"Successfully granted comprehensive permissions on {function_name} to {client_id}"
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
except Exception as e:
|
|
187
|
+
logger.warning(
|
|
188
|
+
f"Failed to grant permissions on function {function_name} to {client_id}: {e}"
|
|
189
|
+
)
|
|
190
|
+
# Don't fail the tool creation if permission granting fails
|
|
191
|
+
pass
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def _create_filtered_schema(original_schema: type, exclude_fields: set[str]) -> type:
|
|
195
|
+
"""
|
|
196
|
+
Create a new Pydantic model that excludes specified fields from the original schema.
|
|
197
|
+
|
|
198
|
+
Args:
|
|
199
|
+
original_schema: The original Pydantic model class
|
|
200
|
+
exclude_fields: Set of field names to exclude from the schema
|
|
201
|
+
|
|
202
|
+
Returns:
|
|
203
|
+
A new Pydantic model class with the specified fields removed
|
|
204
|
+
"""
|
|
205
|
+
from pydantic import BaseModel, Field, create_model
|
|
206
|
+
from pydantic.fields import PydanticUndefined
|
|
207
|
+
|
|
208
|
+
try:
|
|
209
|
+
# Get the original model's fields (Pydantic v2)
|
|
210
|
+
original_fields = original_schema.model_fields
|
|
211
|
+
filtered_field_definitions = {}
|
|
212
|
+
|
|
213
|
+
for name, field in original_fields.items():
|
|
214
|
+
if name not in exclude_fields:
|
|
215
|
+
# Reconstruct the field definition for create_model
|
|
216
|
+
field_type = field.annotation
|
|
217
|
+
field_default = (
|
|
218
|
+
field.default if field.default is not PydanticUndefined else ...
|
|
219
|
+
)
|
|
220
|
+
field_info = Field(default=field_default, description=field.description)
|
|
221
|
+
filtered_field_definitions[name] = (field_type, field_info)
|
|
222
|
+
|
|
223
|
+
# If no fields remain after filtering, return a generic empty schema
|
|
224
|
+
if not filtered_field_definitions:
|
|
225
|
+
|
|
226
|
+
class EmptySchema(BaseModel):
|
|
227
|
+
"""Unity Catalog function with all parameters provided via partial args."""
|
|
45
228
|
|
|
46
|
-
|
|
229
|
+
pass
|
|
47
230
|
|
|
48
|
-
|
|
231
|
+
return EmptySchema
|
|
232
|
+
|
|
233
|
+
# Create the new model dynamically
|
|
234
|
+
model_name = f"Filtered{original_schema.__name__}"
|
|
235
|
+
docstring = getattr(
|
|
236
|
+
original_schema, "__doc__", "Filtered Unity Catalog function parameters."
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
filtered_model = create_model(
|
|
240
|
+
model_name, __doc__=docstring, **filtered_field_definitions
|
|
241
|
+
)
|
|
242
|
+
return filtered_model
|
|
243
|
+
|
|
244
|
+
except Exception as e:
|
|
245
|
+
logger.warning(f"Failed to create filtered schema: {e}")
|
|
246
|
+
|
|
247
|
+
# Fallback to generic schema
|
|
248
|
+
class GenericFilteredSchema(BaseModel):
|
|
249
|
+
"""Generic filtered schema for Unity Catalog function."""
|
|
250
|
+
|
|
251
|
+
pass
|
|
252
|
+
|
|
253
|
+
return GenericFilteredSchema
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def with_partial_args(
|
|
257
|
+
tool: Union[ToolModel, Dict[str, Any]],
|
|
258
|
+
partial_args: dict[str, AnyVariable] = {},
|
|
259
|
+
) -> StructuredTool:
|
|
260
|
+
"""
|
|
261
|
+
Create a Unity Catalog tool with partial arguments pre-filled.
|
|
262
|
+
|
|
263
|
+
This function creates a wrapper tool that calls the UC function with partial arguments
|
|
264
|
+
already resolved, so the caller only needs to provide the remaining parameters.
|
|
265
|
+
|
|
266
|
+
Args:
|
|
267
|
+
tool: ToolModel containing the Unity Catalog function configuration
|
|
268
|
+
partial_args: Dictionary of arguments to pre-fill in the tool
|
|
269
|
+
|
|
270
|
+
Returns:
|
|
271
|
+
StructuredTool: A LangChain tool with partial arguments pre-filled
|
|
272
|
+
"""
|
|
273
|
+
from unitycatalog.ai.langchain.toolkit import generate_function_input_params_schema
|
|
274
|
+
|
|
275
|
+
logger.debug(f"with_partial_args: {tool}")
|
|
276
|
+
|
|
277
|
+
# Convert dict-based variables to CompositeVariableModel and resolve their values
|
|
278
|
+
resolved_args = {}
|
|
279
|
+
for k, v in partial_args.items():
|
|
280
|
+
if isinstance(v, dict):
|
|
281
|
+
resolved_args[k] = value_of(CompositeVariableModel(**v))
|
|
282
|
+
else:
|
|
283
|
+
resolved_args[k] = value_of(v)
|
|
284
|
+
|
|
285
|
+
logger.debug(f"Resolved partial args: {resolved_args.keys()}")
|
|
286
|
+
|
|
287
|
+
if isinstance(tool, dict):
|
|
288
|
+
tool = ToolModel(**tool)
|
|
289
|
+
|
|
290
|
+
unity_catalog_function = tool.function
|
|
291
|
+
if isinstance(unity_catalog_function, dict):
|
|
292
|
+
unity_catalog_function = UnityCatalogFunctionModel(**unity_catalog_function)
|
|
293
|
+
|
|
294
|
+
function_name: str = unity_catalog_function.full_name
|
|
295
|
+
logger.debug(f"Creating UC tool with partial args for: {function_name}")
|
|
296
|
+
|
|
297
|
+
# Grant permissions if we have credentials
|
|
298
|
+
if "client_id" in resolved_args:
|
|
299
|
+
client_id: str = resolved_args["client_id"]
|
|
300
|
+
host: Optional[str] = resolved_args.get("host")
|
|
301
|
+
try:
|
|
302
|
+
_grant_function_permissions(function_name, client_id, host)
|
|
303
|
+
except Exception as e:
|
|
304
|
+
logger.warning(f"Failed to grant permissions: {e}")
|
|
305
|
+
|
|
306
|
+
# Create the client for function execution
|
|
307
|
+
client: DatabricksFunctionClient = DatabricksFunctionClient()
|
|
308
|
+
|
|
309
|
+
# Try to get the function schema for better tool definition
|
|
310
|
+
try:
|
|
311
|
+
function_info = client.get_function(function_name)
|
|
312
|
+
schema_info = generate_function_input_params_schema(function_info)
|
|
313
|
+
tool_description = (
|
|
314
|
+
function_info.comment or f"Unity Catalog function: {function_name}"
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
logger.debug(
|
|
318
|
+
f"Generated schema for function {function_name}: {schema_info.pydantic_model}"
|
|
319
|
+
)
|
|
320
|
+
logger.debug(f"Tool description: {tool_description}")
|
|
321
|
+
|
|
322
|
+
# Create a modified schema that excludes partial args
|
|
323
|
+
original_schema = schema_info.pydantic_model
|
|
324
|
+
schema_model = _create_filtered_schema(original_schema, resolved_args.keys())
|
|
325
|
+
logger.debug(
|
|
326
|
+
f"Filtered schema excludes partial args: {list(resolved_args.keys())}"
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
except Exception as e:
|
|
330
|
+
logger.warning(f"Could not introspect function {function_name}: {e}")
|
|
331
|
+
# Fallback to a generic schema
|
|
332
|
+
from pydantic import BaseModel
|
|
333
|
+
|
|
334
|
+
class GenericUCParams(BaseModel):
|
|
335
|
+
"""Generic parameters for Unity Catalog function."""
|
|
336
|
+
|
|
337
|
+
pass
|
|
338
|
+
|
|
339
|
+
schema_model = GenericUCParams
|
|
340
|
+
tool_description = f"Unity Catalog function: {function_name}"
|
|
341
|
+
|
|
342
|
+
# Create a wrapper function that calls _execute_uc_function with partial args
|
|
343
|
+
def uc_function_wrapper(**kwargs) -> str:
|
|
344
|
+
"""Wrapper function that executes Unity Catalog function with partial args."""
|
|
345
|
+
return _execute_uc_function(
|
|
346
|
+
client=client,
|
|
347
|
+
function_name=function_name,
|
|
348
|
+
partial_args=resolved_args,
|
|
349
|
+
**kwargs,
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
# Set the function name for the decorator
|
|
353
|
+
uc_function_wrapper.__name__ = tool.name or function_name.replace(".", "_")
|
|
354
|
+
|
|
355
|
+
# Create the tool using LangChain's StructuredTool
|
|
356
|
+
from langchain_core.tools import StructuredTool
|
|
357
|
+
|
|
358
|
+
partial_tool = StructuredTool.from_function(
|
|
359
|
+
func=uc_function_wrapper,
|
|
360
|
+
name=tool.name or function_name.replace(".", "_"),
|
|
361
|
+
description=tool_description,
|
|
362
|
+
args_schema=schema_model,
|
|
363
|
+
)
|
|
49
364
|
|
|
50
|
-
return
|
|
365
|
+
return partial_tool
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: dao-ai
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.20
|
|
4
4
|
Summary: DAO AI: A modular, multi-agent orchestration framework for complex AI workflows. Supports agent handoff, tool integration, and dynamic configuration via YAML.
|
|
5
5
|
Project-URL: Homepage, https://github.com/natefleming/dao-ai
|
|
6
6
|
Project-URL: Documentation, https://natefleming.github.io/dao-ai
|
|
@@ -3,9 +3,9 @@ dao_ai/agent_as_code.py,sha256=kPSeDz2-1jRaed1TMs4LA3VECoyqe9_Ed2beRLB9gXQ,472
|
|
|
3
3
|
dao_ai/catalog.py,sha256=sPZpHTD3lPx4EZUtIWeQV7VQM89WJ6YH__wluk1v2lE,4947
|
|
4
4
|
dao_ai/chat_models.py,sha256=uhwwOTeLyHWqoTTgHrs4n5iSyTwe4EQcLKnh3jRxPWI,8626
|
|
5
5
|
dao_ai/cli.py,sha256=Aez2TQW3Q8Ho1IaIkRggt0NevDxAAVPjXkePC5GPJF0,20429
|
|
6
|
-
dao_ai/config.py,sha256=
|
|
6
|
+
dao_ai/config.py,sha256=ZO5ei45gnhqg1BtD0R9aekJz4ClmiTw2GHhOk4Idil4,51958
|
|
7
7
|
dao_ai/graph.py,sha256=gmD9mxODfXuvn9xWeBfewm1FiuVAWMLEdnZz7DNmSH0,7859
|
|
8
|
-
dao_ai/guardrails.py,sha256
|
|
8
|
+
dao_ai/guardrails.py,sha256=4TKArDONRy8RwHzOT1plZ1rhy3x9GF_aeGpPCRl6wYA,4016
|
|
9
9
|
dao_ai/messages.py,sha256=xl_3-WcFqZKCFCiov8sZOPljTdM3gX3fCHhxq-xFg2U,7005
|
|
10
10
|
dao_ai/models.py,sha256=Xb23U-lhDG8KyNRIijcJ4InluadlaGNy4rrYx7Cjgfg,26939
|
|
11
11
|
dao_ai/nodes.py,sha256=SSuFNTXOdFaKg_aX-yUkQO7fM9wvNGu14lPXKDapU1U,8461
|
|
@@ -19,7 +19,7 @@ dao_ai/hooks/core.py,sha256=ZShHctUSoauhBgdf1cecy9-D7J6-sGn-pKjuRMumW5U,6663
|
|
|
19
19
|
dao_ai/memory/__init__.py,sha256=1kHx_p9abKYFQ6EYD05nuc1GS5HXVEpufmjBGw_7Uho,260
|
|
20
20
|
dao_ai/memory/base.py,sha256=99nfr2UZJ4jmfTL_KrqUlRSCoRxzkZyWyx5WqeUoMdQ,338
|
|
21
21
|
dao_ai/memory/core.py,sha256=g7chjBgVgx3iKjR2hghl0QL1j3802uIM_e7mgszur9M,4151
|
|
22
|
-
dao_ai/memory/postgres.py,sha256=
|
|
22
|
+
dao_ai/memory/postgres.py,sha256=pxxMjGotgqjrKhx0lVR3EAjSZTQgBpiPZOB0-cyjprc,12505
|
|
23
23
|
dao_ai/providers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
24
24
|
dao_ai/providers/base.py,sha256=-fjKypCOk28h6vioPfMj9YZSw_3Kcbi2nMuAyY7vX9k,1383
|
|
25
25
|
dao_ai/providers/databricks.py,sha256=fZ8mGotfA3W3t5yUej2xGmGHSybjBFYr895mOctT418,28203
|
|
@@ -27,14 +27,14 @@ dao_ai/tools/__init__.py,sha256=ye6MHaJY7tUnJ8336YJiLxuZr55zDPNdOw6gm7j5jlc,1103
|
|
|
27
27
|
dao_ai/tools/agent.py,sha256=WbQnyziiT12TLMrA7xK0VuOU029tdmUBXbUl-R1VZ0Q,1886
|
|
28
28
|
dao_ai/tools/core.py,sha256=Kei33S8vrmvPOAyrFNekaWmV2jqZ-IPS1QDSvU7RZF0,1984
|
|
29
29
|
dao_ai/tools/genie.py,sha256=GzV5lfDYKmzW_lSLxAsPaTwnzX6GxQOB1UcLaTDqpfY,2787
|
|
30
|
-
dao_ai/tools/human_in_the_loop.py,sha256=
|
|
30
|
+
dao_ai/tools/human_in_the_loop.py,sha256=yk35MO9eNETnYFH-sqlgR-G24TrEgXpJlnZUustsLkI,3681
|
|
31
31
|
dao_ai/tools/mcp.py,sha256=auEt_dwv4J26fr5AgLmwmnAsI894-cyuvkvjItzAUxs,4419
|
|
32
32
|
dao_ai/tools/python.py,sha256=XcQiTMshZyLUTVR5peB3vqsoUoAAy8gol9_pcrhddfI,1831
|
|
33
33
|
dao_ai/tools/time.py,sha256=Y-23qdnNHzwjvnfkWvYsE7PoWS1hfeKy44tA7sCnNac,8759
|
|
34
|
-
dao_ai/tools/unity_catalog.py,sha256=
|
|
34
|
+
dao_ai/tools/unity_catalog.py,sha256=uX_h52BuBAr4c9UeqSMI7DNz3BPRLeai5tBVW4sJqRI,13113
|
|
35
35
|
dao_ai/tools/vector_search.py,sha256=EDYQs51zIPaAP0ma1D81wJT77GQ-v-cjb2XrFVWfWdg,2621
|
|
36
|
-
dao_ai-0.0.
|
|
37
|
-
dao_ai-0.0.
|
|
38
|
-
dao_ai-0.0.
|
|
39
|
-
dao_ai-0.0.
|
|
40
|
-
dao_ai-0.0.
|
|
36
|
+
dao_ai-0.0.20.dist-info/METADATA,sha256=gWNRLhswz5sCe1vxbBQ6dGlgiObI9nI829Q5DQRqRRY,41380
|
|
37
|
+
dao_ai-0.0.20.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
38
|
+
dao_ai-0.0.20.dist-info/entry_points.txt,sha256=Xa-UFyc6gWGwMqMJOt06ZOog2vAfygV_DSwg1AiP46g,43
|
|
39
|
+
dao_ai-0.0.20.dist-info/licenses/LICENSE,sha256=YZt3W32LtPYruuvHE9lGk2bw6ZPMMJD8yLrjgHybyz4,1069
|
|
40
|
+
dao_ai-0.0.20.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|