dao-ai 0.0.28__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/__init__.py +29 -0
- dao_ai/agent_as_code.py +2 -5
- dao_ai/cli.py +245 -40
- dao_ai/config.py +1491 -370
- dao_ai/genie/__init__.py +38 -0
- dao_ai/genie/cache/__init__.py +43 -0
- dao_ai/genie/cache/base.py +72 -0
- dao_ai/genie/cache/core.py +79 -0
- dao_ai/genie/cache/lru.py +347 -0
- dao_ai/genie/cache/semantic.py +970 -0
- dao_ai/genie/core.py +35 -0
- dao_ai/graph.py +27 -253
- dao_ai/hooks/__init__.py +9 -6
- dao_ai/hooks/core.py +27 -195
- dao_ai/logging.py +56 -0
- dao_ai/memory/__init__.py +10 -0
- dao_ai/memory/core.py +65 -30
- dao_ai/memory/databricks.py +402 -0
- dao_ai/memory/postgres.py +79 -38
- dao_ai/messages.py +6 -4
- dao_ai/middleware/__init__.py +125 -0
- dao_ai/middleware/assertions.py +806 -0
- dao_ai/middleware/base.py +50 -0
- dao_ai/middleware/core.py +67 -0
- dao_ai/middleware/guardrails.py +420 -0
- dao_ai/middleware/human_in_the_loop.py +232 -0
- dao_ai/middleware/message_validation.py +586 -0
- dao_ai/middleware/summarization.py +197 -0
- dao_ai/models.py +1306 -114
- dao_ai/nodes.py +245 -159
- dao_ai/optimization.py +674 -0
- dao_ai/orchestration/__init__.py +52 -0
- dao_ai/orchestration/core.py +294 -0
- dao_ai/orchestration/supervisor.py +278 -0
- dao_ai/orchestration/swarm.py +271 -0
- dao_ai/prompts.py +128 -31
- dao_ai/providers/databricks.py +573 -601
- dao_ai/state.py +157 -21
- dao_ai/tools/__init__.py +13 -5
- dao_ai/tools/agent.py +1 -3
- dao_ai/tools/core.py +64 -11
- dao_ai/tools/email.py +232 -0
- dao_ai/tools/genie.py +144 -294
- dao_ai/tools/mcp.py +223 -155
- dao_ai/tools/memory.py +50 -0
- dao_ai/tools/python.py +9 -14
- dao_ai/tools/search.py +14 -0
- dao_ai/tools/slack.py +22 -10
- dao_ai/tools/sql.py +202 -0
- dao_ai/tools/time.py +30 -7
- dao_ai/tools/unity_catalog.py +165 -88
- dao_ai/tools/vector_search.py +331 -221
- dao_ai/utils.py +166 -20
- dao_ai-0.1.2.dist-info/METADATA +455 -0
- dao_ai-0.1.2.dist-info/RECORD +64 -0
- dao_ai/chat_models.py +0 -204
- dao_ai/guardrails.py +0 -112
- dao_ai/tools/human_in_the_loop.py +0 -100
- dao_ai-0.0.28.dist-info/METADATA +0 -1168
- dao_ai-0.0.28.dist-info/RECORD +0 -41
- {dao_ai-0.0.28.dist-info → dao_ai-0.1.2.dist-info}/WHEEL +0 -0
- {dao_ai-0.0.28.dist-info → dao_ai-0.1.2.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.28.dist-info → dao_ai-0.1.2.dist-info}/licenses/LICENSE +0 -0
dao_ai/tools/unity_catalog.py
CHANGED
|
@@ -1,20 +1,21 @@
|
|
|
1
|
-
from typing import Any, Dict, Optional, Sequence,
|
|
1
|
+
from typing import Any, Dict, Optional, Sequence, Set
|
|
2
2
|
|
|
3
3
|
from databricks.sdk import WorkspaceClient
|
|
4
|
-
from databricks.sdk.service.catalog import PermissionsChange, Privilege
|
|
4
|
+
from databricks.sdk.service.catalog import FunctionInfo, PermissionsChange, Privilege
|
|
5
5
|
from databricks_langchain import DatabricksFunctionClient, UCFunctionToolkit
|
|
6
6
|
from langchain_core.runnables.base import RunnableLike
|
|
7
7
|
from langchain_core.tools import StructuredTool
|
|
8
8
|
from loguru import logger
|
|
9
|
+
from pydantic import BaseModel
|
|
10
|
+
from unitycatalog.ai.core.base import FunctionExecutionResult
|
|
9
11
|
|
|
10
12
|
from dao_ai.config import (
|
|
11
13
|
AnyVariable,
|
|
12
14
|
CompositeVariableModel,
|
|
13
|
-
ToolModel,
|
|
14
15
|
UnityCatalogFunctionModel,
|
|
15
16
|
value_of,
|
|
16
17
|
)
|
|
17
|
-
from dao_ai.
|
|
18
|
+
from dao_ai.utils import normalize_host
|
|
18
19
|
|
|
19
20
|
|
|
20
21
|
def create_uc_tools(
|
|
@@ -33,45 +34,48 @@ def create_uc_tools(
|
|
|
33
34
|
Returns:
|
|
34
35
|
A sequence of BaseTool objects that wrap the specified UC functions
|
|
35
36
|
"""
|
|
37
|
+
original_function_model: UnityCatalogFunctionModel | None = None
|
|
38
|
+
workspace_client: WorkspaceClient | None = None
|
|
39
|
+
function_name: str
|
|
36
40
|
|
|
37
|
-
logger.debug(f"create_uc_tools: {function}")
|
|
38
|
-
|
|
39
|
-
original_function_model = None
|
|
40
41
|
if isinstance(function, UnityCatalogFunctionModel):
|
|
41
42
|
original_function_model = function
|
|
42
|
-
function_name = function.full_name
|
|
43
|
+
function_name = function.resource.full_name
|
|
44
|
+
workspace_client = function.resource.workspace_client
|
|
43
45
|
else:
|
|
44
46
|
function_name = function
|
|
45
47
|
|
|
48
|
+
logger.trace("Creating UC tools", function_name=function_name)
|
|
49
|
+
|
|
46
50
|
# Determine which tools to create
|
|
51
|
+
tools: list[RunnableLike]
|
|
47
52
|
if original_function_model and original_function_model.partial_args:
|
|
48
|
-
logger.debug(
|
|
49
|
-
|
|
50
|
-
tool_model = ToolModel(
|
|
51
|
-
name=original_function_model.name, function=original_function_model
|
|
53
|
+
logger.debug(
|
|
54
|
+
"Creating custom tool with partial arguments", function_name=function_name
|
|
52
55
|
)
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
tools = [with_partial_args(tool_model, original_function_model.partial_args)]
|
|
56
|
+
# Use with_partial_args directly with UnityCatalogFunctionModel
|
|
57
|
+
tools = [with_partial_args(original_function_model)]
|
|
56
58
|
else:
|
|
57
59
|
# Fallback to standard UC toolkit approach
|
|
58
|
-
client: DatabricksFunctionClient = DatabricksFunctionClient(
|
|
60
|
+
client: DatabricksFunctionClient = DatabricksFunctionClient(
|
|
61
|
+
client=workspace_client
|
|
62
|
+
)
|
|
59
63
|
|
|
60
64
|
toolkit: UCFunctionToolkit = UCFunctionToolkit(
|
|
61
65
|
function_names=[function_name], client=client
|
|
62
66
|
)
|
|
63
67
|
|
|
64
68
|
tools = toolkit.tools or []
|
|
65
|
-
logger.
|
|
69
|
+
logger.trace("Retrieved tools", tools_count=len(tools))
|
|
66
70
|
|
|
67
|
-
#
|
|
68
|
-
return
|
|
71
|
+
# HITL is now handled at middleware level via HumanInTheLoopMiddleware
|
|
72
|
+
return list(tools)
|
|
69
73
|
|
|
70
74
|
|
|
71
75
|
def _execute_uc_function(
|
|
72
76
|
client: DatabricksFunctionClient,
|
|
73
77
|
function_name: str,
|
|
74
|
-
partial_args: Dict[str, str] = None,
|
|
78
|
+
partial_args: Optional[Dict[str, str]] = None,
|
|
75
79
|
**kwargs: Any,
|
|
76
80
|
) -> str:
|
|
77
81
|
"""Execute Unity Catalog function with partial args and provided parameters."""
|
|
@@ -83,18 +87,30 @@ def _execute_uc_function(
|
|
|
83
87
|
all_params.update(kwargs)
|
|
84
88
|
|
|
85
89
|
logger.debug(
|
|
86
|
-
|
|
90
|
+
"Calling UC function",
|
|
91
|
+
function_name=function_name,
|
|
92
|
+
parameters=list(all_params.keys()),
|
|
87
93
|
)
|
|
88
94
|
|
|
89
|
-
result = client.execute_function(
|
|
95
|
+
result: FunctionExecutionResult = client.execute_function(
|
|
96
|
+
function_name=function_name, parameters=all_params
|
|
97
|
+
)
|
|
90
98
|
|
|
91
99
|
# Handle errors and extract result
|
|
92
|
-
if
|
|
93
|
-
logger.error(
|
|
100
|
+
if result.error:
|
|
101
|
+
logger.error(
|
|
102
|
+
"Unity Catalog function error",
|
|
103
|
+
function_name=function_name,
|
|
104
|
+
error=result.error,
|
|
105
|
+
)
|
|
94
106
|
raise RuntimeError(f"Function execution failed: {result.error}")
|
|
95
107
|
|
|
96
|
-
result_value: str = result.value if
|
|
97
|
-
logger.
|
|
108
|
+
result_value: str = result.value if result.value is not None else str(result)
|
|
109
|
+
logger.trace(
|
|
110
|
+
"UC function result",
|
|
111
|
+
function_name=function_name,
|
|
112
|
+
result_length=len(str(result_value)),
|
|
113
|
+
)
|
|
98
114
|
return result_value
|
|
99
115
|
|
|
100
116
|
|
|
@@ -113,21 +129,29 @@ def _grant_function_permissions(
|
|
|
113
129
|
"""
|
|
114
130
|
try:
|
|
115
131
|
# Initialize workspace client
|
|
116
|
-
workspace_client
|
|
132
|
+
workspace_client: WorkspaceClient = (
|
|
133
|
+
WorkspaceClient(host=host) if host else WorkspaceClient()
|
|
134
|
+
)
|
|
117
135
|
|
|
118
136
|
# Parse the function name to get catalog and schema
|
|
119
|
-
parts = function_name.split(".")
|
|
137
|
+
parts: list[str] = function_name.split(".")
|
|
120
138
|
if len(parts) != 3:
|
|
121
139
|
logger.warning(
|
|
122
|
-
|
|
140
|
+
"Invalid function name format, expected catalog.schema.function",
|
|
141
|
+
function_name=function_name,
|
|
123
142
|
)
|
|
124
143
|
return
|
|
125
144
|
|
|
145
|
+
catalog_name: str
|
|
146
|
+
schema_name: str
|
|
147
|
+
func_name: str
|
|
126
148
|
catalog_name, schema_name, func_name = parts
|
|
127
|
-
schema_full_name = f"{catalog_name}.{schema_name}"
|
|
149
|
+
schema_full_name: str = f"{catalog_name}.{schema_name}"
|
|
128
150
|
|
|
129
151
|
logger.debug(
|
|
130
|
-
|
|
152
|
+
"Granting comprehensive permissions",
|
|
153
|
+
function_name=function_name,
|
|
154
|
+
principal=client_id,
|
|
131
155
|
)
|
|
132
156
|
|
|
133
157
|
# 1. Grant EXECUTE permission on the function
|
|
@@ -139,9 +163,13 @@ def _grant_function_permissions(
|
|
|
139
163
|
PermissionsChange(principal=client_id, add=[Privilege.EXECUTE])
|
|
140
164
|
],
|
|
141
165
|
)
|
|
142
|
-
logger.
|
|
166
|
+
logger.trace("Granted EXECUTE permission", function_name=function_name)
|
|
143
167
|
except Exception as e:
|
|
144
|
-
logger.warning(
|
|
168
|
+
logger.warning(
|
|
169
|
+
"Failed to grant EXECUTE permission",
|
|
170
|
+
function_name=function_name,
|
|
171
|
+
error=str(e),
|
|
172
|
+
)
|
|
145
173
|
|
|
146
174
|
# 2. Grant USE_SCHEMA permission on the schema
|
|
147
175
|
try:
|
|
@@ -155,10 +183,12 @@ def _grant_function_permissions(
|
|
|
155
183
|
)
|
|
156
184
|
],
|
|
157
185
|
)
|
|
158
|
-
logger.
|
|
186
|
+
logger.trace("Granted USE_SCHEMA permission", schema=schema_full_name)
|
|
159
187
|
except Exception as e:
|
|
160
188
|
logger.warning(
|
|
161
|
-
|
|
189
|
+
"Failed to grant USE_SCHEMA permission",
|
|
190
|
+
schema=schema_full_name,
|
|
191
|
+
error=str(e),
|
|
162
192
|
)
|
|
163
193
|
|
|
164
194
|
# 3. Grant USE_CATALOG and BROWSE permissions on the catalog
|
|
@@ -173,25 +203,34 @@ def _grant_function_permissions(
|
|
|
173
203
|
)
|
|
174
204
|
],
|
|
175
205
|
)
|
|
176
|
-
logger.
|
|
206
|
+
logger.trace(
|
|
207
|
+
"Granted USE_CATALOG and BROWSE permissions", catalog=catalog_name
|
|
208
|
+
)
|
|
177
209
|
except Exception as e:
|
|
178
210
|
logger.warning(
|
|
179
|
-
|
|
211
|
+
"Failed to grant catalog permissions",
|
|
212
|
+
catalog=catalog_name,
|
|
213
|
+
error=str(e),
|
|
180
214
|
)
|
|
181
215
|
|
|
182
216
|
logger.debug(
|
|
183
|
-
|
|
217
|
+
"Successfully granted comprehensive permissions",
|
|
218
|
+
function_name=function_name,
|
|
219
|
+
principal=client_id,
|
|
184
220
|
)
|
|
185
221
|
|
|
186
222
|
except Exception as e:
|
|
187
223
|
logger.warning(
|
|
188
|
-
|
|
224
|
+
"Failed to grant permissions",
|
|
225
|
+
function_name=function_name,
|
|
226
|
+
principal=client_id,
|
|
227
|
+
error=str(e),
|
|
189
228
|
)
|
|
190
229
|
# Don't fail the tool creation if permission granting fails
|
|
191
230
|
pass
|
|
192
231
|
|
|
193
232
|
|
|
194
|
-
def _create_filtered_schema(original_schema: type, exclude_fields:
|
|
233
|
+
def _create_filtered_schema(original_schema: type, exclude_fields: Set[str]) -> type:
|
|
195
234
|
"""
|
|
196
235
|
Create a new Pydantic model that excludes specified fields from the original schema.
|
|
197
236
|
|
|
@@ -202,23 +241,27 @@ def _create_filtered_schema(original_schema: type, exclude_fields: set[str]) ->
|
|
|
202
241
|
Returns:
|
|
203
242
|
A new Pydantic model class with the specified fields removed
|
|
204
243
|
"""
|
|
205
|
-
from pydantic import
|
|
206
|
-
from pydantic.fields import PydanticUndefined
|
|
244
|
+
from pydantic import Field, create_model
|
|
245
|
+
from pydantic.fields import FieldInfo, PydanticUndefined
|
|
207
246
|
|
|
208
247
|
try:
|
|
209
248
|
# Get the original model's fields (Pydantic v2)
|
|
210
|
-
original_fields = original_schema.model_fields
|
|
211
|
-
filtered_field_definitions = {}
|
|
249
|
+
original_fields: dict[str, FieldInfo] = original_schema.model_fields
|
|
250
|
+
filtered_field_definitions: dict[str, tuple[type, FieldInfo]] = {}
|
|
212
251
|
|
|
213
|
-
|
|
214
|
-
|
|
252
|
+
field_name: str
|
|
253
|
+
field: FieldInfo
|
|
254
|
+
for field_name, field in original_fields.items():
|
|
255
|
+
if field_name not in exclude_fields:
|
|
215
256
|
# Reconstruct the field definition for create_model
|
|
216
|
-
field_type = field.annotation
|
|
217
|
-
field_default = (
|
|
257
|
+
field_type: type = field.annotation
|
|
258
|
+
field_default: Any = (
|
|
218
259
|
field.default if field.default is not PydanticUndefined else ...
|
|
219
260
|
)
|
|
220
|
-
field_info = Field(
|
|
221
|
-
|
|
261
|
+
field_info: FieldInfo = Field(
|
|
262
|
+
default=field_default, description=field.description
|
|
263
|
+
)
|
|
264
|
+
filtered_field_definitions[field_name] = (field_type, field_info)
|
|
222
265
|
|
|
223
266
|
# If no fields remain after filtering, return a generic empty schema
|
|
224
267
|
if not filtered_field_definitions:
|
|
@@ -231,18 +274,18 @@ def _create_filtered_schema(original_schema: type, exclude_fields: set[str]) ->
|
|
|
231
274
|
return EmptySchema
|
|
232
275
|
|
|
233
276
|
# Create the new model dynamically
|
|
234
|
-
model_name = f"Filtered{original_schema.__name__}"
|
|
235
|
-
docstring = getattr(
|
|
277
|
+
model_name: str = f"Filtered{original_schema.__name__}"
|
|
278
|
+
docstring: str = getattr(
|
|
236
279
|
original_schema, "__doc__", "Filtered Unity Catalog function parameters."
|
|
237
280
|
)
|
|
238
281
|
|
|
239
|
-
filtered_model = create_model(
|
|
282
|
+
filtered_model: type[BaseModel] = create_model(
|
|
240
283
|
model_name, __doc__=docstring, **filtered_field_definitions
|
|
241
284
|
)
|
|
242
285
|
return filtered_model
|
|
243
286
|
|
|
244
287
|
except Exception as e:
|
|
245
|
-
logger.warning(
|
|
288
|
+
logger.warning("Failed to create filtered schema", error=str(e))
|
|
246
289
|
|
|
247
290
|
# Fallback to generic schema
|
|
248
291
|
class GenericFilteredSchema(BaseModel):
|
|
@@ -254,8 +297,7 @@ def _create_filtered_schema(original_schema: type, exclude_fields: set[str]) ->
|
|
|
254
297
|
|
|
255
298
|
|
|
256
299
|
def with_partial_args(
|
|
257
|
-
|
|
258
|
-
partial_args: dict[str, AnyVariable] = {},
|
|
300
|
+
uc_function: UnityCatalogFunctionModel,
|
|
259
301
|
) -> StructuredTool:
|
|
260
302
|
"""
|
|
261
303
|
Create a Unity Catalog tool with partial arguments pre-filled.
|
|
@@ -264,35 +306,64 @@ def with_partial_args(
|
|
|
264
306
|
already resolved, so the caller only needs to provide the remaining parameters.
|
|
265
307
|
|
|
266
308
|
Args:
|
|
267
|
-
|
|
268
|
-
|
|
309
|
+
uc_function: UnityCatalogFunctionModel containing the function configuration
|
|
310
|
+
and partial_args to pre-fill.
|
|
269
311
|
|
|
270
312
|
Returns:
|
|
271
313
|
StructuredTool: A LangChain tool with partial arguments pre-filled
|
|
272
314
|
"""
|
|
273
315
|
from unitycatalog.ai.langchain.toolkit import generate_function_input_params_schema
|
|
274
316
|
|
|
275
|
-
|
|
317
|
+
from dao_ai.config import ServicePrincipalModel
|
|
318
|
+
|
|
319
|
+
partial_args: dict[str, AnyVariable] = uc_function.partial_args or {}
|
|
276
320
|
|
|
277
321
|
# Convert dict-based variables to CompositeVariableModel and resolve their values
|
|
278
|
-
resolved_args = {}
|
|
322
|
+
resolved_args: dict[str, Any] = {}
|
|
323
|
+
k: str
|
|
324
|
+
v: AnyVariable
|
|
279
325
|
for k, v in partial_args.items():
|
|
280
326
|
if isinstance(v, dict):
|
|
281
327
|
resolved_args[k] = value_of(CompositeVariableModel(**v))
|
|
282
328
|
else:
|
|
283
329
|
resolved_args[k] = value_of(v)
|
|
284
330
|
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
331
|
+
# Handle service_principal - expand into client_id and client_secret
|
|
332
|
+
if "service_principal" in resolved_args:
|
|
333
|
+
sp: Any = resolved_args.pop("service_principal")
|
|
334
|
+
if isinstance(sp, dict):
|
|
335
|
+
sp = ServicePrincipalModel(**sp)
|
|
336
|
+
if isinstance(sp, ServicePrincipalModel):
|
|
337
|
+
if "client_id" not in resolved_args:
|
|
338
|
+
resolved_args["client_id"] = value_of(sp.client_id)
|
|
339
|
+
if "client_secret" not in resolved_args:
|
|
340
|
+
resolved_args["client_secret"] = value_of(sp.client_secret)
|
|
341
|
+
|
|
342
|
+
# Normalize host/workspace_host - accept either key, ensure https:// scheme
|
|
343
|
+
if "workspace_host" in resolved_args and "host" not in resolved_args:
|
|
344
|
+
resolved_args["host"] = normalize_host(resolved_args.pop("workspace_host"))
|
|
345
|
+
elif "host" in resolved_args:
|
|
346
|
+
resolved_args["host"] = normalize_host(resolved_args["host"])
|
|
347
|
+
|
|
348
|
+
# Default host from WorkspaceClient if not provided
|
|
349
|
+
if "host" not in resolved_args:
|
|
350
|
+
from dao_ai.utils import get_default_databricks_host
|
|
351
|
+
|
|
352
|
+
host: str | None = get_default_databricks_host()
|
|
353
|
+
if host:
|
|
354
|
+
resolved_args["host"] = host
|
|
355
|
+
|
|
356
|
+
# Get function info from the resource
|
|
357
|
+
function_name: str = uc_function.resource.full_name
|
|
358
|
+
tool_name: str = uc_function.resource.name or function_name.replace(".", "_")
|
|
359
|
+
workspace_client: WorkspaceClient = uc_function.resource.workspace_client
|
|
289
360
|
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
361
|
+
logger.debug(
|
|
362
|
+
"Creating UC tool with partial args",
|
|
363
|
+
function_name=function_name,
|
|
364
|
+
tool_name=tool_name,
|
|
365
|
+
partial_args=list(resolved_args.keys()),
|
|
366
|
+
)
|
|
296
367
|
|
|
297
368
|
# Grant permissions if we have credentials
|
|
298
369
|
if "client_id" in resolved_args:
|
|
@@ -301,36 +372,44 @@ def with_partial_args(
|
|
|
301
372
|
try:
|
|
302
373
|
_grant_function_permissions(function_name, client_id, host)
|
|
303
374
|
except Exception as e:
|
|
304
|
-
logger.warning(
|
|
375
|
+
logger.warning(
|
|
376
|
+
"Failed to grant permissions", function_name=function_name, error=str(e)
|
|
377
|
+
)
|
|
305
378
|
|
|
306
|
-
# Create the client for function execution
|
|
307
|
-
client: DatabricksFunctionClient = DatabricksFunctionClient()
|
|
379
|
+
# Create the client for function execution using the resource's workspace client
|
|
380
|
+
client: DatabricksFunctionClient = DatabricksFunctionClient(client=workspace_client)
|
|
308
381
|
|
|
309
382
|
# Try to get the function schema for better tool definition
|
|
383
|
+
schema_model: type[BaseModel]
|
|
384
|
+
tool_description: str
|
|
310
385
|
try:
|
|
311
|
-
function_info = client.get_function(function_name)
|
|
386
|
+
function_info: FunctionInfo = client.get_function(function_name)
|
|
312
387
|
schema_info = generate_function_input_params_schema(function_info)
|
|
313
388
|
tool_description = (
|
|
314
389
|
function_info.comment or f"Unity Catalog function: {function_name}"
|
|
315
390
|
)
|
|
316
391
|
|
|
317
|
-
logger.
|
|
318
|
-
|
|
392
|
+
logger.trace(
|
|
393
|
+
"Generated function schema",
|
|
394
|
+
function_name=function_name,
|
|
395
|
+
schema=schema_info.pydantic_model.__name__,
|
|
319
396
|
)
|
|
320
|
-
logger.debug(f"Tool description: {tool_description}")
|
|
321
397
|
|
|
322
398
|
# Create a modified schema that excludes partial args
|
|
323
|
-
original_schema = schema_info.pydantic_model
|
|
399
|
+
original_schema: type = schema_info.pydantic_model
|
|
324
400
|
schema_model = _create_filtered_schema(original_schema, resolved_args.keys())
|
|
325
|
-
logger.
|
|
326
|
-
|
|
401
|
+
logger.trace(
|
|
402
|
+
"Filtered schema to exclude partial args",
|
|
403
|
+
function_name=function_name,
|
|
404
|
+
excluded_args=list(resolved_args.keys()),
|
|
327
405
|
)
|
|
328
406
|
|
|
329
407
|
except Exception as e:
|
|
330
|
-
logger.warning(
|
|
331
|
-
|
|
332
|
-
|
|
408
|
+
logger.warning(
|
|
409
|
+
"Could not introspect function", function_name=function_name, error=str(e)
|
|
410
|
+
)
|
|
333
411
|
|
|
412
|
+
# Fallback to a generic schema
|
|
334
413
|
class GenericUCParams(BaseModel):
|
|
335
414
|
"""Generic parameters for Unity Catalog function."""
|
|
336
415
|
|
|
@@ -350,14 +429,12 @@ def with_partial_args(
|
|
|
350
429
|
)
|
|
351
430
|
|
|
352
431
|
# Set the function name for the decorator
|
|
353
|
-
uc_function_wrapper.__name__ =
|
|
432
|
+
uc_function_wrapper.__name__ = tool_name
|
|
354
433
|
|
|
355
434
|
# Create the tool using LangChain's StructuredTool
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
partial_tool = StructuredTool.from_function(
|
|
435
|
+
partial_tool: StructuredTool = StructuredTool.from_function(
|
|
359
436
|
func=uc_function_wrapper,
|
|
360
|
-
name=
|
|
437
|
+
name=tool_name,
|
|
361
438
|
description=tool_description,
|
|
362
439
|
args_schema=schema_model,
|
|
363
440
|
)
|