dao-ai 0.0.18__py3-none-any.whl → 0.0.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dao_ai/config.py CHANGED
@@ -39,6 +39,7 @@ from mlflow.models import ModelConfig
39
39
  from mlflow.models.resources import (
40
40
  DatabricksFunction,
41
41
  DatabricksGenieSpace,
42
+ DatabricksLakebase,
42
43
  DatabricksResource,
43
44
  DatabricksServingEndpoint,
44
45
  DatabricksSQLWarehouse,
@@ -274,10 +275,8 @@ class TableModel(BaseModel, HasFullName, IsDatabricksResource):
274
275
  "_assessment_logs",
275
276
  "_request_logs",
276
277
  ]
277
-
278
- excluded_prefixes: Sequence[str] = [
279
- "trace_logs_"
280
- ]
278
+
279
+ excluded_prefixes: Sequence[str] = ["trace_logs_"]
281
280
 
282
281
  if self.name:
283
282
  resources.append(
@@ -299,8 +298,12 @@ class TableModel(BaseModel, HasFullName, IsDatabricksResource):
299
298
  on_behalf_of_user=self.on_behalf_of_user,
300
299
  )
301
300
  for table in tables
302
- if not any(table.name.endswith(suffix) for suffix in excluded_suffixes)
303
- and not any(table.name.startswith(prefix) for prefix in excluded_prefixes)
301
+ if not any(
302
+ table.name.endswith(suffix) for suffix in excluded_suffixes
303
+ )
304
+ and not any(
305
+ table.name.startswith(prefix) for prefix in excluded_prefixes
306
+ )
304
307
  ]
305
308
  )
306
309
 
@@ -413,7 +416,7 @@ class GenieRoomModel(BaseModel, IsDatabricksResource):
413
416
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
414
417
  name: str
415
418
  description: Optional[str] = None
416
- space_id: str
419
+ space_id: AnyVariable
417
420
 
418
421
  @property
419
422
  def api_scopes(self) -> Sequence[str]:
@@ -428,6 +431,11 @@ class GenieRoomModel(BaseModel, IsDatabricksResource):
428
431
  )
429
432
  ]
430
433
 
434
+ @model_validator(mode="after")
435
+ def update_space_id(self):
436
+ self.space_id = value_of(self.space_id)
437
+ return self
438
+
431
439
 
432
440
  class VolumeModel(BaseModel, HasFullName):
433
441
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
@@ -648,6 +656,7 @@ class ConnectionModel(BaseModel, HasFullName, IsDatabricksResource):
648
656
  def api_scopes(self) -> Sequence[str]:
649
657
  return [
650
658
  "catalog.connections",
659
+ "serving.serving-endpoints",
651
660
  ]
652
661
 
653
662
  def as_resources(self) -> Sequence[DatabricksResource]:
@@ -662,7 +671,7 @@ class WarehouseModel(BaseModel, IsDatabricksResource):
662
671
  model_config = ConfigDict()
663
672
  name: str
664
673
  description: Optional[str] = None
665
- warehouse_id: str
674
+ warehouse_id: AnyVariable
666
675
 
667
676
  @property
668
677
  def api_scopes(self) -> Sequence[str]:
@@ -678,8 +687,13 @@ class WarehouseModel(BaseModel, IsDatabricksResource):
678
687
  )
679
688
  ]
680
689
 
690
+ @model_validator(mode="after")
691
+ def update_warehouse_id(self):
692
+ self.warehouse_id = value_of(self.warehouse_id)
693
+ return self
694
+
681
695
 
682
- class DatabaseModel(BaseModel):
696
+ class DatabaseModel(BaseModel, IsDatabricksResource):
683
697
  model_config = ConfigDict(frozen=True)
684
698
  name: str
685
699
  description: Optional[str] = None
@@ -695,6 +709,18 @@ class DatabaseModel(BaseModel):
695
709
  client_secret: Optional[AnyVariable] = None
696
710
  workspace_host: Optional[AnyVariable] = None
697
711
 
712
+ @property
713
+ def api_scopes(self) -> Sequence[str]:
714
+ return []
715
+
716
+ def as_resources(self) -> Sequence[DatabricksResource]:
717
+ return [
718
+ DatabricksLakebase(
719
+ database_instance_name=self.name,
720
+ on_behalf_of_user=self.on_behalf_of_user,
721
+ )
722
+ ]
723
+
698
724
  @model_validator(mode="after")
699
725
  def validate_auth_methods(self):
