dao-ai 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. dao_ai/agent_as_code.py +2 -5
  2. dao_ai/cli.py +65 -15
  3. dao_ai/config.py +672 -218
  4. dao_ai/genie/cache/core.py +6 -2
  5. dao_ai/genie/cache/lru.py +29 -11
  6. dao_ai/genie/cache/semantic.py +95 -44
  7. dao_ai/hooks/core.py +5 -5
  8. dao_ai/logging.py +56 -0
  9. dao_ai/memory/core.py +61 -44
  10. dao_ai/memory/databricks.py +54 -41
  11. dao_ai/memory/postgres.py +77 -36
  12. dao_ai/middleware/assertions.py +45 -17
  13. dao_ai/middleware/core.py +13 -7
  14. dao_ai/middleware/guardrails.py +30 -25
  15. dao_ai/middleware/human_in_the_loop.py +9 -5
  16. dao_ai/middleware/message_validation.py +61 -29
  17. dao_ai/middleware/summarization.py +16 -11
  18. dao_ai/models.py +172 -69
  19. dao_ai/nodes.py +148 -19
  20. dao_ai/optimization.py +26 -16
  21. dao_ai/orchestration/core.py +15 -8
  22. dao_ai/orchestration/supervisor.py +22 -8
  23. dao_ai/orchestration/swarm.py +57 -12
  24. dao_ai/prompts.py +17 -17
  25. dao_ai/providers/databricks.py +365 -155
  26. dao_ai/state.py +24 -6
  27. dao_ai/tools/__init__.py +2 -0
  28. dao_ai/tools/agent.py +1 -3
  29. dao_ai/tools/core.py +7 -7
  30. dao_ai/tools/email.py +29 -77
  31. dao_ai/tools/genie.py +18 -13
  32. dao_ai/tools/mcp.py +223 -156
  33. dao_ai/tools/python.py +5 -2
  34. dao_ai/tools/search.py +1 -1
  35. dao_ai/tools/slack.py +21 -9
  36. dao_ai/tools/sql.py +202 -0
  37. dao_ai/tools/time.py +30 -7
  38. dao_ai/tools/unity_catalog.py +129 -86
  39. dao_ai/tools/vector_search.py +318 -244
  40. dao_ai/utils.py +15 -10
  41. dao_ai-0.1.3.dist-info/METADATA +455 -0
  42. dao_ai-0.1.3.dist-info/RECORD +64 -0
  43. dao_ai-0.1.1.dist-info/METADATA +0 -1878
  44. dao_ai-0.1.1.dist-info/RECORD +0 -62
  45. {dao_ai-0.1.1.dist-info → dao_ai-0.1.3.dist-info}/WHEEL +0 -0
  46. {dao_ai-0.1.1.dist-info → dao_ai-0.1.3.dist-info}/entry_points.txt +0 -0
  47. {dao_ai-0.1.1.dist-info → dao_ai-0.1.3.dist-info}/licenses/LICENSE +0 -0
dao_ai/tools/mcp.py CHANGED
@@ -1,196 +1,263 @@
1
+ """
2
+ MCP (Model Context Protocol) tool creation for LangChain agents.
3
+
4
+ This module provides tools for connecting to MCP servers using the
5
+ MCP SDK and langchain-mcp-adapters library.
6
+
7
+ For compatibility with Databricks APIs, we use manual tool wrappers
8
+ that give us full control over the response format.
9
+
10
+ Reference: https://docs.langchain.com/oss/python/langchain/mcp
11
+ """
12
+
1
13
  import asyncio
2
14
  from typing import Any, Sequence
3
15
 
4
- from databricks_mcp import DatabricksOAuthClientProvider
5
16
  from langchain_core.runnables.base import RunnableLike
6
17
  from langchain_core.tools import tool as create_tool
7
18
  from langchain_mcp_adapters.client import MultiServerMCPClient
8
19
  from loguru import logger
9
- from mcp import ClientSession
10
- from mcp.client.streamable_http import streamablehttp_client
11
- from mcp.types import ListToolsResult, Tool
20
+ from mcp.types import CallToolResult, TextContent, Tool
12
21
 
