dao-ai 0.0.19__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dao_ai/config.py CHANGED
@@ -427,7 +427,8 @@ class GenieRoomModel(BaseModel, IsDatabricksResource):
427
427
  def as_resources(self) -> Sequence[DatabricksResource]:
428
428
  return [
429
429
  DatabricksGenieSpace(
430
- genie_space_id=self.space_id, on_behalf_of_user=self.on_behalf_of_user
430
+ genie_space_id=value_of(self.space_id),
431
+ on_behalf_of_user=self.on_behalf_of_user,
431
432
  )
432
433
  ]
433
434
 
@@ -437,7 +438,7 @@ class GenieRoomModel(BaseModel, IsDatabricksResource):
437
438
  return self
438
439
 
439
440
 
440
- class VolumeModel(BaseModel, HasFullName):
441
+ class VolumeModel(BaseModel, HasFullName, IsDatabricksResource):
441
442
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
442
443
  schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
443
444
  name: str
@@ -455,6 +456,13 @@ class VolumeModel(BaseModel, HasFullName):
455
456
  provider: ServiceProvider = DatabricksProvider(w=w)
456
457
  provider.create_volume(self)
457
458
 
459
+ @property
460
+ def api_scopes(self) -> Sequence[str]:
461
+ return ["files.files", "catalog.volumes"]
462
+
463
+ def as_resources(self) -> Sequence[DatabricksResource]:
464
+ return []
465
+
458
466
 
459
467
  class VolumePathModel(BaseModel, HasFullName):
460
468
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
@@ -683,7 +691,8 @@ class WarehouseModel(BaseModel, IsDatabricksResource):
683
691
  def as_resources(self) -> Sequence[DatabricksResource]:
684
692
  return [
685
693
  DatabricksSQLWarehouse(
686
- warehouse_id=self.warehouse_id, on_behalf_of_user=self.on_behalf_of_user
694
+ warehouse_id=value_of(self.warehouse_id),
695
+ on_behalf_of_user=self.on_behalf_of_user,
687
696
  )
688
697
  ]
689
698
 
@@ -879,6 +888,12 @@ class FactoryFunctionModel(BaseFunctionModel, HasFullName):
879
888
 
880
889
  return [create_factory_tool(self, **kwargs)]
881
890
 
891
+ @model_validator(mode="after")
892
+ def update_args(self):
893
+ for key, value in self.args.items():
894
+ self.args[key] = value_of(value)
895
+ return self
896
+
882
897
 
883
898
  class TransportType(str, Enum):
884
899
  STREAMABLE_HTTP = "streamable_http"
@@ -963,6 +978,7 @@ class McpFunctionModel(BaseFunctionModel, HasFullName):
963
978
  class UnityCatalogFunctionModel(BaseFunctionModel, HasFullName):
964
979
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
965
980
  schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
981
+ partial_args: Optional[dict[str, AnyVariable]] = Field(default_factory=dict)
966
982
  type: Literal[FunctionType.UNITY_CATALOG] = FunctionType.UNITY_CATALOG
967
983
 
968
984
  @property
@@ -1261,12 +1277,12 @@ class AppModel(BaseModel):
1261
1277
  if len(self.agents) > 1:
1262
1278
  default_agent: AgentModel = self.agents[0]
1263
1279
  self.orchestration = OrchestrationModel(
1264
- swarm=SupervisorModel(model=default_agent.model)
1280
+ supervisor=SupervisorModel(model=default_agent.model)
1265
1281
  )
1266
1282
  elif len(self.agents) == 1:
1267
1283
  default_agent: AgentModel = self.agents[0]
1268
1284
  self.orchestration = OrchestrationModel(
1269
- supervisor=SwarmModel(
1285
+ swarm=SwarmModel(
1270
1286
  model=default_agent.model, default_agent=default_agent
1271
1287
  )
1272
1288
  )
dao_ai/guardrails.py CHANGED
@@ -87,12 +87,12 @@ def judge_node(guardrails: GuardrailModel) -> RunnableLike:
87
87
  )
88
88
 
89
89
  if eval_result["score"]:
90
- logger.debug("Response approved by judge")
90
+ logger.debug("Response approved by judge")
91
91
  logger.debug(f"Judge's comment: {eval_result['comment']}")
92
92
  return
93
93
  else:
94
94
  # Otherwise, return the judge's critique as a new user message
95
- logger.warning("⚠️ Judge requested improvements")
95
+ logger.warning("Judge requested improvements")
96
96
  comment: str = eval_result["comment"]
97
97
  logger.warning(f"Judge's critique: {comment}")
98
98
  content: str = "\n".join([human_message.content, comment])
dao_ai/memory/postgres.py CHANGED
@@ -20,137 +20,6 @@ from dao_ai.memory.base import (
20
20
  )
21
21
 
22
22
 
23
- class PatchedAsyncPostgresStore(AsyncPostgresStore):
24
- """
25
- Patched version of AsyncPostgresStore that properly handles event loop initialization
26
- and task lifecycle management.
27
-
28
- The issues occur because:
29
- 1. AsyncBatchedBaseStore.__init__ calls asyncio.get_running_loop() and fails if no event loop is running
30
- 2. The background _task can complete/fail, causing assertions in asearch/other methods to fail
31
- 3. Destructor tries to access _task even when it doesn't exist
32
-
33
- This patch ensures proper initialization and handles task lifecycle robustly.
34
- """
35
-
36
- def __init__(self, *args, **kwargs):
37
- # Ensure we have a running event loop before calling super().__init__()
38
- loop = None
39
- try:
40
- loop = asyncio.get_running_loop()
41
- except RuntimeError:
42
- # No running loop - create one temporarily for initialization
43
- loop = asyncio.new_event_loop()
44
- asyncio.set_event_loop(loop)
45
-
46
- try:
47
- super().__init__(*args, **kwargs)
48
- except Exception as e:
49
- # If parent initialization fails, ensure _task is at least defined
50
- if not hasattr(self, "_task"):
51
- self._task = None
52
- logger.warning(f"AsyncPostgresStore initialization failed: {e}")
53
- raise
54
-
55
- def _ensure_task_running(self):
56
- """
57
- Ensure the background task is running. Recreate it if necessary.
58
- """
59
- if not hasattr(self, "_task") or self._task is None:
60
- logger.error("AsyncPostgresStore task not initialized")
61
- raise RuntimeError("Store task not properly initialized")
62
-
63
- if self._task.done():
64
- logger.warning(
65
- "AsyncPostgresStore background task completed, attempting to restart"
66
- )
67
- # Try to get the task exception for debugging
68
- try:
69
- exception = self._task.exception()
70
- if exception:
71
- logger.error(f"Background task failed with: {exception}")
72
- else:
73
- logger.info("Background task completed normally")
74
- except Exception as e:
75
- logger.warning(f"Could not determine task completion reason: {e}")
76
-
77
- # Try to restart the task
78
- try:
79
- import weakref
80
-
81
- from langgraph.store.base.batch import _run
82
-
83
- self._task = self._loop.create_task(
84
- _run(self._aqueue, weakref.ref(self))
85
- )
86
- logger.info("Successfully restarted AsyncPostgresStore background task")
87
- except Exception as e:
88
- logger.error(f"Failed to restart background task: {e}")
89
- raise RuntimeError(
90
- f"Store background task failed and could not be restarted: {e}"
91
- )
92
-
93
- async def asearch(
94
- self,
95
- namespace_prefix,
96
- /,
97
- *,
98
- query=None,
99
- filter=None,
100
- limit=10,
101
- offset=0,
102
- refresh_ttl=None,
103
- ):
104
- """
105
- Override asearch to handle task lifecycle issues gracefully.
106
- """
107
- self._ensure_task_running()
108
-
109
- # Call parent implementation if task is healthy
110
- return await super().asearch(
111
- namespace_prefix,
112
- query=query,
113
- filter=filter,
114
- limit=limit,
115
- offset=offset,
116
- refresh_ttl=refresh_ttl,
117
- )
118
-
119
- async def aget(self, namespace, key, /, *, refresh_ttl=None):
120
- """Override aget with task lifecycle management."""
121
- self._ensure_task_running()
122
- return await super().aget(namespace, key, refresh_ttl=refresh_ttl)
123
-
124
- async def aput(self, namespace, key, value, /, *, refresh_ttl=None):
125
- """Override aput with task lifecycle management."""
126
- self._ensure_task_running()
127
- return await super().aput(namespace, key, value, refresh_ttl=refresh_ttl)
128
-
129
- async def adelete(self, namespace, key):
130
- """Override adelete with task lifecycle management."""
131
- self._ensure_task_running()
132
- return await super().adelete(namespace, key)
133
-
134
- async def alist_namespaces(self, *, prefix=None):
135
- """Override alist_namespaces with task lifecycle management."""
136
- self._ensure_task_running()
137
- return await super().alist_namespaces(prefix=prefix)
138
-
139
- def __del__(self):
140
- """
141
- Override destructor to handle missing _task attribute gracefully.
142
- """
143
- try:
144
- # Only try to cancel if _task exists and is not None
145
- if hasattr(self, "_task") and self._task is not None:
146
- if not self._task.done():
147
- self._task.cancel()
148
- except Exception as e:
149
- # Log but don't raise - destructors should not raise exceptions
150
- logger.debug(f"AsyncPostgresStore destructor cleanup: {e}")
151
- pass
152
-
153
-
154
23
  class AsyncPostgresPoolManager:
155
24
  _pools: dict[str, AsyncConnectionPool] = {}
156
25
  _lock: asyncio.Lock = asyncio.Lock()
@@ -205,8 +74,17 @@ class AsyncPostgresPoolManager:
205
74
  async with cls._lock:
206
75
  for connection_key, pool in cls._pools.items():
207
76
  try:
208
- await pool.close()
77
+ # Use a short timeout to avoid blocking on pool closure
78
+ await asyncio.wait_for(pool.close(), timeout=2.0)
209
79
  logger.debug(f"Closed PostgreSQL pool: {connection_key}")
80
+ except asyncio.TimeoutError:
81
+ logger.warning(
82
+ f"Timeout closing pool {connection_key}, forcing closure"
83
+ )
84
+ except asyncio.CancelledError:
85
+ logger.warning(
86
+ f"Pool closure cancelled for {connection_key} (shutdown in progress)"
87
+ )
210
88
  except Exception as e:
211
89
  logger.error(f"Error closing pool {connection_key}: {e}")
212
90
  cls._pools.clear()
@@ -251,7 +129,7 @@ class AsyncPostgresStoreManager(StoreManagerBase):
251
129
  )
252
130
 
253
131
  # Create store with the shared pool (using patched version)
254
- self._store = PatchedAsyncPostgresStore(conn=self.pool)
132
+ self._store = AsyncPostgresStore(conn=self.pool)
255
133
 
256
134
  await self._store.setup()
257
135
 
@@ -500,8 +378,27 @@ def _shutdown_pools():
500
378
 
501
379
  def _shutdown_async_pools():
502
380
  try:
503
- asyncio.run(AsyncPostgresPoolManager.close_all_pools())
504
- logger.debug("Successfully closed all asynchronous PostgreSQL pools")
381
+ # Try to get the current event loop first
382
+ try:
383
+ loop = asyncio.get_running_loop()
384
+ # If we're already in an event loop, create a task
385
+ loop.create_task(AsyncPostgresPoolManager.close_all_pools())
386
+ logger.debug("Scheduled async pool closure in running event loop")
387
+ except RuntimeError:
388
+ # No running loop, try to get or create one
389
+ try:
390
+ loop = asyncio.get_event_loop()
391
+ if loop.is_closed():
392
+ # Loop is closed, create a new one
393
+ loop = asyncio.new_event_loop()
394
+ asyncio.set_event_loop(loop)
395
+ loop.run_until_complete(AsyncPostgresPoolManager.close_all_pools())
396
+ logger.debug("Successfully closed all asynchronous PostgreSQL pools")
397
+ except Exception as inner_e:
398
+ # If all else fails, just log the error
399
+ logger.warning(
400
+ f"Could not close async pools cleanly during shutdown: {inner_e}"
401
+ )
505
402
  except Exception as e:
506
403
  logger.error(
507
404
  f"Error closing asynchronous PostgreSQL pools during shutdown: {e}"
dao_ai/models.py CHANGED
@@ -5,6 +5,7 @@ from typing import Any, Generator, Optional, Sequence, Union
5
5
 
6
6
  from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
7
7
  from langgraph.graph.state import CompiledStateGraph
8
+ from langgraph.types import StateSnapshot
8
9
  from loguru import logger
9
10
  from mlflow import MlflowClient
10
11
  from mlflow.pyfunc import ChatAgent, ChatModel, ResponsesAgent
@@ -59,6 +60,113 @@ def get_latest_model_version(model_name: str) -> int:
59
60
  return latest_version
60
61
 
61
62
 
63
+ async def get_state_snapshot_async(
64
+ graph: CompiledStateGraph, thread_id: str
65
+ ) -> Optional[StateSnapshot]:
66
+ """
67
+ Retrieve the state snapshot from the graph for a given thread_id asynchronously.
68
+
69
+ This utility function accesses the graph's checkpointer to retrieve the current
70
+ state snapshot, which contains the full state values and metadata.
71
+
72
+ Args:
73
+ graph: The compiled LangGraph state machine
74
+ thread_id: The thread/conversation ID to retrieve state for
75
+
76
+ Returns:
77
+ StateSnapshot if found, None otherwise
78
+ """
79
+ logger.debug(f"Retrieving state snapshot for thread_id: {thread_id}")
80
+ try:
81
+ # Check if graph has a checkpointer
82
+ if graph.checkpointer is None:
83
+ logger.debug("No checkpointer available in graph")
84
+ return None
85
+
86
+ # Get the current state from the checkpointer (use async version)
87
+ config: dict[str, Any] = {"configurable": {"thread_id": thread_id}}
88
+ state_snapshot: Optional[StateSnapshot] = await graph.aget_state(config)
89
+
90
+ if state_snapshot is None:
91
+ logger.debug(f"No state found for thread_id: {thread_id}")
92
+ return None
93
+
94
+ return state_snapshot
95
+
96
+ except Exception as e:
97
+ logger.warning(f"Error retrieving state snapshot for thread {thread_id}: {e}")
98
+ return None
99
+
100
+
101
+ def get_state_snapshot(
102
+ graph: CompiledStateGraph, thread_id: str
103
+ ) -> Optional[StateSnapshot]:
104
+ """
105
+ Retrieve the state snapshot from the graph for a given thread_id.
106
+
107
+ This is a synchronous wrapper around get_state_snapshot_async.
108
+ Use this for backward compatibility in synchronous contexts.
109
+
110
+ Args:
111
+ graph: The compiled LangGraph state machine
112
+ thread_id: The thread/conversation ID to retrieve state for
113
+
114
+ Returns:
115
+ StateSnapshot if found, None otherwise
116
+ """
117
+ import asyncio
118
+
119
+ try:
120
+ loop = asyncio.get_event_loop()
121
+ except RuntimeError:
122
+ loop = asyncio.new_event_loop()
123
+ asyncio.set_event_loop(loop)
124
+
125
+ try:
126
+ return loop.run_until_complete(get_state_snapshot_async(graph, thread_id))
127
+ except Exception as e:
128
+ logger.warning(f"Error in synchronous state snapshot retrieval: {e}")
129
+ return None
130
+
131
+
132
+ def get_genie_conversation_ids_from_state(
133
+ state_snapshot: Optional[StateSnapshot],
134
+ ) -> dict[str, str]:
135
+ """
136
+ Extract genie_conversation_ids from a state snapshot.
137
+
138
+ This function extracts the genie_conversation_ids dictionary from the state
139
+ snapshot values if present.
140
+
141
+ Args:
142
+ state_snapshot: The state snapshot to extract conversation IDs from
143
+
144
+ Returns:
145
+ A dictionary mapping genie space_id to conversation_id, or empty dict if not found
146
+ """
147
+ if state_snapshot is None:
148
+ return {}
149
+
150
+ try:
151
+ # Extract state values - these contain the actual state data
152
+ state_values: dict[str, Any] = state_snapshot.values
153
+
154
+ # Extract genie_conversation_ids from state values
155
+ genie_conversation_ids: dict[str, str] = state_values.get(
156
+ "genie_conversation_ids", {}
157
+ )
158
+
159
+ if genie_conversation_ids:
160
+ logger.debug(f"Retrieved genie_conversation_ids: {genie_conversation_ids}")
161
+ return genie_conversation_ids
162
+
163
+ return {}
164
+
165
+ except Exception as e:
166
+ logger.warning(f"Error extracting genie_conversation_ids from state: {e}")
167
+ return {}
168
+
169
+
62
170
  class LanggraphChatModel(ChatModel):
63
171
  """
64
172
  ChatModel that delegates requests to a LangGraph CompiledStateGraph.
@@ -257,7 +365,19 @@ class LanggraphResponsesAgent(ResponsesAgent):
257
365
  text=last_message.content, id=f"msg_{uuid.uuid4().hex[:8]}"
258
366
  )
259
367
 
260
- custom_outputs = custom_inputs
368
+ # Retrieve genie_conversation_ids from state if available
369
+ custom_outputs: dict[str, Any] = custom_inputs.copy()
370
+ thread_id: Optional[str] = context.thread_id
371
+ if thread_id:
372
+ state_snapshot: Optional[StateSnapshot] = loop.run_until_complete(
373
+ get_state_snapshot_async(self.graph, thread_id)
374
+ )
375
+ genie_conversation_ids: dict[str, str] = (
376
+ get_genie_conversation_ids_from_state(state_snapshot)
377
+ )
378
+ if genie_conversation_ids:
379
+ custom_outputs["genie_conversation_ids"] = genie_conversation_ids
380
+
261
381
  return ResponsesAgentResponse(
262
382
  output=[output_item], custom_outputs=custom_outputs
263
383
  )
@@ -318,7 +438,22 @@ class LanggraphResponsesAgent(ResponsesAgent):
318
438
  **self.create_text_delta(delta=content, item_id=item_id)
319
439
  )
320
440
 
321
- custom_outputs = custom_inputs
441
+ # Retrieve genie_conversation_ids from state if available
442
+ custom_outputs: dict[str, Any] = custom_inputs.copy()
443
+ thread_id: Optional[str] = context.thread_id
444
+
445
+ if thread_id:
446
+ state_snapshot: Optional[
447
+ StateSnapshot
448
+ ] = await get_state_snapshot_async(self.graph, thread_id)
449
+ genie_conversation_ids: dict[str, str] = (
450
+ get_genie_conversation_ids_from_state(state_snapshot)
451
+ )
452
+ if genie_conversation_ids:
453
+ custom_outputs["genie_conversation_ids"] = (
454
+ genie_conversation_ids
455
+ )
456
+
322
457
  # Yield final output item
323
458
  yield ResponsesAgentStreamEvent(
324
459
  type="response.output_item.done",
@@ -226,6 +226,7 @@ class DatabricksProvider(ServiceProvider):
226
226
  config.resources.connections.values()
227
227
  )
228
228
  databases: Sequence[DatabaseModel] = list(config.resources.databases.values())
229
+ volumes: Sequence[VolumeModel] = list(config.resources.volumes.values())
229
230
 
230
231
  resources: Sequence[IsDatabricksResource] = (
231
232
  llms
@@ -236,6 +237,7 @@ class DatabricksProvider(ServiceProvider):
236
237
  + tables
237
238
  + connections
238
239
  + databases
240
+ + volumes
239
241
  )
240
242
 
241
243
  # Flatten all resources from all models into a single list
dao_ai/state.py CHANGED
@@ -31,6 +31,9 @@ class SharedState(MessagesState):
31
31
  is_valid: bool # message validation node
32
32
  message_error: str
33
33
 
34
+ # A mapping of genie space_id to conversation_id
35
+ genie_conversation_ids: dict[str, str] # Genie
36
+
34
37
 
35
38
  class Context(BaseModel):
36
39
  user_id: str | None = None
dao_ai/tools/genie.py CHANGED
@@ -1,20 +1,284 @@
1
+ import bisect
2
+ import json
3
+ import logging
1
4
  import os
5
+ import time
6
+ from dataclasses import asdict, dataclass
7
+ from datetime import datetime
2
8
  from textwrap import dedent
3
- from typing import Any, Callable, Optional
9
+ from typing import Annotated, Any, Callable, Optional, Union
4
10
 
5
- from databricks_ai_bridge.genie import GenieResponse
6
- from databricks_langchain.genie import Genie
7
- from langchain_core.tools import StructuredTool
11
+ import mlflow
12
+ import pandas as pd
13
+ from databricks.sdk import WorkspaceClient
14
+ from langchain_core.messages import ToolMessage
15
+ from langchain_core.tools import InjectedToolCallId, tool
16
+ from langgraph.prebuilt import InjectedState
17
+ from langgraph.types import Command
18
+ from loguru import logger
19
+ from pydantic import BaseModel, Field
8
20
 
9
- from dao_ai.config import (
10
- GenieRoomModel,
11
- )
21
+ from dao_ai.config import AnyVariable, CompositeVariableModel, GenieRoomModel, value_of
22
+
23
+ MAX_TOKENS_OF_DATA: int = 20000
24
+ MAX_ITERATIONS: int = 50
25
+ DEFAULT_POLLING_INTERVAL_SECS: int = 2
26
+
27
+
28
+ def _count_tokens(text):
29
+ import tiktoken
30
+
31
+ encoding = tiktoken.encoding_for_model("gpt-4o")
32
+ return len(encoding.encode(text))
33
+
34
+
35
+ @dataclass
36
+ class GenieResponse:
37
+ conversation_id: str
38
+ result: Union[str, pd.DataFrame]
39
+ query: Optional[str] = ""
40
+ description: Optional[str] = ""
41
+
42
+ def to_json(self):
43
+ return json.dumps(asdict(self))
44
+
45
+
46
+ class GenieToolInput(BaseModel):
47
+ """Input schema for the Genie tool."""
48
+
49
+ question: str = Field(
50
+ description="The question to ask Genie about your data. Ask simple, clear questions about your tabular data. For complex analysis, ask multiple simple questions rather than one complex question."
51
+ )
52
+
53
+
54
+ def _truncate_result(dataframe: pd.DataFrame) -> str:
55
+ query_result = dataframe.to_markdown()
56
+ tokens_used = _count_tokens(query_result)
57
+
58
+ # If the full result fits, return it
59
+ if tokens_used <= MAX_TOKENS_OF_DATA:
60
+ return query_result.strip()
61
+
62
+ def is_too_big(n):
63
+ return _count_tokens(dataframe.iloc[:n].to_markdown()) > MAX_TOKENS_OF_DATA
64
+
65
+ # Use bisect_left to find the cutoff point of rows within the max token data limit in a O(log n) complexity
66
+ # Passing True, as this is the target value we are looking for when _is_too_big returns
67
+ cutoff = bisect.bisect_left(range(len(dataframe) + 1), True, key=is_too_big)
68
+
69
+ # Slice to the found limit
70
+ truncated_df = dataframe.iloc[:cutoff]
71
+
72
+ # Edge case: Cannot return any rows because of tokens so return an empty string
73
+ if len(truncated_df) == 0:
74
+ return ""
75
+
76
+ truncated_result = truncated_df.to_markdown()
77
+
78
+ # Double-check edge case if we overshot by one
79
+ if _count_tokens(truncated_result) > MAX_TOKENS_OF_DATA:
80
+ truncated_result = truncated_df.iloc[:-1].to_markdown()
81
+ return truncated_result
82
+
83
+
84
+ @mlflow.trace(span_type="PARSER")
85
+ def _parse_query_result(resp, truncate_results) -> Union[str, pd.DataFrame]:
86
+ output = resp["result"]
87
+ if not output:
88
+ return "EMPTY"
89
+
90
+ columns = resp["manifest"]["schema"]["columns"]
91
+ header = [str(col["name"]) for col in columns]
92
+ rows = []
93
+
94
+ for item in output["data_array"]:
95
+ row = []
96
+ for column, value in zip(columns, item):
97
+ type_name = column["type_name"]
98
+ if value is None:
99
+ row.append(None)
100
+ continue
101
+
102
+ if type_name in ["INT", "LONG", "SHORT", "BYTE"]:
103
+ row.append(int(value))
104
+ elif type_name in ["FLOAT", "DOUBLE", "DECIMAL"]:
105
+ row.append(float(value))
106
+ elif type_name == "BOOLEAN":
107
+ row.append(value.lower() == "true")
108
+ elif type_name == "DATE" or type_name == "TIMESTAMP":
109
+ row.append(datetime.strptime(value[:10], "%Y-%m-%d").date())
110
+ elif type_name == "BINARY":
111
+ row.append(bytes(value, "utf-8"))
112
+ else:
113
+ row.append(value)
114
+
115
+ rows.append(row)
116
+
117
+ dataframe = pd.DataFrame(rows, columns=header)
118
+
119
+ if truncate_results:
120
+ query_result = _truncate_result(dataframe)
121
+ else:
122
+ query_result = dataframe.to_markdown()
123
+
124
+ return query_result.strip()
125
+
126
+
127
+ class Genie:
128
+ def __init__(
129
+ self,
130
+ space_id,
131
+ client: WorkspaceClient | None = None,
132
+ truncate_results: bool = False,
133
+ polling_interval: int = DEFAULT_POLLING_INTERVAL_SECS,
134
+ ):
135
+ self.space_id = space_id
136
+ workspace_client = client or WorkspaceClient()
137
+ self.genie = workspace_client.genie
138
+ self.description = self.genie.get_space(space_id).description
139
+ self.headers = {
140
+ "Accept": "application/json",
141
+ "Content-Type": "application/json",
142
+ }
143
+ self.truncate_results = truncate_results
144
+ if polling_interval < 1 or polling_interval > 30:
145
+ raise ValueError("poll_interval must be between 1 and 30 seconds")
146
+ self.poll_interval = polling_interval
147
+
148
+ @mlflow.trace()
149
+ def start_conversation(self, content):
150
+ resp = self.genie._api.do(
151
+ "POST",
152
+ f"/api/2.0/genie/spaces/{self.space_id}/start-conversation",
153
+ body={"content": content},
154
+ headers=self.headers,
155
+ )
156
+ return resp
157
+
158
+ @mlflow.trace()
159
+ def create_message(self, conversation_id, content):
160
+ resp = self.genie._api.do(
161
+ "POST",
162
+ f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages",
163
+ body={"content": content},
164
+ headers=self.headers,
165
+ )
166
+ return resp
167
+
168
+ @mlflow.trace()
169
+ def poll_for_result(self, conversation_id, message_id):
170
+ @mlflow.trace()
171
+ def poll_query_results(attachment_id, query_str, description):
172
+ iteration_count = 0
173
+ while iteration_count < MAX_ITERATIONS:
174
+ iteration_count += 1
175
+ resp = self.genie._api.do(
176
+ "GET",
177
+ f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages/{message_id}/attachments/{attachment_id}/query-result",
178
+ headers=self.headers,
179
+ )["statement_response"]
180
+ state = resp["status"]["state"]
181
+ if state == "SUCCEEDED":
182
+ result = _parse_query_result(resp, self.truncate_results)
183
+ return GenieResponse(
184
+ conversation_id, result, query_str, description
185
+ )
186
+ elif state in ["RUNNING", "PENDING"]:
187
+ logging.debug("Waiting for query result...")
188
+ time.sleep(self.poll_interval)
189
+ else:
190
+ return GenieResponse(
191
+ conversation_id,
192
+ f"No query result: {resp['state']}",
193
+ query_str,
194
+ description,
195
+ )
196
+ return GenieResponse(
197
+ conversation_id,
198
+ f"Genie query for result timed out after {MAX_ITERATIONS} iterations of {self.poll_interval} seconds",
199
+ query_str,
200
+ description,
201
+ )
202
+
203
+ @mlflow.trace()
204
+ def poll_result():
205
+ iteration_count = 0
206
+ while iteration_count < MAX_ITERATIONS:
207
+ iteration_count += 1
208
+ resp = self.genie._api.do(
209
+ "GET",
210
+ f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages/{message_id}",
211
+ headers=self.headers,
212
+ )
213
+ if resp["status"] == "COMPLETED":
214
+ # Check if attachments key exists in response
215
+ attachments = resp.get("attachments", [])
216
+ if not attachments:
217
+ # Handle case where response has no attachments
218
+ return GenieResponse(
219
+ conversation_id,
220
+ result=f"Genie query completed but no attachments found. Response: {resp}",
221
+ )
222
+
223
+ attachment = next((r for r in attachments if "query" in r), None)
224
+ if attachment:
225
+ query_obj = attachment["query"]
226
+ description = query_obj.get("description", "")
227
+ query_str = query_obj.get("query", "")
228
+ attachment_id = attachment["attachment_id"]
229
+ return poll_query_results(attachment_id, query_str, description)
230
+ if resp["status"] == "COMPLETED":
231
+ text_content = next(
232
+ (r for r in attachments if "text" in r), None
233
+ )
234
+ if text_content:
235
+ return GenieResponse(
236
+ conversation_id, result=text_content["text"]["content"]
237
+ )
238
+ return GenieResponse(
239
+ conversation_id,
240
+ result="Genie query completed but no text content found in attachments.",
241
+ )
242
+ elif resp["status"] in {"CANCELLED", "QUERY_RESULT_EXPIRED"}:
243
+ return GenieResponse(
244
+ conversation_id, result=f"Genie query {resp['status'].lower()}."
245
+ )
246
+ elif resp["status"] == "FAILED":
247
+ return GenieResponse(
248
+ conversation_id,
249
+ result=f"Genie query failed with error: {resp.get('error', 'Unknown error')}",
250
+ )
251
+ # includes EXECUTING_QUERY, Genie can retry after this status
252
+ else:
253
+ logging.debug(f"Waiting...: {resp['status']}")
254
+ time.sleep(self.poll_interval)
255
+ return GenieResponse(
256
+ conversation_id,
257
+ f"Genie query timed out after {MAX_ITERATIONS} iterations of {self.poll_interval} seconds",
258
+ )
259
+
260
+ return poll_result()
261
+
262
+ @mlflow.trace()
263
+ def ask_question(self, question: str, conversation_id: str | None = None):
264
+ logger.debug(
265
+ f"ask_question called with question: {question}, conversation_id: {conversation_id}"
266
+ )
267
+ if conversation_id:
268
+ resp = self.create_message(conversation_id, question)
269
+ else:
270
+ resp = self.start_conversation(question)
271
+ logger.debug(f"ask_question response: {resp}")
272
+ return self.poll_for_result(resp["conversation_id"], resp["message_id"])
12
273
 
13
274
 
14
275
  def create_genie_tool(
15
276
  genie_room: GenieRoomModel | dict[str, Any],
16
277
  name: Optional[str] = None,
17
278
  description: Optional[str] = None,
279
+ persist_conversation: bool = True,
280
+ truncate_results: bool = False,
281
+ poll_interval: int = DEFAULT_POLLING_INTERVAL_SECS,
18
282
  ) -> Callable[[str], GenieResponse]:
19
283
  """
20
284
  Create a tool for interacting with Databricks Genie for natural language queries to databases.
@@ -24,22 +288,33 @@ def create_genie_tool(
24
288
  answering questions about inventory, sales, and other structured retail data.
25
289
 
26
290
  Args:
27
- space_id: Databricks workspace ID where Genie is configured. If None, tries to
28
- get it from DATABRICKS_GENIE_SPACE_ID environment variable.
291
+ genie_room: GenieRoomModel or dict containing Genie configuration
292
+ name: Optional custom name for the tool. If None, uses default "genie_tool"
293
+ description: Optional custom description for the tool. If None, uses default description
29
294
 
30
295
  Returns:
31
- A callable tool function that processes natural language queries through Genie
296
+ A LangGraph tool that processes natural language queries through Genie
32
297
  """
33
298
 
34
299
  if isinstance(genie_room, dict):
35
300
  genie_room = GenieRoomModel(**genie_room)
36
301
 
37
- space_id: str = genie_room.space_id or os.environ.get("DATABRICKS_GENIE_SPACE_ID")
38
-
39
- genie: Genie = Genie(
40
- space_id=space_id,
41
- client=genie_room.workspace_client,
302
+ space_id: AnyVariable = genie_room.space_id or os.environ.get(
303
+ "DATABRICKS_GENIE_SPACE_ID"
304
+ )
305
+ space_id: AnyVariable = genie_room.space_id or os.environ.get(
306
+ "DATABRICKS_GENIE_SPACE_ID"
42
307
  )
308
+ if isinstance(space_id, dict):
309
+ space_id = CompositeVariableModel(**space_id)
310
+ space_id = value_of(space_id)
311
+
312
+ # genie: Genie = Genie(
313
+ # space_id=space_id,
314
+ # client=genie_room.workspace_client,
315
+ # truncate_results=truncate_results,
316
+ # polling_interval=poll_interval,
317
+ # )
43
318
 
44
319
  default_description: str = dedent("""
45
320
  This tool lets you have a conversation and chat with tabular data about <topic>. You should ask
@@ -49,29 +324,66 @@ def create_genie_tool(
49
324
  Prefer to call this tool multiple times rather than asking a complex question.
50
325
  """)
51
326
 
52
- if description is None:
53
- description = default_description
327
+ tool_description: str = (
328
+ description if description is not None else default_description
329
+ )
330
+ tool_name: str = name if name is not None else "genie_tool"
54
331
 
55
- doc_signature: str = dedent("""
56
- Args:
57
- question (str): The question to ask to ask Genie
332
+ function_docs = """
58
333
 
59
- Returns:
60
- response (GenieResponse): An object containing the Genie response
61
- """)
334
+ Args:
335
+ question (str): The question to ask to ask Genie about your data. Ask simple, clear questions about your tabular data. For complex analysis, ask multiple simple questions rather than one complex question.
62
336
 
63
- doc: str = description + "\n" + doc_signature
337
+ Returns:
338
+ GenieResponse: A response object containing the conversation ID and result from Genie."""
339
+ tool_description = tool_description + function_docs
64
340
 
65
- async def genie_tool(question: str) -> GenieResponse:
66
- # Use sync API for now since Genie doesn't support async yet
67
- # Can be easily updated to await when Genie gets async support
68
- response: GenieResponse = genie.ask_question(question)
69
- return response
341
+ @tool(
342
+ name_or_callable=tool_name,
343
+ description=tool_description,
344
+ )
345
+ def genie_tool(
346
+ question: Annotated[str, "The question to ask Genie about your data"],
347
+ state: Annotated[dict, InjectedState],
348
+ tool_call_id: Annotated[str, InjectedToolCallId],
349
+ ) -> Command:
350
+ genie: Genie = Genie(
351
+ space_id=space_id,
352
+ client=genie_room.workspace_client,
353
+ truncate_results=truncate_results,
354
+ polling_interval=poll_interval,
355
+ )
70
356
 
71
- name: str = name if name else genie_tool.__name__
357
+ """Process a natural language question through Databricks Genie."""
358
+ # Get existing conversation mapping and retrieve conversation ID for this space
359
+ conversation_ids: dict[str, str] = state.get("genie_conversation_ids", {})
360
+ existing_conversation_id: str | None = conversation_ids.get(space_id)
361
+ logger.debug(
362
+ f"Existing conversation ID for space {space_id}: {existing_conversation_id}"
363
+ )
72
364
 
73
- structured_tool: StructuredTool = StructuredTool.from_function(
74
- coroutine=genie_tool, name=name, description=doc, parse_docstring=False
75
- )
365
+ response: GenieResponse = genie.ask_question(
366
+ question, conversation_id=existing_conversation_id
367
+ )
368
+
369
+ current_conversation_id: str = response.conversation_id
370
+ logger.debug(
371
+ f"Current conversation ID for space {space_id}: {current_conversation_id}"
372
+ )
373
+
374
+ # Update the conversation mapping with the new conversation ID for this space
375
+
376
+ update: dict[str, Any] = {
377
+ "messages": [ToolMessage(response.to_json(), tool_call_id=tool_call_id)],
378
+ }
379
+
380
+ if persist_conversation:
381
+ updated_conversation_ids: dict[str, str] = conversation_ids.copy()
382
+ updated_conversation_ids[space_id] = current_conversation_id
383
+ update["genie_conversation_ids"] = updated_conversation_ids
384
+
385
+ logger.debug(f"State update: {update}")
386
+
387
+ return Command(update=update)
76
388
 
77
- return structured_tool
389
+ return genie_tool
@@ -87,7 +87,11 @@ def as_human_in_the_loop(
87
87
  if isinstance(function, BaseFunctionModel):
88
88
  human_in_the_loop: HumanInTheLoopModel | None = function.human_in_the_loop
89
89
  if human_in_the_loop:
90
- logger.debug(f"Adding human-in-the-loop to tool: {tool.name}")
90
+ # Get tool name safely - handle RunnableBinding objects
91
+ tool_name = getattr(tool, "name", None) or getattr(
92
+ getattr(tool, "bound", None), "name", "unknown_tool"
93
+ )
94
+ logger.debug(f"Adding human-in-the-loop to tool: {tool_name}")
91
95
  tool = add_human_in_the_loop(
92
96
  tool=tool,
93
97
  interrupt_config=human_in_the_loop.interupt_config,
@@ -1,14 +1,18 @@
1
- from typing import Sequence
1
+ from typing import Any, Dict, Optional, Sequence, Union
2
2
 
3
- from databricks_langchain import (
4
- DatabricksFunctionClient,
5
- UCFunctionToolkit,
6
- )
3
+ from databricks.sdk import WorkspaceClient
4
+ from databricks.sdk.service.catalog import PermissionsChange, Privilege
5
+ from databricks_langchain import DatabricksFunctionClient, UCFunctionToolkit
7
6
  from langchain_core.runnables.base import RunnableLike
7
+ from langchain_core.tools import StructuredTool
8
8
  from loguru import logger
9
9
 
10
10
  from dao_ai.config import (
11
+ AnyVariable,
12
+ CompositeVariableModel,
13
+ ToolModel,
11
14
  UnityCatalogFunctionModel,
15
+ value_of,
12
16
  )
13
17
  from dao_ai.tools.human_in_the_loop import as_human_in_the_loop
14
18
 
@@ -32,19 +36,330 @@ def create_uc_tools(
32
36
 
33
37
  logger.debug(f"create_uc_tools: {function}")
34
38
 
39
+ original_function_model = None
35
40
  if isinstance(function, UnityCatalogFunctionModel):
36
- function = function.full_name
41
+ original_function_model = function
42
+ function_name = function.full_name
43
+ else:
44
+ function_name = function
37
45
 
38
- client: DatabricksFunctionClient = DatabricksFunctionClient()
46
+ # Determine which tools to create
47
+ if original_function_model and original_function_model.partial_args:
48
+ logger.debug("Found partial_args, creating custom tool with partial arguments")
49
+ # Create a ToolModel wrapper for the with_partial_args function
50
+ tool_model = ToolModel(
51
+ name=original_function_model.name, function=original_function_model
52
+ )
53
+
54
+ # Use with_partial_args to create the authenticated tool
55
+ tools = [with_partial_args(tool_model, original_function_model.partial_args)]
56
+ else:
57
+ # Fallback to standard UC toolkit approach
58
+ client: DatabricksFunctionClient = DatabricksFunctionClient()
59
+
60
+ toolkit: UCFunctionToolkit = UCFunctionToolkit(
61
+ function_names=[function_name], client=client
62
+ )
63
+
64
+ tools = toolkit.tools or []
65
+ logger.debug(f"Retrieved tools: {tools}")
66
+
67
+ # Apply human-in-the-loop wrapper to all tools and return
68
+ return [as_human_in_the_loop(tool=tool, function=function_name) for tool in tools]
69
+
70
+
71
+ def _execute_uc_function(
72
+ client: DatabricksFunctionClient,
73
+ function_name: str,
74
+ partial_args: Dict[str, str] = None,
75
+ **kwargs: Any,
76
+ ) -> str:
77
+ """Execute Unity Catalog function with partial args and provided parameters."""
78
+
79
+ # Start with partial args if provided
80
+ all_params: Dict[str, Any] = dict(partial_args) if partial_args else {}
81
+
82
+ # Add any additional kwargs
83
+ all_params.update(kwargs)
39
84
 
40
- toolkit: UCFunctionToolkit = UCFunctionToolkit(
41
- function_names=[function], client=client
85
+ logger.debug(
86
+ f"Calling UC function {function_name} with parameters: {list(all_params.keys())}"
42
87
  )
43
88
 
44
- tools = toolkit.tools or []
89
+ result = client.execute_function(function_name=function_name, parameters=all_params)
90
+
91
+ # Handle errors and extract result
92
+ if hasattr(result, "error") and result.error:
93
+ logger.error(f"Unity Catalog function error: {result.error}")
94
+ raise RuntimeError(f"Function execution failed: {result.error}")
95
+
96
+ result_value: str = result.value if hasattr(result, "value") else str(result)
97
+ logger.debug(f"UC function result: {result_value}")
98
+ return result_value
99
+
100
+
101
+ def _grant_function_permissions(
102
+ function_name: str,
103
+ client_id: str,
104
+ host: Optional[str] = None,
105
+ ) -> None:
106
+ """
107
+ Grant comprehensive permissions to the service principal for Unity Catalog function execution.
108
+
109
+ This includes:
110
+ - EXECUTE permission on the function itself
111
+ - USE permission on the containing schema
112
+ - USE permission on the containing catalog
113
+ """
114
+ try:
115
+ # Initialize workspace client
116
+ workspace_client = WorkspaceClient(host=host) if host else WorkspaceClient()
117
+
118
+ # Parse the function name to get catalog and schema
119
+ parts = function_name.split(".")
120
+ if len(parts) != 3:
121
+ logger.warning(
122
+ f"Invalid function name format: {function_name}. Expected catalog.schema.function"
123
+ )
124
+ return
125
+
126
+ catalog_name, schema_name, func_name = parts
127
+ schema_full_name = f"{catalog_name}.{schema_name}"
128
+
129
+ logger.debug(
130
+ f"Granting comprehensive permissions on function {function_name} to principal {client_id}"
131
+ )
132
+
133
+ # 1. Grant EXECUTE permission on the function
134
+ try:
135
+ workspace_client.grants.update(
136
+ securable_type="function",
137
+ full_name=function_name,
138
+ changes=[
139
+ PermissionsChange(principal=client_id, add=[Privilege.EXECUTE])
140
+ ],
141
+ )
142
+ logger.debug(f"Granted EXECUTE on function {function_name}")
143
+ except Exception as e:
144
+ logger.warning(f"Failed to grant EXECUTE on function {function_name}: {e}")
145
+
146
+ # 2. Grant USE_SCHEMA permission on the schema
147
+ try:
148
+ workspace_client.grants.update(
149
+ securable_type="schema",
150
+ full_name=schema_full_name,
151
+ changes=[
152
+ PermissionsChange(
153
+ principal=client_id,
154
+ add=[Privilege.USE_SCHEMA],
155
+ )
156
+ ],
157
+ )
158
+ logger.debug(f"Granted USE_SCHEMA on schema {schema_full_name}")
159
+ except Exception as e:
160
+ logger.warning(
161
+ f"Failed to grant USE_SCHEMA on schema {schema_full_name}: {e}"
162
+ )
163
+
164
+ # 3. Grant USE_CATALOG and BROWSE permissions on the catalog
165
+ try:
166
+ workspace_client.grants.update(
167
+ securable_type="catalog",
168
+ full_name=catalog_name,
169
+ changes=[
170
+ PermissionsChange(
171
+ principal=client_id,
172
+ add=[Privilege.USE_CATALOG, Privilege.BROWSE],
173
+ )
174
+ ],
175
+ )
176
+ logger.debug(f"Granted USE_CATALOG and BROWSE on catalog {catalog_name}")
177
+ except Exception as e:
178
+ logger.warning(
179
+ f"Failed to grant catalog permissions on {catalog_name}: {e}"
180
+ )
181
+
182
+ logger.debug(
183
+ f"Successfully granted comprehensive permissions on {function_name} to {client_id}"
184
+ )
185
+
186
+ except Exception as e:
187
+ logger.warning(
188
+ f"Failed to grant permissions on function {function_name} to {client_id}: {e}"
189
+ )
190
+ # Don't fail the tool creation if permission granting fails
191
+ pass
192
+
193
+
194
+ def _create_filtered_schema(original_schema: type, exclude_fields: set[str]) -> type:
195
+ """
196
+ Create a new Pydantic model that excludes specified fields from the original schema.
197
+
198
+ Args:
199
+ original_schema: The original Pydantic model class
200
+ exclude_fields: Set of field names to exclude from the schema
201
+
202
+ Returns:
203
+ A new Pydantic model class with the specified fields removed
204
+ """
205
+ from pydantic import BaseModel, Field, create_model
206
+ from pydantic.fields import PydanticUndefined
207
+
208
+ try:
209
+ # Get the original model's fields (Pydantic v2)
210
+ original_fields = original_schema.model_fields
211
+ filtered_field_definitions = {}
212
+
213
+ for name, field in original_fields.items():
214
+ if name not in exclude_fields:
215
+ # Reconstruct the field definition for create_model
216
+ field_type = field.annotation
217
+ field_default = (
218
+ field.default if field.default is not PydanticUndefined else ...
219
+ )
220
+ field_info = Field(default=field_default, description=field.description)
221
+ filtered_field_definitions[name] = (field_type, field_info)
222
+
223
+ # If no fields remain after filtering, return a generic empty schema
224
+ if not filtered_field_definitions:
225
+
226
+ class EmptySchema(BaseModel):
227
+ """Unity Catalog function with all parameters provided via partial args."""
45
228
 
46
- logger.debug(f"Retrieved tools: {tools}")
229
+ pass
47
230
 
48
- tools = [as_human_in_the_loop(tool=tool, function=function) for tool in tools]
231
+ return EmptySchema
232
+
233
+ # Create the new model dynamically
234
+ model_name = f"Filtered{original_schema.__name__}"
235
+ docstring = getattr(
236
+ original_schema, "__doc__", "Filtered Unity Catalog function parameters."
237
+ )
238
+
239
+ filtered_model = create_model(
240
+ model_name, __doc__=docstring, **filtered_field_definitions
241
+ )
242
+ return filtered_model
243
+
244
+ except Exception as e:
245
+ logger.warning(f"Failed to create filtered schema: {e}")
246
+
247
+ # Fallback to generic schema
248
+ class GenericFilteredSchema(BaseModel):
249
+ """Generic filtered schema for Unity Catalog function."""
250
+
251
+ pass
252
+
253
+ return GenericFilteredSchema
254
+
255
+
256
+ def with_partial_args(
257
+ tool: Union[ToolModel, Dict[str, Any]],
258
+ partial_args: dict[str, AnyVariable] = {},
259
+ ) -> StructuredTool:
260
+ """
261
+ Create a Unity Catalog tool with partial arguments pre-filled.
262
+
263
+ This function creates a wrapper tool that calls the UC function with partial arguments
264
+ already resolved, so the caller only needs to provide the remaining parameters.
265
+
266
+ Args:
267
+ tool: ToolModel containing the Unity Catalog function configuration
268
+ partial_args: Dictionary of arguments to pre-fill in the tool
269
+
270
+ Returns:
271
+ StructuredTool: A LangChain tool with partial arguments pre-filled
272
+ """
273
+ from unitycatalog.ai.langchain.toolkit import generate_function_input_params_schema
274
+
275
+ logger.debug(f"with_partial_args: {tool}")
276
+
277
+ # Convert dict-based variables to CompositeVariableModel and resolve their values
278
+ resolved_args = {}
279
+ for k, v in partial_args.items():
280
+ if isinstance(v, dict):
281
+ resolved_args[k] = value_of(CompositeVariableModel(**v))
282
+ else:
283
+ resolved_args[k] = value_of(v)
284
+
285
+ logger.debug(f"Resolved partial args: {resolved_args.keys()}")
286
+
287
+ if isinstance(tool, dict):
288
+ tool = ToolModel(**tool)
289
+
290
+ unity_catalog_function = tool.function
291
+ if isinstance(unity_catalog_function, dict):
292
+ unity_catalog_function = UnityCatalogFunctionModel(**unity_catalog_function)
293
+
294
+ function_name: str = unity_catalog_function.full_name
295
+ logger.debug(f"Creating UC tool with partial args for: {function_name}")
296
+
297
+ # Grant permissions if we have credentials
298
+ if "client_id" in resolved_args:
299
+ client_id: str = resolved_args["client_id"]
300
+ host: Optional[str] = resolved_args.get("host")
301
+ try:
302
+ _grant_function_permissions(function_name, client_id, host)
303
+ except Exception as e:
304
+ logger.warning(f"Failed to grant permissions: {e}")
305
+
306
+ # Create the client for function execution
307
+ client: DatabricksFunctionClient = DatabricksFunctionClient()
308
+
309
+ # Try to get the function schema for better tool definition
310
+ try:
311
+ function_info = client.get_function(function_name)
312
+ schema_info = generate_function_input_params_schema(function_info)
313
+ tool_description = (
314
+ function_info.comment or f"Unity Catalog function: {function_name}"
315
+ )
316
+
317
+ logger.debug(
318
+ f"Generated schema for function {function_name}: {schema_info.pydantic_model}"
319
+ )
320
+ logger.debug(f"Tool description: {tool_description}")
321
+
322
+ # Create a modified schema that excludes partial args
323
+ original_schema = schema_info.pydantic_model
324
+ schema_model = _create_filtered_schema(original_schema, resolved_args.keys())
325
+ logger.debug(
326
+ f"Filtered schema excludes partial args: {list(resolved_args.keys())}"
327
+ )
328
+
329
+ except Exception as e:
330
+ logger.warning(f"Could not introspect function {function_name}: {e}")
331
+ # Fallback to a generic schema
332
+ from pydantic import BaseModel
333
+
334
+ class GenericUCParams(BaseModel):
335
+ """Generic parameters for Unity Catalog function."""
336
+
337
+ pass
338
+
339
+ schema_model = GenericUCParams
340
+ tool_description = f"Unity Catalog function: {function_name}"
341
+
342
+ # Create a wrapper function that calls _execute_uc_function with partial args
343
+ def uc_function_wrapper(**kwargs) -> str:
344
+ """Wrapper function that executes Unity Catalog function with partial args."""
345
+ return _execute_uc_function(
346
+ client=client,
347
+ function_name=function_name,
348
+ partial_args=resolved_args,
349
+ **kwargs,
350
+ )
351
+
352
+ # Set the function name for the decorator
353
+ uc_function_wrapper.__name__ = tool.name or function_name.replace(".", "_")
354
+
355
+ # Create the tool using LangChain's StructuredTool
356
+ from langchain_core.tools import StructuredTool
357
+
358
+ partial_tool = StructuredTool.from_function(
359
+ func=uc_function_wrapper,
360
+ name=tool.name or function_name.replace(".", "_"),
361
+ description=tool_description,
362
+ args_schema=schema_model,
363
+ )
49
364
 
50
- return tools
365
+ return partial_tool
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dao-ai
3
- Version: 0.0.19
3
+ Version: 0.0.21
4
4
  Summary: DAO AI: A modular, multi-agent orchestration framework for complex AI workflows. Supports agent handoff, tool integration, and dynamic configuration via YAML.
5
5
  Project-URL: Homepage, https://github.com/natefleming/dao-ai
6
6
  Project-URL: Documentation, https://natefleming.github.io/dao-ai
@@ -3,14 +3,14 @@ dao_ai/agent_as_code.py,sha256=kPSeDz2-1jRaed1TMs4LA3VECoyqe9_Ed2beRLB9gXQ,472
3
3
  dao_ai/catalog.py,sha256=sPZpHTD3lPx4EZUtIWeQV7VQM89WJ6YH__wluk1v2lE,4947
4
4
  dao_ai/chat_models.py,sha256=uhwwOTeLyHWqoTTgHrs4n5iSyTwe4EQcLKnh3jRxPWI,8626
5
5
  dao_ai/cli.py,sha256=Aez2TQW3Q8Ho1IaIkRggt0NevDxAAVPjXkePC5GPJF0,20429
6
- dao_ai/config.py,sha256=N_Vc-rJHvBzbia4TyAExGhCvZKXlk49bskrI_sbxwjg,51869
6
+ dao_ai/config.py,sha256=GeaM00wNlYecwe3HhqeG88Hprt0SvGg4HtC7g_m-v98,52386
7
7
  dao_ai/graph.py,sha256=gmD9mxODfXuvn9xWeBfewm1FiuVAWMLEdnZz7DNmSH0,7859
8
- dao_ai/guardrails.py,sha256=-Qh0f_2Db9t4Nbrrx9FM7tnpqShjMoyxepZ0HByItfU,4027
8
+ dao_ai/guardrails.py,sha256=4TKArDONRy8RwHzOT1plZ1rhy3x9GF_aeGpPCRl6wYA,4016
9
9
  dao_ai/messages.py,sha256=xl_3-WcFqZKCFCiov8sZOPljTdM3gX3fCHhxq-xFg2U,7005
10
- dao_ai/models.py,sha256=Xb23U-lhDG8KyNRIijcJ4InluadlaGNy4rrYx7Cjgfg,26939
10
+ dao_ai/models.py,sha256=8r8GIG3EGxtVyWsRNI56lVaBjiNrPkzh4HdwMZRq8iw,31689
11
11
  dao_ai/nodes.py,sha256=SSuFNTXOdFaKg_aX-yUkQO7fM9wvNGu14lPXKDapU1U,8461
12
12
  dao_ai/prompts.py,sha256=vpmIbWs_szXUgNNDs5Gh2LcxKZti5pHDKSfoClUcgX0,1289
13
- dao_ai/state.py,sha256=GwbMbd1TWZx1T5iQrEOX6_rpxOitlmyeJ8dMr2o_pag,1031
13
+ dao_ai/state.py,sha256=_lF9krAYYjvFDMUwZzVKOn0ZnXKcOrbjWKdre0C5B54,1137
14
14
  dao_ai/types.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
15
  dao_ai/utils.py,sha256=dkZTXNN6q0xwkrvSWdNq8937W2xGuLCRWRb6hRQM6kA,4217
16
16
  dao_ai/vector_search.py,sha256=jlaFS_iizJ55wblgzZmswMM3UOL-qOp2BGJc0JqXYSg,2839
@@ -19,22 +19,22 @@ dao_ai/hooks/core.py,sha256=ZShHctUSoauhBgdf1cecy9-D7J6-sGn-pKjuRMumW5U,6663
19
19
  dao_ai/memory/__init__.py,sha256=1kHx_p9abKYFQ6EYD05nuc1GS5HXVEpufmjBGw_7Uho,260
20
20
  dao_ai/memory/base.py,sha256=99nfr2UZJ4jmfTL_KrqUlRSCoRxzkZyWyx5WqeUoMdQ,338
21
21
  dao_ai/memory/core.py,sha256=g7chjBgVgx3iKjR2hghl0QL1j3802uIM_e7mgszur9M,4151
22
- dao_ai/memory/postgres.py,sha256=ncvEKFYX-ZjUDYVmuWBMcZnykcp2eK4TP-ojzqkwDsk,17433
22
+ dao_ai/memory/postgres.py,sha256=aWHRLhPm-9ywjlQe2B4XSdLbeaiuVV88p4PiQJFNEWo,13924
23
23
  dao_ai/providers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
24
24
  dao_ai/providers/base.py,sha256=-fjKypCOk28h6vioPfMj9YZSw_3Kcbi2nMuAyY7vX9k,1383
25
- dao_ai/providers/databricks.py,sha256=fZ8mGotfA3W3t5yUej2xGmGHSybjBFYr895mOctT418,28203
25
+ dao_ai/providers/databricks.py,sha256=PX5mBvZaIxSJIAHWVnPXsho1XvxcoR3Qs3I9UavFRsY,28306
26
26
  dao_ai/tools/__init__.py,sha256=ye6MHaJY7tUnJ8336YJiLxuZr55zDPNdOw6gm7j5jlc,1103
27
27
  dao_ai/tools/agent.py,sha256=WbQnyziiT12TLMrA7xK0VuOU029tdmUBXbUl-R1VZ0Q,1886
28
28
  dao_ai/tools/core.py,sha256=Kei33S8vrmvPOAyrFNekaWmV2jqZ-IPS1QDSvU7RZF0,1984
29
- dao_ai/tools/genie.py,sha256=GzV5lfDYKmzW_lSLxAsPaTwnzX6GxQOB1UcLaTDqpfY,2787
30
- dao_ai/tools/human_in_the_loop.py,sha256=IBmQJmpxkdDxnBNyABc_-dZhhsQlTNTkPyUXgkHKIgY,3466
29
+ dao_ai/tools/genie.py,sha256=1CbLViNQ3KnmDtHXuwqCPug7rEhCGvuHP1NgsY-AJZ0,15050
30
+ dao_ai/tools/human_in_the_loop.py,sha256=yk35MO9eNETnYFH-sqlgR-G24TrEgXpJlnZUustsLkI,3681
31
31
  dao_ai/tools/mcp.py,sha256=auEt_dwv4J26fr5AgLmwmnAsI894-cyuvkvjItzAUxs,4419
32
32
  dao_ai/tools/python.py,sha256=XcQiTMshZyLUTVR5peB3vqsoUoAAy8gol9_pcrhddfI,1831
33
33
  dao_ai/tools/time.py,sha256=Y-23qdnNHzwjvnfkWvYsE7PoWS1hfeKy44tA7sCnNac,8759
34
- dao_ai/tools/unity_catalog.py,sha256=PXfLj2EgyQgaXq4Qq3t25AmTC4KyVCF_-sCtg6enens,1404
34
+ dao_ai/tools/unity_catalog.py,sha256=uX_h52BuBAr4c9UeqSMI7DNz3BPRLeai5tBVW4sJqRI,13113
35
35
  dao_ai/tools/vector_search.py,sha256=EDYQs51zIPaAP0ma1D81wJT77GQ-v-cjb2XrFVWfWdg,2621
36
- dao_ai-0.0.19.dist-info/METADATA,sha256=hus4RZHOCTgDR6Rs8zS9l0OusplrFzryWCLsXZpTxgw,41380
37
- dao_ai-0.0.19.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
38
- dao_ai-0.0.19.dist-info/entry_points.txt,sha256=Xa-UFyc6gWGwMqMJOt06ZOog2vAfygV_DSwg1AiP46g,43
39
- dao_ai-0.0.19.dist-info/licenses/LICENSE,sha256=YZt3W32LtPYruuvHE9lGk2bw6ZPMMJD8yLrjgHybyz4,1069
40
- dao_ai-0.0.19.dist-info/RECORD,,
36
+ dao_ai-0.0.21.dist-info/METADATA,sha256=PG-eOltuUpaJf4lYEw-DoVy5BFT9LbMCfe8GanIV7zQ,41380
37
+ dao_ai-0.0.21.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
38
+ dao_ai-0.0.21.dist-info/entry_points.txt,sha256=Xa-UFyc6gWGwMqMJOt06ZOog2vAfygV_DSwg1AiP46g,43
39
+ dao_ai-0.0.21.dist-info/licenses/LICENSE,sha256=YZt3W32LtPYruuvHE9lGk2bw6ZPMMJD8yLrjgHybyz4,1069
40
+ dao_ai-0.0.21.dist-info/RECORD,,