700
726
  oauth_fields: Sequence[Any] = [
@@ -841,7 +867,7 @@ class PythonFunctionModel(BaseFunctionModel, HasFullName):
841
867
 
842
868
  class FactoryFunctionModel(BaseFunctionModel, HasFullName):
843
869
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
844
- args: Optional[dict[str, Any]] = Field(default_factory=dict)
870
+ args: Optional[dict[str, AnyVariable]] = Field(default_factory=dict)
845
871
  type: Literal[FunctionType.FACTORY] = FunctionType.FACTORY
846
872
 
847
873
  @property
@@ -937,6 +963,7 @@ class McpFunctionModel(BaseFunctionModel, HasFullName):
937
963
  class UnityCatalogFunctionModel(BaseFunctionModel, HasFullName):
938
964
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
939
965
  schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
966
+ partial_args: Optional[dict[str, AnyVariable]] = Field(default_factory=dict)
940
967
  type: Literal[FunctionType.UNITY_CATALOG] = FunctionType.UNITY_CATALOG
941
968
 
942
969
  @property
@@ -1190,13 +1217,13 @@ class AppModel(BaseModel):
1190
1217
  endpoint_name: Optional[str] = None
1191
1218
  tags: Optional[dict[str, Any]] = Field(default_factory=dict)
1192
1219
  scale_to_zero: Optional[bool] = True
1193
- environment_vars: Optional[dict[str, Any]] = Field(default_factory=dict)
1220
+ environment_vars: Optional[dict[str, AnyVariable]] = Field(default_factory=dict)
1194
1221
  budget_policy_id: Optional[str] = None
1195
1222
  workload_size: Optional[WorkloadSize] = "Small"
1196
1223
  permissions: Optional[list[AppPermissionModel]] = Field(default_factory=list)
1197
1224
  agents: list[AgentModel] = Field(default_factory=list)
1198
1225
 
1199
- orchestration: OrchestrationModel
1226
+ orchestration: Optional[OrchestrationModel] = None
1200
1227
  alias: Optional[str] = None
1201
1228
  initialization_hooks: Optional[FunctionHook | list[FunctionHook]] = Field(
1202
1229
  default_factory=list
@@ -1215,7 +1242,38 @@ class AppModel(BaseModel):
1215
1242
  @model_validator(mode="after")
1216
1243
  def validate_agents_not_empty(self):
1217
1244
  if not self.agents:
1218
- raise ValueError("agents must contain at least one item")
1245
+ raise ValueError("At least one agent must be specified")
1246
+ return self
1247
+
1248
+ @model_validator(mode="after")
1249
+ def update_environment_vars(self):
1250
+ for key, value in self.environment_vars.items():
1251
+ if isinstance(value, SecretVariableModel):
1252
+ updated_value = str(value)
1253
+ else:
1254
+ updated_value = value_of(value)
1255
+
1256
+ self.environment_vars[key] = updated_value
1257
+ return self
1258
+
1259
+ @model_validator(mode="after")
1260
+ def set_default_orchestration(self):
1261
+ if self.orchestration is None:
1262
+ if len(self.agents) > 1:
1263
+ default_agent: AgentModel = self.agents[0]
1264
+ self.orchestration = OrchestrationModel(
1265
+ swarm=SupervisorModel(model=default_agent.model)
1266
+ )
1267
+ elif len(self.agents) == 1:
1268
+ default_agent: AgentModel = self.agents[0]
1269
+ self.orchestration = OrchestrationModel(
1270
+ supervisor=SwarmModel(
1271
+ model=default_agent.model, default_agent=default_agent
1272
+ )
1273
+ )
1274
+ else:
1275
+ raise ValueError("At least one agent must be specified")
1276
+
1219
1277
  return self
1220
1278
 
1221
1279
  @model_validator(mode="after")
dao_ai/graph.py CHANGED
@@ -136,8 +136,8 @@ def _create_supervisor_graph(config: AppConfig) -> CompiledStateGraph:
136
136
 
137
137
  workflow: StateGraph = StateGraph(
138
138
  SharedState,
139
- input=IncomingState,
140
- output=OutgoingState,
139
+ input_schema=IncomingState,
140
+ output_schema=OutgoingState,
141
141
  context_schema=Context,
142
142
  )
143
143
 
@@ -200,8 +200,8 @@ def _create_swarm_graph(config: AppConfig) -> CompiledStateGraph:
200
200
 
201
201
  workflow: StateGraph = StateGraph(
202
202
  SharedState,
203
- input=IncomingState,
204
- output=OutgoingState,
203
+ input_schema=IncomingState,
204
+ output_schema=OutgoingState,
205
205
  context_schema=Context,
206
206
  )
207
207
 
dao_ai/guardrails.py CHANGED
@@ -87,12 +87,12 @@ def judge_node(guardrails: GuardrailModel) -> RunnableLike:
87
87
  )
88
88
 
89
89
  if eval_result["score"]:
90
- logger.debug("Response approved by judge")
90
+ logger.debug("Response approved by judge")
91
91
  logger.debug(f"Judge's comment: {eval_result['comment']}")
92
92
  return
93
93
  else:
94
94
  # Otherwise, return the judge's critique as a new user message
95
- logger.warning("⚠️ Judge requested improvements")
95
+ logger.warning("Judge requested improvements")
96
96
  comment: str = eval_result["comment"]
97
97
  logger.warning(f"Judge's critique: {comment}")
98
98
  content: str = "\n".join([human_message.content, comment])
dao_ai/memory/core.py CHANGED
@@ -99,13 +99,13 @@ class CheckpointManager:
99
99
  checkpointer_manager
100
100
  )
101
101
  case StorageType.POSTGRES:
102
- from dao_ai.memory.postgres import PostgresCheckpointerManager
102
+ from dao_ai.memory.postgres import AsyncPostgresCheckpointerManager
103
103
 
104
104
  checkpointer_manager = cls.checkpoint_managers.get(
105
105
  checkpointer_model.database.name
106
106
  )
107
107
  if checkpointer_manager is None:
108
- checkpointer_manager = PostgresCheckpointerManager(
108
+ checkpointer_manager = AsyncPostgresCheckpointerManager(
109
109
  checkpointer_model
110
110
  )
111
111
  cls.checkpoint_managers[checkpointer_model.database.name] = (
dao_ai/memory/postgres.py CHANGED
@@ -119,8 +119,9 @@ class AsyncPostgresStoreManager(StoreManagerBase):
119
119
  self.store_model.database
120
120
  )
121
121
 
122
- # Create store with the shared pool
122
+ # Create store with the shared pool (using patched version)
123
123
  self._store = AsyncPostgresStore(conn=self.pool)
124
+
124
125
  await self._store.setup()
125
126
 
126
127
  self._setup_complete = True
dao_ai/models.py CHANGED
@@ -227,14 +227,28 @@ class LanggraphResponsesAgent(ResponsesAgent):
227
227
  import asyncio
228
228
 
229
229
  async def _async_invoke():
230
- return await self.graph.ainvoke(
231
- {"messages": messages}, context=context, config=custom_inputs
232
- )
230
+ try:
231
+ return await self.graph.ainvoke(
232
+ {"messages": messages}, context=context, config=custom_inputs
233
+ )
234
+ except Exception as e:
235
+ logger.error(f"Error in graph.ainvoke: {e}")
236
+ raise
233
237
 
234
- loop = asyncio.get_event_loop()
235
- response: dict[str, Sequence[BaseMessage]] = loop.run_until_complete(
236
- _async_invoke()
237
- )
238
+ try:
239
+ loop = asyncio.get_event_loop()
240
+ except RuntimeError:
241
+ # Handle case where no event loop exists (common in some deployment scenarios)
242
+ loop = asyncio.new_event_loop()
243
+ asyncio.set_event_loop(loop)
244
+
245
+ try:
246
+ response: dict[str, Sequence[BaseMessage]] = loop.run_until_complete(
247
+ _async_invoke()
248
+ )
249
+ except Exception as e:
250
+ logger.error(f"Error in async execution: {e}")
251
+ raise
238
252
 
239
253
  # Convert response to ResponsesAgent format
240
254
  last_message: BaseMessage = response["messages"][-1]
@@ -243,8 +257,9 @@ class LanggraphResponsesAgent(ResponsesAgent):
243
257
  text=last_message.content, id=f"msg_{uuid.uuid4().hex[:8]}"
244
258
  )
245
259
 
260
+ custom_outputs = custom_inputs
246
261
  return ResponsesAgentResponse(
247
- output=[output_item], custom_outputs=request.custom_inputs
262
+ output=[output_item], custom_outputs=custom_outputs
248
263
  )
249
264
 
250
265
  def predict_stream(
@@ -271,46 +286,59 @@ class LanggraphResponsesAgent(ResponsesAgent):
271
286
  item_id = f"msg_{uuid.uuid4().hex[:8]}"
272
287
  accumulated_content = ""
273
288
 
274
- async for nodes, stream_mode, messages_batch in self.graph.astream(
275
- {"messages": messages},
276
- context=context,
277
- config=custom_inputs,
278
- stream_mode=["messages", "custom"],
279
- subgraphs=True,
280
- ):
281
- nodes: tuple[str, ...]
282
- stream_mode: str
283
- messages_batch: Sequence[BaseMessage]
284
-
285
- for message in messages_batch:
286
- if (
287
- isinstance(
288
- message,
289
- (
290
- AIMessageChunk,
291
- AIMessage,
292
- ),
293
- )
294
- and message.content
295
- and "summarization" not in nodes
296
- ):
297
- content = message.content
298
- accumulated_content += content
299
-
300
- # Yield streaming delta
301
- yield ResponsesAgentStreamEvent(
302
- **self.create_text_delta(delta=content, item_id=item_id)
303
- )
304
-
305
- # Yield final output item
306
- yield ResponsesAgentStreamEvent(
307
- type="response.output_item.done",
308
- item=self.create_text_output_item(text=accumulated_content, id=item_id),
309
- custom_outputs=request.custom_inputs,
310
- )
289
+ try:
290
+ async for nodes, stream_mode, messages_batch in self.graph.astream(
291
+ {"messages": messages},
292
+ context=context,
293
+ config=custom_inputs,
294
+ stream_mode=["messages", "custom"],
295
+ subgraphs=True,
296
+ ):
297
+ nodes: tuple[str, ...]
298
+ stream_mode: str
299
+ messages_batch: Sequence[BaseMessage]
300
+
301
+ for message in messages_batch:
302
+ if (
303
+ isinstance(
304
+ message,
305
+ (
306
+ AIMessageChunk,
307
+ AIMessage,
308
+ ),
309
+ )
310
+ and message.content
311
+ and "summarization" not in nodes
312
+ ):
313
+ content = message.content
314
+ accumulated_content += content
315
+
316
+ # Yield streaming delta
317
+ yield ResponsesAgentStreamEvent(
318
+ **self.create_text_delta(delta=content, item_id=item_id)
319
+ )
320
+
321
+ custom_outputs = custom_inputs
322
+ # Yield final output item
323
+ yield ResponsesAgentStreamEvent(
324
+ type="response.output_item.done",
325
+ item=self.create_text_output_item(
326
+ text=accumulated_content, id=item_id
327
+ ),
328
+ custom_outputs=custom_outputs,
329
+ )
330
+ except Exception as e:
331
+ logger.error(f"Error in graph.astream: {e}")
332
+ raise
311
333
 
312
334
  # Convert async generator to sync generator
313
- loop = asyncio.get_event_loop()
335
+ try:
336
+ loop = asyncio.get_event_loop()
337
+ except RuntimeError:
338
+ # Handle case where no event loop exists (common in some deployment scenarios)
339
+ loop = asyncio.new_event_loop()
340
+ asyncio.set_event_loop(loop)
341
+
314
342
  async_gen = _async_stream()
315
343
 
316
344
  try:
@@ -320,8 +348,14 @@ class LanggraphResponsesAgent(ResponsesAgent):
320
348
  yield item
321
349
  except StopAsyncIteration:
322
350
  break
351
+ except Exception as e:
352
+ logger.error(f"Error in streaming: {e}")
353
+ raise
323
354
  finally:
324
- loop.run_until_complete(async_gen.aclose())
355
+ try:
356
+ loop.run_until_complete(async_gen.aclose())
357
+ except Exception as e:
358
+ logger.warning(f"Error closing async generator: {e}")
325
359
 
326
360
  def _extract_text_from_content(
327
361
  self,
@@ -522,7 +556,14 @@ def _process_langchain_messages_stream(
522
556
  yield message
523
557
 
524
558
  # Convert async generator to sync generator
525
- loop = asyncio.get_event_loop()
559
+
560
+ try:
561
+ loop = asyncio.get_event_loop()
562
+ except RuntimeError:
563
+ # Handle case where no event loop exists (common in some deployment scenarios)
564
+ loop = asyncio.new_event_loop()
565
+ asyncio.set_event_loop(loop)
566
+
526
567
  async_gen = _async_stream()
527
568
 
528
569
  try:
@@ -42,6 +42,7 @@ import dao_ai
42
42
  from dao_ai.config import (
43
43
  AppConfig,
44
44
  ConnectionModel,
45
+ DatabaseModel,
45
46
  DatasetModel,
46
47
  FunctionModel,
47
48
  GenieRoomModel,
@@ -224,6 +225,7 @@ class DatabricksProvider(ServiceProvider):
224
225
  connections: Sequence[ConnectionModel] = list(
225
226
  config.resources.connections.values()
226
227
  )
228
+ databases: Sequence[DatabaseModel] = list(config.resources.databases.values())
227
229
 
228
230
  resources: Sequence[IsDatabricksResource] = (
229
231
  llms
@@ -233,6 +235,7 @@ class DatabricksProvider(ServiceProvider):
233
235
  + functions
234
236
  + tables
235
237
  + connections
238
+ + databases
236
239
  )
237
240
 
238
241
  # Flatten all resources from all models into a single list
@@ -87,7 +87,11 @@ def as_human_in_the_loop(
87
87
  if isinstance(function, BaseFunctionModel):
88
88
  human_in_the_loop: HumanInTheLoopModel | None = function.human_in_the_loop
89
89
  if human_in_the_loop:
90
- logger.debug(f"Adding human-in-the-loop to tool: {tool.name}")
90
+ # Get tool name safely - handle RunnableBinding objects
91
+ tool_name = getattr(tool, "name", None) or getattr(
92
+ getattr(tool, "bound", None), "name", "unknown_tool"
93
+ )
94
+ logger.debug(f"Adding human-in-the-loop to tool: {tool_name}")
91
95
  tool = add_human_in_the_loop(
92
96
  tool=tool,
93
97
  interrupt_config=human_in_the_loop.interupt_config,
@@ -1,14 +1,18 @@
1
- from typing import Sequence
1
+ from typing import Any, Dict, Optional, Sequence, Union
2
2
 
3
- from databricks_langchain import (
4
- DatabricksFunctionClient,
5
- UCFunctionToolkit,
6
- )
3
+ from databricks.sdk import WorkspaceClient
4
+ from databricks.sdk.service.catalog import PermissionsChange, Privilege
5
+ from databricks_langchain import DatabricksFunctionClient, UCFunctionToolkit
7
6
  from langchain_core.runnables.base import RunnableLike
7
+ from langchain_core.tools import StructuredTool
8
8
  from loguru import logger
9
9
 
10
10
  from dao_ai.config import (
11
+ AnyVariable,
12
+ CompositeVariableModel,
13
+ ToolModel,
11
14
  UnityCatalogFunctionModel,
15
+ value_of,
12
16
  )
13
17
  from dao_ai.tools.human_in_the_loop import as_human_in_the_loop
14
18
 
@@ -32,19 +36,330 @@ def create_uc_tools(
32
36
 
33
37
  logger.debug(f"create_uc_tools: {function}")
34
38
 
39
+ original_function_model = None
35
40
  if isinstance(function, UnityCatalogFunctionModel):
36
- function = function.full_name
41
+ original_function_model = function
42
+ function_name = function.full_name
43
+ else:
44
+ function_name = function
37
45
 
38
- client: DatabricksFunctionClient = DatabricksFunctionClient()
46
+ # Determine which tools to create
47
+ if original_function_model and original_function_model.partial_args:
48
+ logger.debug("Found partial_args, creating custom tool with partial arguments")
49
+ # Create a ToolModel wrapper for the with_partial_args function
50
+ tool_model = ToolModel(
51
+ name=original_function_model.name, function=original_function_model
52
+ )
53
+
54
+ # Use with_partial_args to create the authenticated tool
55
+ tools = [with_partial_args(tool_model, original_function_model.partial_args)]
56
+ else:
57
+ # Fallback to standard UC toolkit approach
58
+ client: DatabricksFunctionClient = DatabricksFunctionClient()
59
+
60
+ toolkit: UCFunctionToolkit = UCFunctionToolkit(
61
+ function_names=[function_name], client=client
62
+ )
63
+
64
+ tools = toolkit.tools or []
65
+ logger.debug(f"Retrieved tools: {tools}")
66
+
67
+ # Apply human-in-the-loop wrapper to all tools and return
68
+ return [as_human_in_the_loop(tool=tool, function=function_name) for tool in tools]
69
+
70
+
71
+ def _execute_uc_function(
72
+ client: DatabricksFunctionClient,
73
+ function_name: str,
74
+ partial_args: Dict[str, str] = None,
75
+ **kwargs: Any,
76
+ ) -> str:
77
+ """Execute Unity Catalog function with partial args and provided parameters."""
78
+
79
+ # Start with partial args if provided
80
+ all_params: Dict[str, Any] = dict(partial_args) if partial_args else {}
81
+
82
+ # Add any additional kwargs
83
+ all_params.update(kwargs)
39
84
 
40
- toolkit: UCFunctionToolkit = UCFunctionToolkit(
41
- function_names=[function], client=client
85
+ logger.debug(
86
+ f"Calling UC function {function_name} with parameters: {list(all_params.keys())}"
42
87
  )
43
88
 
44
- tools = toolkit.tools or []
89
+ result = client.execute_function(function_name=function_name, parameters=all_params)
90
+
91
+ # Handle errors and extract result
92
+ if hasattr(result, "error") and result.error:
93
+ logger.error(f"Unity Catalog function error: {result.error}")
94
+ raise RuntimeError(f"Function execution failed: {result.error}")
95
+
96
+ result_value: str = result.value if hasattr(result, "value") else str(result)
97
+ logger.debug(f"UC function result: {result_value}")
98
+ return result_value
99
+
100
+
101
+ def _grant_function_permissions(
102
+ function_name: str,
103
+ client_id: str,
104
+ host: Optional[str] = None,
105
+ ) -> None:
106
+ """
107
+ Grant comprehensive permissions to the service principal for Unity Catalog function execution.
108
+
109
+ This includes:
110
+ - EXECUTE permission on the function itself
111
+ - USE permission on the containing schema
112
+ - USE permission on the containing catalog
113
+ """
114
+ try:
115
+ # Initialize workspace client
116
+ workspace_client = WorkspaceClient(host=host) if host else WorkspaceClient()
117
+
118
+ # Parse the function name to get catalog and schema
119
+ parts = function_name.split(".")
120
+ if len(parts) != 3:
121
+ logger.warning(
122
+ f"Invalid function name format: {function_name}. Expected catalog.schema.function"
123
+ )
124
+ return
125
+
126
+ catalog_name, schema_name, func_name = parts
127
+ schema_full_name = f"{catalog_name}.{schema_name}"
128
+
129
+ logger.debug(
130
+ f"Granting comprehensive permissions on function {function_name} to principal {client_id}"
131
+ )
132
+
133
+ # 1. Grant EXECUTE permission on the function
134
+ try:
135
+ workspace_client.grants.update(
136
+ securable_type="function",
137
+ full_name=function_name,
138
+ changes=[
139
+ PermissionsChange(principal=client_id, add=[Privilege.EXECUTE])
140
+ ],
141
+ )
142
+ logger.debug(f"Granted EXECUTE on function {function_name}")
143
+ except Exception as e:
144
+ logger.warning(f"Failed to grant EXECUTE on function {function_name}: {e}")
145
+
146
+ # 2. Grant USE_SCHEMA permission on the schema
147
+ try:
148
+ workspace_client.grants.update(
149
+ securable_type="schema",
150
+ full_name=schema_full_name,
151
+ changes=[
152
+ PermissionsChange(
153
+ principal=client_id,
154
+ add=[Privilege.USE_SCHEMA],
155
+ )
156
+ ],
157
+ )
158
+ logger.debug(f"Granted USE_SCHEMA on schema {schema_full_name}")
159
+ except Exception as e:
160
+ logger.warning(
161
+ f"Failed to grant USE_SCHEMA on schema {schema_full_name}: {e}"
162
+ )
163
+
164
+ # 3. Grant USE_CATALOG and BROWSE permissions on the catalog
165
+ try:
166
+ workspace_client.grants.update(
167
+ securable_type="catalog",
168
+ full_name=catalog_name,
169
+ changes=[
170
+ PermissionsChange(
171
+ principal=client_id,
172
+ add=[Privilege.USE_CATALOG, Privilege.BROWSE],
173
+ )
174
+ ],
175
+ )
176
+ logger.debug(f"Granted USE_CATALOG and BROWSE on catalog {catalog_name}")
177
+ except Exception as e:
178
+ logger.warning(
179
+ f"Failed to grant catalog permissions on {catalog_name}: {e}"
180
+ )
181
+
182
+ logger.debug(
183
+ f"Successfully granted comprehensive permissions on {function_name} to {client_id}"
184
+ )
185
+
186
+ except Exception as e:
187
+ logger.warning(
188
+ f"Failed to grant permissions on function {function_name} to {client_id}: {e}"
189
+ )
190
+ # Don't fail the tool creation if permission granting fails
191
+ pass
192
+
193
+
194
+ def _create_filtered_schema(original_schema: type, exclude_fields: set[str]) -> type:
195
+ """
196
+ Create a new Pydantic model that excludes specified fields from the original schema.
197
+
198
+ Args:
199
+ original_schema: The original Pydantic model class
200
+ exclude_fields: Set of field names to exclude from the schema
201
+
202
+ Returns:
203
+ A new Pydantic model class with the specified fields removed
204
+ """
205
+ from pydantic import BaseModel, Field, create_model
206
+ from pydantic.fields import PydanticUndefined
207
+
208
+ try:
209
+ # Get the original model's fields (Pydantic v2)
210
+ original_fields = original_schema.model_fields
211
+ filtered_field_definitions = {}
212
+
213
+ for name, field in original_fields.items():
214
+ if name not in exclude_fields:
215
+ # Reconstruct the field definition for create_model
216
+ field_type = field.annotation
217
+ field_default = (
218
+ field.default if field.default is not PydanticUndefined else ...
219
+ )
220
+ field_info = Field(default=field_default, description=field.description)
221
+ filtered_field_definitions[name] = (field_type, field_info)
222
+
223
+ # If no fields remain after filtering, return a generic empty schema
224
+ if not filtered_field_definitions:
225
+
226
+ class EmptySchema(BaseModel):
227
+ """Unity Catalog function with all parameters provided via partial args."""
45
228
 
46
- logger.debug(f"Retrieved tools: {tools}")
229
+ pass
47
230
 
48
- tools = [as_human_in_the_loop(tool=tool, function=function) for tool in tools]
231
+ return EmptySchema
232
+
233
+ # Create the new model dynamically
234
+ model_name = f"Filtered{original_schema.__name__}"
235
+ docstring = getattr(
236
+ original_schema, "__doc__", "Filtered Unity Catalog function parameters."
237
+ )
238
+
239
+ filtered_model = create_model(
240
+ model_name, __doc__=docstring, **filtered_field_definitions
241
+ )
242
+ return filtered_model
243
+
244
+ except Exception as e:
245
+ logger.warning(f"Failed to create filtered schema: {e}")
246
+
247
+ # Fallback to generic schema
248
+ class GenericFilteredSchema(BaseModel):
249
+ """Generic filtered schema for Unity Catalog function."""
250
+
251
+ pass
252
+
253
+ return GenericFilteredSchema
254
+
255
+
256
+ def with_partial_args(
257
+ tool: Union[ToolModel, Dict[str, Any]],
258
+ partial_args: dict[str, AnyVariable] = {},
259
+ ) -> StructuredTool:
260
+ """
261
+ Create a Unity Catalog tool with partial arguments pre-filled.
262
+
263
+ This function creates a wrapper tool that calls the UC function with partial arguments
264
+ already resolved, so the caller only needs to provide the remaining parameters.
265
+
266
+ Args:
267
+ tool: ToolModel containing the Unity Catalog function configuration
268
+ partial_args: Dictionary of arguments to pre-fill in the tool
269
+
270
+ Returns:
271
+ StructuredTool: A LangChain tool with partial arguments pre-filled
272
+ """
273
+ from unitycatalog.ai.langchain.toolkit import generate_function_input_params_schema
274
+
275
+ logger.debug(f"with_partial_args: {tool}")
276
+
277
+ # Convert dict-based variables to CompositeVariableModel and resolve their values
278
+ resolved_args = {}
279
+ for k, v in partial_args.items():
280
+ if isinstance(v, dict):
281
+ resolved_args[k] = value_of(CompositeVariableModel(**v))
282
+ else:
283
+ resolved_args[k] = value_of(v)
284
+
285
+ logger.debug(f"Resolved partial args: {resolved_args.keys()}")
286
+
287
+ if isinstance(tool, dict):
288
+ tool = ToolModel(**tool)
289
+
290
+ unity_catalog_function = tool.function
291
+ if isinstance(unity_catalog_function, dict):
292
+ unity_catalog_function = UnityCatalogFunctionModel(**unity_catalog_function)
293
+
294
+ function_name: str = unity_catalog_function.full_name
295
+ logger.debug(f"Creating UC tool with partial args for: {function_name}")
296
+
297
+ # Grant permissions if we have credentials
298
+ if "client_id" in resolved_args:
299
+ client_id: str = resolved_args["client_id"]
300
+ host: Optional[str] = resolved_args.get("host")
301
+ try:
302
+ _grant_function_permissions(function_name, client_id, host)
303
+ except Exception as e:
304
+ logger.warning(f"Failed to grant permissions: {e}")
305
+
306
+ # Create the client for function execution
307
+ client: DatabricksFunctionClient = DatabricksFunctionClient()
308
+
309
+ # Try to get the function schema for better tool definition
310
+ try:
311
+ function_info = client.get_function(function_name)
312
+ schema_info = generate_function_input_params_schema(function_info)
313
+ tool_description = (
314
+ function_info.comment or f"Unity Catalog function: {function_name}"
315
+ )
316
+
317
+ logger.debug(
318
+ f"Generated schema for function {function_name}: {schema_info.pydantic_model}"
319
+ )
320
+ logger.debug(f"Tool description: {tool_description}")
321
+
322
+ # Create a modified schema that excludes partial args
323
+ original_schema = schema_info.pydantic_model
324
+ schema_model = _create_filtered_schema(original_schema, resolved_args.keys())
325
+ logger.debug(
326
+ f"Filtered schema excludes partial args: {list(resolved_args.keys())}"
327
+ )
328
+
329
+ except Exception as e:
330
+ logger.warning(f"Could not introspect function {function_name}: {e}")
331
+ # Fallback to a generic schema
332
+ from pydantic import BaseModel
333
+
334
+ class GenericUCParams(BaseModel):
335
+ """Generic parameters for Unity Catalog function."""
336
+
337
+ pass
338
+
339
+ schema_model = GenericUCParams
340
+ tool_description = f"Unity Catalog function: {function_name}"
341
+
342
+ # Create a wrapper function that calls _execute_uc_function with partial args
343
+ def uc_function_wrapper(**kwargs) -> str:
344
+ """Wrapper function that executes Unity Catalog function with partial args."""
345
+ return _execute_uc_function(
346
+ client=client,
347
+ function_name=function_name,
348
+ partial_args=resolved_args,
349
+ **kwargs,
350
+ )
351
+
352
+ # Set the function name for the decorator
353
+ uc_function_wrapper.__name__ = tool.name or function_name.replace(".", "_")
354
+
355
+ # Create the tool using LangChain's StructuredTool
356
+ from langchain_core.tools import StructuredTool
357
+
358
+ partial_tool = StructuredTool.from_function(
359
+ func=uc_function_wrapper,
360
+ name=tool.name or function_name.replace(".", "_"),
361
+ description=tool_description,
362
+ args_schema=schema_model,
363
+ )
49
364
 
50
- return tools
365
+ return partial_tool
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dao-ai
3
- Version: 0.0.18
3
+ Version: 0.0.20
4
4
  Summary: DAO AI: A modular, multi-agent orchestration framework for complex AI workflows. Supports agent handoff, tool integration, and dynamic configuration via YAML.
5
5
  Project-URL: Homepage, https://github.com/natefleming/dao-ai
6
6
  Project-URL: Documentation, https://natefleming.github.io/dao-ai
@@ -24,22 +24,22 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
24
24
  Classifier: Topic :: Software Development :: Libraries :: Python Modules
25
25
  Classifier: Topic :: System :: Distributed Computing
26
26
  Requires-Python: >=3.12
27
- Requires-Dist: databricks-agents>=1.2.0
28
- Requires-Dist: databricks-langchain>=0.4.2
29
- Requires-Dist: databricks-sdk[openai]>=0.55.0
27
+ Requires-Dist: databricks-agents>=1.6.0
28
+ Requires-Dist: databricks-langchain>=0.8.0
29
+ Requires-Dist: databricks-sdk[openai]>=0.66.0
30
30
  Requires-Dist: duckduckgo-search>=8.0.2
31
31
  Requires-Dist: grandalf>=0.8
32
- Requires-Dist: langchain-mcp-adapters>=0.1.9
32
+ Requires-Dist: langchain-mcp-adapters>=0.1.10
33
33
  Requires-Dist: langchain-tavily>=0.2.11
34
34
  Requires-Dist: langchain>=0.3.27
35
35
  Requires-Dist: langgraph-checkpoint-postgres>=2.0.23
36
36
  Requires-Dist: langgraph-supervisor>=0.0.29
37
37
  Requires-Dist: langgraph-swarm>=0.0.14
38
- Requires-Dist: langgraph>=0.6.5
38
+ Requires-Dist: langgraph>=0.6.7
39
39
  Requires-Dist: langmem>=0.0.29
40
40
  Requires-Dist: loguru>=0.7.3
41
- Requires-Dist: mcp>=1.9.1
42
- Requires-Dist: mlflow>=3.3.2
41
+ Requires-Dist: mcp>=1.14.1
42
+ Requires-Dist: mlflow>=3.4.0
43
43
  Requires-Dist: nest-asyncio>=1.6.0
44
44
  Requires-Dist: openevals>=0.0.19
45
45
  Requires-Dist: openpyxl>=3.1.5
@@ -3,11 +3,11 @@ dao_ai/agent_as_code.py,sha256=kPSeDz2-1jRaed1TMs4LA3VECoyqe9_Ed2beRLB9gXQ,472
3
3
  dao_ai/catalog.py,sha256=sPZpHTD3lPx4EZUtIWeQV7VQM89WJ6YH__wluk1v2lE,4947
4
4
  dao_ai/chat_models.py,sha256=uhwwOTeLyHWqoTTgHrs4n5iSyTwe4EQcLKnh3jRxPWI,8626
5
5
  dao_ai/cli.py,sha256=Aez2TQW3Q8Ho1IaIkRggt0NevDxAAVPjXkePC5GPJF0,20429
6
- dao_ai/config.py,sha256=JlYC8N_7UL8VVkdSepiCUnR9NA5OsCVAigLjse7dMFM,49922
7
- dao_ai/graph.py,sha256=kXaGLGFVekDWqm-AHzti6LmrXnyi99VQ-AdCGuNb_xM,7831
8
- dao_ai/guardrails.py,sha256=-Qh0f_2Db9t4Nbrrx9FM7tnpqShjMoyxepZ0HByItfU,4027
6
+ dao_ai/config.py,sha256=ZO5ei45gnhqg1BtD0R9aekJz4ClmiTw2GHhOk4Idil4,51958
7
+ dao_ai/graph.py,sha256=gmD9mxODfXuvn9xWeBfewm1FiuVAWMLEdnZz7DNmSH0,7859
8
+ dao_ai/guardrails.py,sha256=4TKArDONRy8RwHzOT1plZ1rhy3x9GF_aeGpPCRl6wYA,4016
9
9
  dao_ai/messages.py,sha256=xl_3-WcFqZKCFCiov8sZOPljTdM3gX3fCHhxq-xFg2U,7005
10
- dao_ai/models.py,sha256=h_xFMK5FHQwPApEAYhvrt69y7ZUljmqThHTjp-yde_o,25368
10
+ dao_ai/models.py,sha256=Xb23U-lhDG8KyNRIijcJ4InluadlaGNy4rrYx7Cjgfg,26939
11
11
  dao_ai/nodes.py,sha256=SSuFNTXOdFaKg_aX-yUkQO7fM9wvNGu14lPXKDapU1U,8461
12
12
  dao_ai/prompts.py,sha256=vpmIbWs_szXUgNNDs5Gh2LcxKZti5pHDKSfoClUcgX0,1289
13
13
  dao_ai/state.py,sha256=GwbMbd1TWZx1T5iQrEOX6_rpxOitlmyeJ8dMr2o_pag,1031
@@ -18,23 +18,23 @@ dao_ai/hooks/__init__.py,sha256=LlHGIuiZt6vGW8K5AQo1XJEkBP5vDVtMhq0IdjcLrD4,417
18
18
  dao_ai/hooks/core.py,sha256=ZShHctUSoauhBgdf1cecy9-D7J6-sGn-pKjuRMumW5U,6663
19
19
  dao_ai/memory/__init__.py,sha256=1kHx_p9abKYFQ6EYD05nuc1GS5HXVEpufmjBGw_7Uho,260
20
20
  dao_ai/memory/base.py,sha256=99nfr2UZJ4jmfTL_KrqUlRSCoRxzkZyWyx5WqeUoMdQ,338
21
- dao_ai/memory/core.py,sha256=K45iCEFbqJCVxMi4m3vmBJi4c6TQ-UtKGzyugDTkPP0,4141
22
- dao_ai/memory/postgres.py,sha256=YILzA7xtqawPAOLFaGG_i17zW7cQxXTzTD8yd-ipe8k,12480
21
+ dao_ai/memory/core.py,sha256=g7chjBgVgx3iKjR2hghl0QL1j3802uIM_e7mgszur9M,4151
22
+ dao_ai/memory/postgres.py,sha256=pxxMjGotgqjrKhx0lVR3EAjSZTQgBpiPZOB0-cyjprc,12505
23
23
  dao_ai/providers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
24
24
  dao_ai/providers/base.py,sha256=-fjKypCOk28h6vioPfMj9YZSw_3Kcbi2nMuAyY7vX9k,1383
25
- dao_ai/providers/databricks.py,sha256=KLYrLccOA3Uws9nWJcJUZTbMz-MdR_onhlQeztbplCM,28073
25
+ dao_ai/providers/databricks.py,sha256=fZ8mGotfA3W3t5yUej2xGmGHSybjBFYr895mOctT418,28203
26
26
  dao_ai/tools/__init__.py,sha256=ye6MHaJY7tUnJ8336YJiLxuZr55zDPNdOw6gm7j5jlc,1103
27
27
  dao_ai/tools/agent.py,sha256=WbQnyziiT12TLMrA7xK0VuOU029tdmUBXbUl-R1VZ0Q,1886
28
28
  dao_ai/tools/core.py,sha256=Kei33S8vrmvPOAyrFNekaWmV2jqZ-IPS1QDSvU7RZF0,1984
29
29
  dao_ai/tools/genie.py,sha256=GzV5lfDYKmzW_lSLxAsPaTwnzX6GxQOB1UcLaTDqpfY,2787
30
- dao_ai/tools/human_in_the_loop.py,sha256=IBmQJmpxkdDxnBNyABc_-dZhhsQlTNTkPyUXgkHKIgY,3466
30
+ dao_ai/tools/human_in_the_loop.py,sha256=yk35MO9eNETnYFH-sqlgR-G24TrEgXpJlnZUustsLkI,3681
31
31
  dao_ai/tools/mcp.py,sha256=auEt_dwv4J26fr5AgLmwmnAsI894-cyuvkvjItzAUxs,4419
32
32
  dao_ai/tools/python.py,sha256=XcQiTMshZyLUTVR5peB3vqsoUoAAy8gol9_pcrhddfI,1831
33
33
  dao_ai/tools/time.py,sha256=Y-23qdnNHzwjvnfkWvYsE7PoWS1hfeKy44tA7sCnNac,8759
34
- dao_ai/tools/unity_catalog.py,sha256=PXfLj2EgyQgaXq4Qq3t25AmTC4KyVCF_-sCtg6enens,1404
34
+ dao_ai/tools/unity_catalog.py,sha256=uX_h52BuBAr4c9UeqSMI7DNz3BPRLeai5tBVW4sJqRI,13113
35
35
  dao_ai/tools/vector_search.py,sha256=EDYQs51zIPaAP0ma1D81wJT77GQ-v-cjb2XrFVWfWdg,2621
36
- dao_ai-0.0.18.dist-info/METADATA,sha256=9lTAXjEqQHxl6dmRMyiqUnYT1Nh_wJpSeJXRG8bGZGg,41378
37
- dao_ai-0.0.18.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
38
- dao_ai-0.0.18.dist-info/entry_points.txt,sha256=Xa-UFyc6gWGwMqMJOt06ZOog2vAfygV_DSwg1AiP46g,43
39
- dao_ai-0.0.18.dist-info/licenses/LICENSE,sha256=YZt3W32LtPYruuvHE9lGk2bw6ZPMMJD8yLrjgHybyz4,1069
40
- dao_ai-0.0.18.dist-info/RECORD,,
36
+ dao_ai-0.0.20.dist-info/METADATA,sha256=gWNRLhswz5sCe1vxbBQ6dGlgiObI9nI829Q5DQRqRRY,41380
37
+ dao_ai-0.0.20.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
38
+ dao_ai-0.0.20.dist-info/entry_points.txt,sha256=Xa-UFyc6gWGwMqMJOt06ZOog2vAfygV_DSwg1AiP46g,43
39
+ dao_ai-0.0.20.dist-info/licenses/LICENSE,sha256=YZt3W32LtPYruuvHE9lGk2bw6ZPMMJD8yLrjgHybyz4,1069
40
+ dao_ai-0.0.20.dist-info/RECORD,,