13
22
  from dao_ai.config import (
14
23
  McpFunctionModel,
15
24
  TransportType,
25
+ value_of,
16
26
  )
17
27
 
18
28
 
19
- def create_mcp_tools(
29
+ def _build_connection_config(
20
30
  function: McpFunctionModel,
21
- ) -> Sequence[RunnableLike]:
31
+ ) -> dict[str, Any]:
22
32
  """
23
- Create tools for invoking Databricks MCP functions.
33
+ Build the connection configuration dictionary for MultiServerMCPClient.
24
34
 
25
- Supports both direct MCP connections and UC Connection-based MCP access.
26
- Uses session-based approach to handle authentication token expiration properly.
35
+ Args:
36
+ function: The MCP function model configuration.
27
37
 
28
- Based on: https://docs.databricks.com/aws/en/generative-ai/mcp/external-mcp
38
+ Returns:
39
+ A dictionary containing the transport-specific connection settings.
29
40
  """
30
- logger.debug(f"create_mcp_tools: {function}")
41
+ if function.transport == TransportType.STDIO:
42
+ return {
43
+ "command": function.command,
44
+ "args": function.args,
45
+ "transport": function.transport.value,
46
+ }
47
+
48
+ # For HTTP transport with UC Connection, use DatabricksOAuthClientProvider
49
+ if function.connection:
50
+ from databricks_mcp import DatabricksOAuthClientProvider
31
51
 
32
- # Get MCP URL - handles all convenience objects (connection, genie_room, warehouse, etc.)
33
- mcp_url = function.mcp_url
34
- logger.debug(f"Using MCP URL: {mcp_url}")
52
+ workspace_client = function.connection.workspace_client
53
+ auth_provider = DatabricksOAuthClientProvider(workspace_client)
35
54
 
36
- # Check if using UC Connection or direct MCP connection
37
- if function.connection:
38
- # Use UC Connection approach with DatabricksOAuthClientProvider
39
- logger.debug(f"Using UC Connection for MCP: {function.connection.name}")
55
+ logger.trace(
56
+ "Using DatabricksOAuthClientProvider for authentication",
57
+ connection_name=function.connection.name,
58
+ )
40
59
 
41
- async def _list_tools_with_connection():
42
- """List available tools using DatabricksOAuthClientProvider."""
43
- workspace_client = function.connection.workspace_client
60
+ return {
61
+ "url": function.mcp_url,
62
+ "transport": "http",
63
+ "auth": auth_provider,
64
+ }
44
65
 
45
- async with streamablehttp_client(
46
- mcp_url, auth=DatabricksOAuthClientProvider(workspace_client)
47
- ) as (read_stream, write_stream, _):
48
- async with ClientSession(read_stream, write_stream) as session:
49
- # Initialize and list tools
50
- await session.initialize()
51
- return await session.list_tools()
66
+ # For HTTP transport with headers-based authentication
67
+ headers: dict[str, str] = {
68
+ key: str(value_of(val)) for key, val in function.headers.items()
69
+ }
70
+
71
+ if "Authorization" not in headers:
72
+ logger.trace("Generating fresh authentication token")
73
+
74
+ from dao_ai.providers.databricks import DatabricksProvider
52
75
 
53
76
  try:
54
- mcp_tools: list[Tool] | ListToolsResult = asyncio.run(
55
- _list_tools_with_connection()
77
+ provider = DatabricksProvider(
78
+ workspace_host=value_of(function.workspace_host),
79
+ client_id=value_of(function.client_id),
80
+ client_secret=value_of(function.client_secret),
81
+ pat=value_of(function.pat),
56
82
  )
57
- if isinstance(mcp_tools, ListToolsResult):
58
- mcp_tools = mcp_tools.tools
83
+ headers["Authorization"] = f"Bearer {provider.create_token()}"
84
+ logger.trace("Generated fresh authentication token")
85
+ except Exception as e:
86
+ logger.error("Failed to create fresh token", error=str(e))
87
+ else:
88
+ logger.trace("Using existing authentication token")
59
89
 
60
- logger.debug(f"Retrieved {len(mcp_tools)} MCP tools via UC Connection")
90
+ return {
91
+ "url": function.mcp_url,
92
+ "transport": "http",
93
+ "headers": headers,
94
+ }
61
95
 
62
- except Exception as e:
63
- logger.error(f"Failed to get tools from MCP server via UC Connection: {e}")
64
- raise RuntimeError(
65
- f"Failed to list MCP tools for function '{function.name}' via UC Connection '{function.connection.name}': {e}"
66
- )
67
96
 
68
- # Create wrapper tools with fresh session per invocation
69
- def _create_tool_wrapper_with_connection(mcp_tool: Tool) -> RunnableLike:
70
- @create_tool(
71
- mcp_tool.name,
72
- description=mcp_tool.description or f"MCP tool: {mcp_tool.name}",
73
- args_schema=mcp_tool.inputSchema,
74
- )
75
- async def tool_wrapper(**kwargs):
76
- """Execute MCP tool with fresh UC Connection session."""
77
- logger.debug(
78
- f"Invoking MCP tool {mcp_tool.name} with fresh UC Connection session"
79
- )
80
- workspace_client = function.connection.workspace_client
81
-
82
- try:
83
- async with streamablehttp_client(
84
- mcp_url, auth=DatabricksOAuthClientProvider(workspace_client)
85
- ) as (read_stream, write_stream, _):
86
- async with ClientSession(read_stream, write_stream) as session:
87
- await session.initialize()
88
- result = await session.call_tool(mcp_tool.name, kwargs)
89
- logger.debug(
90
- f"MCP tool {mcp_tool.name} completed successfully"
91
- )
92
- return result
93
- except Exception as e:
94
- logger.error(f"MCP tool {mcp_tool.name} failed: {e}")
95
- raise
96
-
97
- # HITL is now handled at middleware level via HumanInTheLoopMiddleware
98
- return tool_wrapper
99
-
100
- return [_create_tool_wrapper_with_connection(tool) for tool in mcp_tools]
97
+ def _extract_text_content(result: CallToolResult) -> str:
98
+ """
99
+ Extract text content from an MCP CallToolResult.
101
100
 
102
- else:
103
- # Use direct MCP connection with MultiServerMCPClient
104
- logger.debug("Using direct MCP connection with MultiServerMCPClient")
105
-
106
- def _create_fresh_connection() -> dict[str, Any]:
107
- """Create connection config with fresh authentication headers."""
108
- logger.debug("Creating fresh connection...")
109
-
110
- if function.transport == TransportType.STDIO:
111
- return {
112
- "command": function.command,
113
- "args": function.args,
114
- "transport": function.transport,
115
- }
116
-
117
- # For HTTP transport, generate fresh headers
118
- headers = function.headers.copy() if function.headers else {}
119
-
120
- if "Authorization" not in headers:
121
- logger.debug("Generating fresh authentication token for MCP function")
122
-
123
- from dao_ai.config import value_of
124
- from dao_ai.providers.databricks import DatabricksProvider
125
-
126
- try:
127
- provider = DatabricksProvider(
128
- workspace_host=value_of(function.workspace_host),
129
- client_id=value_of(function.client_id),
130
- client_secret=value_of(function.client_secret),
131
- pat=value_of(function.pat),
132
- )
133
- headers["Authorization"] = f"Bearer {provider.create_token()}"
134
- logger.debug("Generated fresh authentication token")
135
- except Exception as e:
136
- logger.error(f"Failed to create fresh token: {e}")
137
- else:
138
- logger.debug("Using existing authentication token")
139
-
140
- return {
141
- "url": mcp_url, # Use the resolved MCP URL
142
- "transport": function.transport,
143
- "headers": headers,
144
- }
145
-
146
- # Get available tools from MCP server
147
- async def _list_mcp_tools():
148
- connection = _create_fresh_connection()
149
- client = MultiServerMCPClient({function.name: connection})
101
+ Converts the MCP result content to a plain string format that is
102
+ compatible with all LLM APIs (avoiding extra fields like 'id').
150
103
 
151
- try:
152
- async with client.session(function.name) as session:
153
- return await session.list_tools()
154
- except Exception as e:
155
- logger.error(f"Failed to list MCP tools: {e}")
156
- return []
104
+ Args:
105
+ result: The MCP tool call result.
157
106
 
158
- # Note: This still needs to run sync during tool creation/registration
159
- # The actual tool execution will be async
160
- try:
161
- mcp_tools: list[Tool] | ListToolsResult = asyncio.run(_list_mcp_tools())
162
- if isinstance(mcp_tools, ListToolsResult):
163
- mcp_tools = mcp_tools.tools
107
+ Returns:
108
+ A string containing the concatenated text content.
109
+ """
110
+ if not result.content:
111
+ return ""
164
112
 
165
- logger.debug(f"Retrieved {len(mcp_tools)} MCP tools")
166
- except Exception as e:
167
- logger.error(f"Failed to get tools from MCP server: {e}")
168
- raise RuntimeError(
169
- f"Failed to list MCP tools for function '{function.name}' with transport '{function.transport}' and URL '{function.url}': {e}"
113
+ text_parts: list[str] = []
114
+ for item in result.content:
115
+ if isinstance(item, TextContent):
116
+ text_parts.append(item.text)
117
+ elif hasattr(item, "text"):
118
+ # Handle other content types that have text
119
+ text_parts.append(str(item.text))
120
+ else:
121
+ # Fallback: convert to string representation
122
+ text_parts.append(str(item))
123
+
124
+ return "\n".join(text_parts)
125
+
126
+
127
+ def create_mcp_tools(
128
+ function: McpFunctionModel,
129
+ ) -> Sequence[RunnableLike]:
130
+ """
131
+ Create tools for invoking Databricks MCP functions.
132
+
133
+ Supports both direct MCP connections and UC Connection-based MCP access.
134
+ Uses manual tool wrappers to ensure response format compatibility with
135
+ Databricks APIs (which reject extra fields in tool results).
136
+
137
+ Based on: https://docs.databricks.com/aws/en/generative-ai/mcp/external-mcp
138
+
139
+ Args:
140
+ function: The MCP function model configuration.
141
+
142
+ Returns:
143
+ A sequence of LangChain tools that can be used by agents.
144
+ """
145
+ mcp_url = function.mcp_url
146
+ logger.debug("Creating MCP tools", mcp_url=mcp_url)
147
+
148
+ connection_config = _build_connection_config(function)
149
+
150
+ if function.connection:
151
+ logger.debug(
152
+ "Using UC Connection for MCP",
153
+ connection_name=function.connection.name,
154
+ mcp_url=mcp_url,
155
+ )
156
+ else:
157
+ logger.debug(
158
+ "Using direct connection for MCP",
159
+ transport=function.transport,
160
+ mcp_url=mcp_url,
161
+ )
162
+
163
+ # Create client to list available tools
164
+ client = MultiServerMCPClient({"mcp_function": connection_config})
165
+
166
+ async def _list_tools() -> list[Tool]:
167
+ """List available MCP tools from the server."""
168
+ async with client.session("mcp_function") as session:
169
+ result = await session.list_tools()
170
+ return result.tools if hasattr(result, "tools") else list(result)
171
+
172
+ try:
173
+ mcp_tools: list[Tool] = asyncio.run(_list_tools())
174
+
175
+ # Log discovered tools
176
+ logger.info(
177
+ "Discovered MCP tools",
178
+ tools_count=len(mcp_tools),
179
+ mcp_url=mcp_url,
180
+ )
181
+ for mcp_tool in mcp_tools:
182
+ logger.debug(
183
+ "MCP tool discovered",
184
+ tool_name=mcp_tool.name,
185
+ tool_description=(
186
+ mcp_tool.description[:100] if mcp_tool.description else None
187
+ ),
170
188
  )
171
189
 
172
- # Create wrapper tools with fresh session per invocation
173
- def _create_tool_wrapper(mcp_tool: Tool) -> RunnableLike:
174
- @create_tool(
175
- mcp_tool.name,
176
- description=mcp_tool.description or f"MCP tool: {mcp_tool.name}",
177
- args_schema=mcp_tool.inputSchema,
190
+ except Exception as e:
191
+ if function.connection:
192
+ logger.error(
193
+ "Failed to get tools from MCP server via UC Connection",
194
+ connection_name=function.connection.name,
195
+ error=str(e),
196
+ )
197
+ raise RuntimeError(
198
+ f"Failed to list MCP tools via UC Connection "
199
+ f"'{function.connection.name}': {e}"
200
+ ) from e
201
+ else:
202
+ logger.error(
203
+ "Failed to get tools from MCP server",
204
+ transport=function.transport,
205
+ url=function.url,
206
+ error=str(e),
207
+ )
208
+ raise RuntimeError(
209
+ f"Failed to list MCP tools with transport '{function.transport}' "
210
+ f"and URL '{function.url}': {e}"
211
+ ) from e
212
+
213
+ def _create_tool_wrapper(mcp_tool: Tool) -> RunnableLike:
214
+ """
215
+ Create a LangChain tool wrapper for an MCP tool.
216
+
217
+ This wrapper handles:
218
+ - Fresh session creation per invocation (stateless)
219
+ - Content extraction to plain text (avoiding extra fields)
220
+ """
221
+
222
+ @create_tool(
223
+ mcp_tool.name,
224
+ description=mcp_tool.description or f"MCP tool: {mcp_tool.name}",
225
+ args_schema=mcp_tool.inputSchema,
226
+ )
227
+ async def tool_wrapper(**kwargs: Any) -> str:
228
+ """Execute MCP tool with fresh session."""
229
+ logger.trace("Invoking MCP tool", tool_name=mcp_tool.name, args=kwargs)
230
+
231
+ # Create a fresh client/session for each invocation
232
+ invocation_client = MultiServerMCPClient(
233
+ {"mcp_function": _build_connection_config(function)}
178
234
  )
179
- async def tool_wrapper(**kwargs):
180
- """Execute MCP tool with fresh session and authentication."""
181
- logger.debug(f"Invoking MCP tool {mcp_tool.name} with fresh session")
182
235
 
183
- connection = _create_fresh_connection()
184
- client = MultiServerMCPClient({function.name: connection})
236
+ try:
237
+ async with invocation_client.session("mcp_function") as session:
238
+ result: CallToolResult = await session.call_tool(
239
+ mcp_tool.name, kwargs
240
+ )
241
+
242
+ # Extract text content, avoiding extra fields
243
+ text_result = _extract_text_content(result)
244
+
245
+ logger.trace(
246
+ "MCP tool completed",
247
+ tool_name=mcp_tool.name,
248
+ result_length=len(text_result),
249
+ )
185
250
 
186
- try:
187
- async with client.session(function.name) as session:
188
- return await session.call_tool(mcp_tool.name, kwargs)
189
- except Exception as e:
190
- logger.error(f"MCP tool {mcp_tool.name} failed: {e}")
191
- raise
251
+ return text_result
252
+
253
+ except Exception as e:
254
+ logger.error(
255
+ "MCP tool failed",
256
+ tool_name=mcp_tool.name,
257
+ error=str(e),
258
+ )
259
+ raise
192
260
 
193
- # HITL is now handled at middleware level via HumanInTheLoopMiddleware
194
- return tool_wrapper
261
+ return tool_wrapper
195
262
 
196
- return [_create_tool_wrapper(tool) for tool in mcp_tools]
263
+ return [_create_tool_wrapper(tool) for tool in mcp_tools]
dao_ai/tools/python.py CHANGED
@@ -21,7 +21,7 @@ def create_factory_tool(
21
21
  Returns:
22
22
  A callable tool function that wraps the specified factory function
23
23
  """
24
- logger.debug(f"create_factory_tool: {function}")
24
+ logger.trace("Creating factory tool", function=function.full_name)
25
25
 
26
26
  factory: Callable[..., Any] = load_function(function_name=function.full_name)
27
27
  tool: RunnableLike = factory(**function.args)
@@ -41,7 +41,10 @@ def create_python_tool(
41
41
  Returns:
42
42
  A callable tool function that wraps the specified Python function
43
43
  """
44
- logger.debug(f"create_python_tool: {function}")
44
+ function_name = (
45
+ function.full_name if isinstance(function, PythonFunctionModel) else function
46
+ )
47
+ logger.trace("Creating Python tool", function=function_name)
45
48
 
46
49
  if isinstance(function, PythonFunctionModel):
47
50
  function = function.full_name
dao_ai/tools/search.py CHANGED
@@ -10,5 +10,5 @@ def create_search_tool() -> RunnableLike:
10
10
  Returns:
11
11
  RunnableLike: A DuckDuckGo search tool that returns results as a list
12
12
  """
13
- logger.debug("Creating DuckDuckGo search tool")
13
+ logger.trace("Creating DuckDuckGo search tool")
14
14
  return DuckDuckGoSearchRun(output_format="list")
dao_ai/tools/slack.py CHANGED
@@ -26,7 +26,7 @@ def _find_channel_id_by_name(
26
26
  # Remove '#' prefix if present
27
27
  clean_name = channel_name.lstrip("#")
28
28
 
29
- logger.debug(f"Looking up Slack channel ID for channel name: {clean_name}")
29
+ logger.trace("Looking up Slack channel ID", channel_name=clean_name)
30
30
 
31
31
  try:
32
32
  # Call Slack API to list conversations
@@ -37,14 +37,18 @@ def _find_channel_id_by_name(
37
37
  )
38
38
 
39
39
  if response.status_code != 200:
40
- logger.error(f"Failed to list Slack channels: {response.text}")
40
+ logger.error(
41
+ "Failed to list Slack channels",
42
+ status_code=response.status_code,
43
+ response=response.text,
44
+ )
41
45
  return None
42
46
 
43
47
  # Parse response
44
48
  data = response.json()
45
49
 
46
50
  if not data.get("ok"):
47
- logger.error(f"Slack API returned error: {data.get('error')}")
51
+ logger.error("Slack API returned error", error=data.get("error"))
48
52
  return None
49
53
 
50
54
  # Search for channel by name
@@ -53,15 +57,19 @@ def _find_channel_id_by_name(
53
57
  if channel.get("name") == clean_name:
54
58
  channel_id = channel.get("id")
55
59
  logger.debug(
56
- f"Found channel ID '{channel_id}' for channel name '{clean_name}'"
60
+ "Found Slack channel ID",
61
+ channel_id=channel_id,
62
+ channel_name=clean_name,
57
63
  )
58
64
  return channel_id
59
65
 
60
- logger.warning(f"Channel '{clean_name}' not found in Slack workspace")
66
+ logger.warning("Slack channel not found", channel_name=clean_name)
61
67
  return None
62
68
 
63
69
  except Exception as e:
64
- logger.error(f"Error looking up Slack channel: {e}")
70
+ logger.error(
71
+ "Error looking up Slack channel", channel_name=clean_name, error=str(e)
72
+ )
65
73
  return None
66
74
 
67
75
 
@@ -87,7 +95,7 @@ def create_send_slack_message_tool(
87
95
 
88
96
  Based on: https://docs.databricks.com/aws/en/generative-ai/agent-framework/slack-agent
89
97
  """
90
- logger.debug("create_send_slack_message_tool")
98
+ logger.trace("Creating send Slack message tool")
91
99
 
92
100
  # Validate inputs
93
101
  if channel_id is None and channel_name is None:
@@ -99,12 +107,16 @@ def create_send_slack_message_tool(
99
107
 
100
108
  # Look up channel_id from channel_name if needed
101
109
  if channel_id is None and channel_name is not None:
102
- logger.debug(f"Looking up channel_id for channel_name: {channel_name}")
110
+ logger.trace(
111
+ "Looking up channel ID for channel name", channel_name=channel_name
112
+ )
103
113
  channel_id = _find_channel_id_by_name(connection, channel_name)
104
114
  if channel_id is None:
105
115
  raise ValueError(f"Could not find Slack channel with name '{channel_name}'")
106
116
  logger.debug(
107
- f"Resolved channel_name '{channel_name}' to channel_id '{channel_id}'"
117
+ "Resolved channel name to ID",
118
+ channel_name=channel_name,
119
+ channel_id=channel_id,
108
120
  )
109
121
 
110
122
  if name is None: