dao-ai 0.0.25__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/__init__.py +29 -0
- dao_ai/agent_as_code.py +5 -5
- dao_ai/cli.py +245 -40
- dao_ai/config.py +1863 -338
- dao_ai/genie/__init__.py +38 -0
- dao_ai/genie/cache/__init__.py +43 -0
- dao_ai/genie/cache/base.py +72 -0
- dao_ai/genie/cache/core.py +79 -0
- dao_ai/genie/cache/lru.py +347 -0
- dao_ai/genie/cache/semantic.py +970 -0
- dao_ai/genie/core.py +35 -0
- dao_ai/graph.py +27 -228
- dao_ai/hooks/__init__.py +9 -6
- dao_ai/hooks/core.py +27 -195
- dao_ai/logging.py +56 -0
- dao_ai/memory/__init__.py +10 -0
- dao_ai/memory/core.py +65 -30
- dao_ai/memory/databricks.py +402 -0
- dao_ai/memory/postgres.py +79 -38
- dao_ai/messages.py +6 -4
- dao_ai/middleware/__init__.py +125 -0
- dao_ai/middleware/assertions.py +806 -0
- dao_ai/middleware/base.py +50 -0
- dao_ai/middleware/core.py +67 -0
- dao_ai/middleware/guardrails.py +420 -0
- dao_ai/middleware/human_in_the_loop.py +232 -0
- dao_ai/middleware/message_validation.py +586 -0
- dao_ai/middleware/summarization.py +197 -0
- dao_ai/models.py +1306 -114
- dao_ai/nodes.py +261 -166
- dao_ai/optimization.py +674 -0
- dao_ai/orchestration/__init__.py +52 -0
- dao_ai/orchestration/core.py +294 -0
- dao_ai/orchestration/supervisor.py +278 -0
- dao_ai/orchestration/swarm.py +271 -0
- dao_ai/prompts.py +128 -31
- dao_ai/providers/databricks.py +645 -172
- dao_ai/state.py +157 -21
- dao_ai/tools/__init__.py +13 -5
- dao_ai/tools/agent.py +1 -3
- dao_ai/tools/core.py +64 -11
- dao_ai/tools/email.py +232 -0
- dao_ai/tools/genie.py +144 -295
- dao_ai/tools/mcp.py +220 -133
- dao_ai/tools/memory.py +50 -0
- dao_ai/tools/python.py +9 -14
- dao_ai/tools/search.py +14 -0
- dao_ai/tools/slack.py +22 -10
- dao_ai/tools/sql.py +202 -0
- dao_ai/tools/time.py +30 -7
- dao_ai/tools/unity_catalog.py +165 -88
- dao_ai/tools/vector_search.py +360 -40
- dao_ai/utils.py +218 -16
- dao_ai-0.1.2.dist-info/METADATA +455 -0
- dao_ai-0.1.2.dist-info/RECORD +64 -0
- {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/WHEEL +1 -1
- dao_ai/chat_models.py +0 -204
- dao_ai/guardrails.py +0 -112
- dao_ai/tools/human_in_the_loop.py +0 -100
- dao_ai-0.0.25.dist-info/METADATA +0 -1165
- dao_ai-0.0.25.dist-info/RECORD +0 -41
- {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/licenses/LICENSE +0 -0
dao_ai/tools/mcp.py
CHANGED
|
@@ -1,21 +1,127 @@
|
|
|
1
|
+
"""
|
|
2
|
+
MCP (Model Context Protocol) tool creation for LangChain agents.
|
|
3
|
+
|
|
4
|
+
This module provides tools for connecting to MCP servers using the
|
|
5
|
+
MCP SDK and langchain-mcp-adapters library.
|
|
6
|
+
|
|
7
|
+
For compatibility with Databricks APIs, we use manual tool wrappers
|
|
8
|
+
that give us full control over the response format.
|
|
9
|
+
|
|
10
|
+
Reference: https://docs.langchain.com/oss/python/langchain/mcp
|
|
11
|
+
"""
|
|
12
|
+
|
|
1
13
|
import asyncio
|
|
2
14
|
from typing import Any, Sequence
|
|
3
15
|
|
|
4
|
-
from databricks_mcp import DatabricksOAuthClientProvider
|
|
5
16
|
from langchain_core.runnables.base import RunnableLike
|
|
6
17
|
from langchain_core.tools import tool as create_tool
|
|
7
18
|
from langchain_mcp_adapters.client import MultiServerMCPClient
|
|
8
|
-
from langchain_mcp_adapters.tools import load_mcp_tools
|
|
9
19
|
from loguru import logger
|
|
10
|
-
from mcp import
|
|
11
|
-
from mcp.client.streamable_http import streamablehttp_client
|
|
12
|
-
from mcp.types import ListToolsResult, Tool
|
|
20
|
+
from mcp.types import CallToolResult, TextContent, Tool
|
|
13
21
|
|
|
14
22
|
from dao_ai.config import (
|
|
15
23
|
McpFunctionModel,
|
|
16
24
|
TransportType,
|
|
25
|
+
value_of,
|
|
17
26
|
)
|
|
18
|
-
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _build_connection_config(
|
|
30
|
+
function: McpFunctionModel,
|
|
31
|
+
) -> dict[str, Any]:
|
|
32
|
+
"""
|
|
33
|
+
Build the connection configuration dictionary for MultiServerMCPClient.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
function: The MCP function model configuration.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
A dictionary containing the transport-specific connection settings.
|
|
40
|
+
"""
|
|
41
|
+
if function.transport == TransportType.STDIO:
|
|
42
|
+
return {
|
|
43
|
+
"command": function.command,
|
|
44
|
+
"args": function.args,
|
|
45
|
+
"transport": function.transport.value,
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
# For HTTP transport with UC Connection, use DatabricksOAuthClientProvider
|
|
49
|
+
if function.connection:
|
|
50
|
+
from databricks_mcp import DatabricksOAuthClientProvider
|
|
51
|
+
|
|
52
|
+
workspace_client = function.connection.workspace_client
|
|
53
|
+
auth_provider = DatabricksOAuthClientProvider(workspace_client)
|
|
54
|
+
|
|
55
|
+
logger.trace(
|
|
56
|
+
"Using DatabricksOAuthClientProvider for authentication",
|
|
57
|
+
connection_name=function.connection.name,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
return {
|
|
61
|
+
"url": function.mcp_url,
|
|
62
|
+
"transport": "http",
|
|
63
|
+
"auth": auth_provider,
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
# For HTTP transport with headers-based authentication
|
|
67
|
+
headers: dict[str, str] = {
|
|
68
|
+
key: str(value_of(val)) for key, val in function.headers.items()
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
if "Authorization" not in headers:
|
|
72
|
+
logger.trace("Generating fresh authentication token")
|
|
73
|
+
|
|
74
|
+
from dao_ai.providers.databricks import DatabricksProvider
|
|
75
|
+
|
|
76
|
+
try:
|
|
77
|
+
provider = DatabricksProvider(
|
|
78
|
+
workspace_host=value_of(function.workspace_host),
|
|
79
|
+
client_id=value_of(function.client_id),
|
|
80
|
+
client_secret=value_of(function.client_secret),
|
|
81
|
+
pat=value_of(function.pat),
|
|
82
|
+
)
|
|
83
|
+
headers["Authorization"] = f"Bearer {provider.create_token()}"
|
|
84
|
+
logger.trace("Generated fresh authentication token")
|
|
85
|
+
except Exception as e:
|
|
86
|
+
logger.error("Failed to create fresh token", error=str(e))
|
|
87
|
+
else:
|
|
88
|
+
logger.trace("Using existing authentication token")
|
|
89
|
+
|
|
90
|
+
return {
|
|
91
|
+
"url": function.mcp_url,
|
|
92
|
+
"transport": "http",
|
|
93
|
+
"headers": headers,
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def _extract_text_content(result: CallToolResult) -> str:
|
|
98
|
+
"""
|
|
99
|
+
Extract text content from an MCP CallToolResult.
|
|
100
|
+
|
|
101
|
+
Converts the MCP result content to a plain string format that is
|
|
102
|
+
compatible with all LLM APIs (avoiding extra fields like 'id').
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
result: The MCP tool call result.
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
A string containing the concatenated text content.
|
|
109
|
+
"""
|
|
110
|
+
if not result.content:
|
|
111
|
+
return ""
|
|
112
|
+
|
|
113
|
+
text_parts: list[str] = []
|
|
114
|
+
for item in result.content:
|
|
115
|
+
if isinstance(item, TextContent):
|
|
116
|
+
text_parts.append(item.text)
|
|
117
|
+
elif hasattr(item, "text"):
|
|
118
|
+
# Handle other content types that have text
|
|
119
|
+
text_parts.append(str(item.text))
|
|
120
|
+
else:
|
|
121
|
+
# Fallback: convert to string representation
|
|
122
|
+
text_parts.append(str(item))
|
|
123
|
+
|
|
124
|
+
return "\n".join(text_parts)
|
|
19
125
|
|
|
20
126
|
|
|
21
127
|
def create_mcp_tools(
|
|
@@ -25,152 +131,133 @@ def create_mcp_tools(
|
|
|
25
131
|
Create tools for invoking Databricks MCP functions.
|
|
26
132
|
|
|
27
133
|
Supports both direct MCP connections and UC Connection-based MCP access.
|
|
28
|
-
Uses
|
|
134
|
+
Uses manual tool wrappers to ensure response format compatibility with
|
|
135
|
+
Databricks APIs (which reject extra fields in tool results).
|
|
29
136
|
|
|
30
137
|
Based on: https://docs.databricks.com/aws/en/generative-ai/mcp/external-mcp
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
function: The MCP function model configuration.
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
A sequence of LangChain tools that can be used by agents.
|
|
31
144
|
"""
|
|
32
|
-
|
|
145
|
+
mcp_url = function.mcp_url
|
|
146
|
+
logger.debug("Creating MCP tools", mcp_url=mcp_url)
|
|
147
|
+
|
|
148
|
+
connection_config = _build_connection_config(function)
|
|
33
149
|
|
|
34
|
-
# Check if using UC Connection or direct MCP connection
|
|
35
150
|
if function.connection:
|
|
36
|
-
|
|
37
|
-
|
|
151
|
+
logger.debug(
|
|
152
|
+
"Using UC Connection for MCP",
|
|
153
|
+
connection_name=function.connection.name,
|
|
154
|
+
mcp_url=mcp_url,
|
|
155
|
+
)
|
|
156
|
+
else:
|
|
157
|
+
logger.debug(
|
|
158
|
+
"Using direct connection for MCP",
|
|
159
|
+
transport=function.transport,
|
|
160
|
+
mcp_url=mcp_url,
|
|
161
|
+
)
|
|
38
162
|
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
mcp_url = function.url
|
|
42
|
-
logger.debug(f"Using provided MCP URL: {mcp_url}")
|
|
43
|
-
else:
|
|
44
|
-
# Construct URL from workspace host and connection name
|
|
45
|
-
# Pattern: https://{workspace_host}/api/2.0/mcp/external/{connection_name}
|
|
46
|
-
workspace_client = function.connection.workspace_client
|
|
47
|
-
workspace_host = workspace_client.config.host
|
|
48
|
-
connection_name = function.connection.name
|
|
49
|
-
mcp_url = f"{workspace_host}/api/2.0/mcp/external/{connection_name}"
|
|
50
|
-
logger.debug(f"Constructed MCP URL from connection: {mcp_url}")
|
|
51
|
-
|
|
52
|
-
async def _get_tools_with_connection():
|
|
53
|
-
"""Get tools using DatabricksOAuthClientProvider."""
|
|
54
|
-
workspace_client = function.connection.workspace_client
|
|
55
|
-
|
|
56
|
-
async with streamablehttp_client(
|
|
57
|
-
mcp_url, auth=DatabricksOAuthClientProvider(workspace_client)
|
|
58
|
-
) as (read_stream, write_stream, _):
|
|
59
|
-
async with ClientSession(read_stream, write_stream) as session:
|
|
60
|
-
# Initialize and list tools
|
|
61
|
-
await session.initialize()
|
|
62
|
-
tools = await load_mcp_tools(session)
|
|
63
|
-
return tools
|
|
163
|
+
# Create client to list available tools
|
|
164
|
+
client = MultiServerMCPClient({"mcp_function": connection_config})
|
|
64
165
|
|
|
65
|
-
|
|
66
|
-
|
|
166
|
+
async def _list_tools() -> list[Tool]:
|
|
167
|
+
"""List available MCP tools from the server."""
|
|
168
|
+
async with client.session("mcp_function") as session:
|
|
169
|
+
result = await session.list_tools()
|
|
170
|
+
return result.tools if hasattr(result, "tools") else list(result)
|
|
171
|
+
|
|
172
|
+
try:
|
|
173
|
+
mcp_tools: list[Tool] = asyncio.run(_list_tools())
|
|
174
|
+
|
|
175
|
+
# Log discovered tools
|
|
176
|
+
logger.info(
|
|
177
|
+
"Discovered MCP tools",
|
|
178
|
+
tools_count=len(mcp_tools),
|
|
179
|
+
mcp_url=mcp_url,
|
|
180
|
+
)
|
|
181
|
+
for mcp_tool in mcp_tools:
|
|
67
182
|
logger.debug(
|
|
68
|
-
|
|
183
|
+
"MCP tool discovered",
|
|
184
|
+
tool_name=mcp_tool.name,
|
|
185
|
+
tool_description=(
|
|
186
|
+
mcp_tool.description[:100] if mcp_tool.description else None
|
|
187
|
+
),
|
|
69
188
|
)
|
|
70
189
|
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
logger.error(f"Failed to get tools from MCP server via UC Connection: {e}")
|
|
190
|
+
except Exception as e:
|
|
191
|
+
if function.connection:
|
|
192
|
+
logger.error(
|
|
193
|
+
"Failed to get tools from MCP server via UC Connection",
|
|
194
|
+
connection_name=function.connection.name,
|
|
195
|
+
error=str(e),
|
|
196
|
+
)
|
|
79
197
|
raise RuntimeError(
|
|
80
|
-
f"Failed to list MCP tools
|
|
198
|
+
f"Failed to list MCP tools via UC Connection "
|
|
199
|
+
f"'{function.connection.name}': {e}"
|
|
200
|
+
) from e
|
|
201
|
+
else:
|
|
202
|
+
logger.error(
|
|
203
|
+
"Failed to get tools from MCP server",
|
|
204
|
+
transport=function.transport,
|
|
205
|
+
url=function.url,
|
|
206
|
+
error=str(e),
|
|
81
207
|
)
|
|
208
|
+
raise RuntimeError(
|
|
209
|
+
f"Failed to list MCP tools with transport '{function.transport}' "
|
|
210
|
+
f"and URL '{function.url}': {e}"
|
|
211
|
+
) from e
|
|
82
212
|
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
def _create_fresh_connection() -> dict[str, Any]:
|
|
88
|
-
"""Create connection config with fresh authentication headers."""
|
|
89
|
-
logger.debug("Creating fresh connection...")
|
|
90
|
-
|
|
91
|
-
if function.transport == TransportType.STDIO:
|
|
92
|
-
return {
|
|
93
|
-
"command": function.command,
|
|
94
|
-
"args": function.args,
|
|
95
|
-
"transport": function.transport,
|
|
96
|
-
}
|
|
97
|
-
|
|
98
|
-
# For HTTP transport, generate fresh headers
|
|
99
|
-
headers = function.headers.copy() if function.headers else {}
|
|
100
|
-
|
|
101
|
-
if "Authorization" not in headers:
|
|
102
|
-
logger.debug("Generating fresh authentication token for MCP function")
|
|
103
|
-
|
|
104
|
-
from dao_ai.config import value_of
|
|
105
|
-
from dao_ai.providers.databricks import DatabricksProvider
|
|
106
|
-
|
|
107
|
-
try:
|
|
108
|
-
provider = DatabricksProvider(
|
|
109
|
-
workspace_host=value_of(function.workspace_host),
|
|
110
|
-
client_id=value_of(function.client_id),
|
|
111
|
-
client_secret=value_of(function.client_secret),
|
|
112
|
-
pat=value_of(function.pat),
|
|
113
|
-
)
|
|
114
|
-
headers["Authorization"] = f"Bearer {provider.create_token()}"
|
|
115
|
-
logger.debug("Generated fresh authentication token")
|
|
116
|
-
except Exception as e:
|
|
117
|
-
logger.error(f"Failed to create fresh token: {e}")
|
|
118
|
-
else:
|
|
119
|
-
logger.debug("Using existing authentication token")
|
|
120
|
-
|
|
121
|
-
return {
|
|
122
|
-
"url": function.url,
|
|
123
|
-
"transport": function.transport,
|
|
124
|
-
"headers": headers,
|
|
125
|
-
}
|
|
126
|
-
|
|
127
|
-
# Get available tools from MCP server
|
|
128
|
-
async def _list_mcp_tools():
|
|
129
|
-
connection = _create_fresh_connection()
|
|
130
|
-
client = MultiServerMCPClient({function.name: connection})
|
|
213
|
+
def _create_tool_wrapper(mcp_tool: Tool) -> RunnableLike:
|
|
214
|
+
"""
|
|
215
|
+
Create a LangChain tool wrapper for an MCP tool.
|
|
131
216
|
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
logger.error(f"Failed to list MCP tools: {e}")
|
|
137
|
-
return []
|
|
217
|
+
This wrapper handles:
|
|
218
|
+
- Fresh session creation per invocation (stateless)
|
|
219
|
+
- Content extraction to plain text (avoiding extra fields)
|
|
220
|
+
"""
|
|
138
221
|
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
222
|
+
@create_tool(
|
|
223
|
+
mcp_tool.name,
|
|
224
|
+
description=mcp_tool.description or f"MCP tool: {mcp_tool.name}",
|
|
225
|
+
args_schema=mcp_tool.inputSchema,
|
|
226
|
+
)
|
|
227
|
+
async def tool_wrapper(**kwargs: Any) -> str:
|
|
228
|
+
"""Execute MCP tool with fresh session."""
|
|
229
|
+
logger.trace("Invoking MCP tool", tool_name=mcp_tool.name, args=kwargs)
|
|
145
230
|
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
raise RuntimeError(
|
|
150
|
-
f"Failed to list MCP tools for function '{function.name}' with transport '{function.transport}' and URL '{function.url}': {e}"
|
|
231
|
+
# Create a fresh client/session for each invocation
|
|
232
|
+
invocation_client = MultiServerMCPClient(
|
|
233
|
+
{"mcp_function": _build_connection_config(function)}
|
|
151
234
|
)
|
|
152
235
|
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
args_schema=mcp_tool.inputSchema,
|
|
159
|
-
)
|
|
160
|
-
async def tool_wrapper(**kwargs):
|
|
161
|
-
"""Execute MCP tool with fresh session and authentication."""
|
|
162
|
-
logger.debug(f"Invoking MCP tool {mcp_tool.name} with fresh session")
|
|
236
|
+
try:
|
|
237
|
+
async with invocation_client.session("mcp_function") as session:
|
|
238
|
+
result: CallToolResult = await session.call_tool(
|
|
239
|
+
mcp_tool.name, kwargs
|
|
240
|
+
)
|
|
163
241
|
|
|
164
|
-
|
|
165
|
-
|
|
242
|
+
# Extract text content, avoiding extra fields
|
|
243
|
+
text_result = _extract_text_content(result)
|
|
166
244
|
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
245
|
+
logger.trace(
|
|
246
|
+
"MCP tool completed",
|
|
247
|
+
tool_name=mcp_tool.name,
|
|
248
|
+
result_length=len(text_result),
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
return text_result
|
|
252
|
+
|
|
253
|
+
except Exception as e:
|
|
254
|
+
logger.error(
|
|
255
|
+
"MCP tool failed",
|
|
256
|
+
tool_name=mcp_tool.name,
|
|
257
|
+
error=str(e),
|
|
258
|
+
)
|
|
259
|
+
raise
|
|
173
260
|
|
|
174
|
-
|
|
261
|
+
return tool_wrapper
|
|
175
262
|
|
|
176
|
-
|
|
263
|
+
return [_create_tool_wrapper(tool) for tool in mcp_tools]
|
dao_ai/tools/memory.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
"""Memory tools for DAO AI."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from langchain_core.tools import BaseTool, StructuredTool
|
|
6
|
+
from langmem import create_search_memory_tool as langmem_create_search_memory_tool
|
|
7
|
+
from pydantic import BaseModel, Field
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def create_search_memory_tool(namespace: tuple[str, ...]) -> BaseTool:
|
|
11
|
+
"""
|
|
12
|
+
Create a Databricks-compatible search_memory tool.
|
|
13
|
+
|
|
14
|
+
The langmem search_memory tool has a 'filter' field with additionalProperties: true
|
|
15
|
+
in its schema, which Databricks LLM endpoints reject. This function creates a
|
|
16
|
+
wrapper tool that omits the problematic filter field.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
namespace: The memory namespace tuple
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
A StructuredTool compatible with Databricks
|
|
23
|
+
"""
|
|
24
|
+
# Get the original tool
|
|
25
|
+
original_tool = langmem_create_search_memory_tool(namespace=namespace)
|
|
26
|
+
|
|
27
|
+
# Create a schema without the problematic filter field
|
|
28
|
+
class SearchMemoryInput(BaseModel):
|
|
29
|
+
"""Input for search_memory tool."""
|
|
30
|
+
|
|
31
|
+
query: str = Field(..., description="The search query")
|
|
32
|
+
limit: int = Field(default=10, description="Maximum number of results")
|
|
33
|
+
offset: int = Field(default=0, description="Offset for pagination")
|
|
34
|
+
|
|
35
|
+
# Create a wrapper function
|
|
36
|
+
async def search_memory_wrapper(
|
|
37
|
+
query: str, limit: int = 10, offset: int = 0
|
|
38
|
+
) -> Any:
|
|
39
|
+
"""Search your long-term memories for information relevant to your current context."""
|
|
40
|
+
return await original_tool.ainvoke(
|
|
41
|
+
{"query": query, "limit": limit, "offset": offset}
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
# Create the new tool
|
|
45
|
+
return StructuredTool.from_function(
|
|
46
|
+
coroutine=search_memory_wrapper,
|
|
47
|
+
name="search_memory",
|
|
48
|
+
description="Search your long-term memories for information relevant to your current context.",
|
|
49
|
+
args_schema=SearchMemoryInput,
|
|
50
|
+
)
|
dao_ai/tools/python.py
CHANGED
|
@@ -7,7 +7,6 @@ from dao_ai.config import (
|
|
|
7
7
|
FactoryFunctionModel,
|
|
8
8
|
PythonFunctionModel,
|
|
9
9
|
)
|
|
10
|
-
from dao_ai.tools.human_in_the_loop import as_human_in_the_loop
|
|
11
10
|
from dao_ai.utils import load_function
|
|
12
11
|
|
|
13
12
|
|
|
@@ -22,14 +21,11 @@ def create_factory_tool(
|
|
|
22
21
|
Returns:
|
|
23
22
|
A callable tool function that wraps the specified factory function
|
|
24
23
|
"""
|
|
25
|
-
logger.
|
|
24
|
+
logger.trace("Creating factory tool", function=function.full_name)
|
|
26
25
|
|
|
27
26
|
factory: Callable[..., Any] = load_function(function_name=function.full_name)
|
|
28
|
-
tool:
|
|
29
|
-
|
|
30
|
-
tool=tool,
|
|
31
|
-
function=function,
|
|
32
|
-
)
|
|
27
|
+
tool: RunnableLike = factory(**function.args)
|
|
28
|
+
# HITL is now handled at middleware level via HumanInTheLoopMiddleware
|
|
33
29
|
return tool
|
|
34
30
|
|
|
35
31
|
|
|
@@ -45,16 +41,15 @@ def create_python_tool(
|
|
|
45
41
|
Returns:
|
|
46
42
|
A callable tool function that wraps the specified Python function
|
|
47
43
|
"""
|
|
48
|
-
|
|
44
|
+
function_name = (
|
|
45
|
+
function.full_name if isinstance(function, PythonFunctionModel) else function
|
|
46
|
+
)
|
|
47
|
+
logger.trace("Creating Python tool", function=function_name)
|
|
49
48
|
|
|
50
49
|
if isinstance(function, PythonFunctionModel):
|
|
51
50
|
function = function.full_name
|
|
52
51
|
|
|
53
52
|
# Load the Python function dynamically
|
|
54
|
-
tool:
|
|
55
|
-
|
|
56
|
-
tool = as_human_in_the_loop(
|
|
57
|
-
tool=tool,
|
|
58
|
-
function=function,
|
|
59
|
-
)
|
|
53
|
+
tool: RunnableLike = load_function(function_name=function)
|
|
54
|
+
# HITL is now handled at middleware level via HumanInTheLoopMiddleware
|
|
60
55
|
return tool
|
dao_ai/tools/search.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from langchain_community.tools import DuckDuckGoSearchRun
|
|
2
|
+
from langchain_core.runnables.base import RunnableLike
|
|
3
|
+
from loguru import logger
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def create_search_tool() -> RunnableLike:
|
|
7
|
+
"""
|
|
8
|
+
Create a DuckDuckGo search tool.
|
|
9
|
+
|
|
10
|
+
Returns:
|
|
11
|
+
RunnableLike: A DuckDuckGo search tool that returns results as a list
|
|
12
|
+
"""
|
|
13
|
+
logger.trace("Creating DuckDuckGo search tool")
|
|
14
|
+
return DuckDuckGoSearchRun(output_format="list")
|
dao_ai/tools/slack.py
CHANGED
|
@@ -26,7 +26,7 @@ def _find_channel_id_by_name(
|
|
|
26
26
|
# Remove '#' prefix if present
|
|
27
27
|
clean_name = channel_name.lstrip("#")
|
|
28
28
|
|
|
29
|
-
logger.
|
|
29
|
+
logger.trace("Looking up Slack channel ID", channel_name=clean_name)
|
|
30
30
|
|
|
31
31
|
try:
|
|
32
32
|
# Call Slack API to list conversations
|
|
@@ -37,14 +37,18 @@ def _find_channel_id_by_name(
|
|
|
37
37
|
)
|
|
38
38
|
|
|
39
39
|
if response.status_code != 200:
|
|
40
|
-
logger.error(
|
|
40
|
+
logger.error(
|
|
41
|
+
"Failed to list Slack channels",
|
|
42
|
+
status_code=response.status_code,
|
|
43
|
+
response=response.text,
|
|
44
|
+
)
|
|
41
45
|
return None
|
|
42
46
|
|
|
43
47
|
# Parse response
|
|
44
48
|
data = response.json()
|
|
45
49
|
|
|
46
50
|
if not data.get("ok"):
|
|
47
|
-
logger.error(
|
|
51
|
+
logger.error("Slack API returned error", error=data.get("error"))
|
|
48
52
|
return None
|
|
49
53
|
|
|
50
54
|
# Search for channel by name
|
|
@@ -53,15 +57,19 @@ def _find_channel_id_by_name(
|
|
|
53
57
|
if channel.get("name") == clean_name:
|
|
54
58
|
channel_id = channel.get("id")
|
|
55
59
|
logger.debug(
|
|
56
|
-
|
|
60
|
+
"Found Slack channel ID",
|
|
61
|
+
channel_id=channel_id,
|
|
62
|
+
channel_name=clean_name,
|
|
57
63
|
)
|
|
58
64
|
return channel_id
|
|
59
65
|
|
|
60
|
-
logger.warning(
|
|
66
|
+
logger.warning("Slack channel not found", channel_name=clean_name)
|
|
61
67
|
return None
|
|
62
68
|
|
|
63
69
|
except Exception as e:
|
|
64
|
-
logger.error(
|
|
70
|
+
logger.error(
|
|
71
|
+
"Error looking up Slack channel", channel_name=clean_name, error=str(e)
|
|
72
|
+
)
|
|
65
73
|
return None
|
|
66
74
|
|
|
67
75
|
|
|
@@ -71,7 +79,7 @@ def create_send_slack_message_tool(
|
|
|
71
79
|
channel_name: Optional[str] = None,
|
|
72
80
|
name: Optional[str] = None,
|
|
73
81
|
description: Optional[str] = None,
|
|
74
|
-
) -> Callable[[str],
|
|
82
|
+
) -> Callable[[str], str]:
|
|
75
83
|
"""
|
|
76
84
|
Create a tool that sends a message to a Slack channel.
|
|
77
85
|
|
|
@@ -87,7 +95,7 @@ def create_send_slack_message_tool(
|
|
|
87
95
|
|
|
88
96
|
Based on: https://docs.databricks.com/aws/en/generative-ai/agent-framework/slack-agent
|
|
89
97
|
"""
|
|
90
|
-
logger.
|
|
98
|
+
logger.trace("Creating send Slack message tool")
|
|
91
99
|
|
|
92
100
|
# Validate inputs
|
|
93
101
|
if channel_id is None and channel_name is None:
|
|
@@ -99,12 +107,16 @@ def create_send_slack_message_tool(
|
|
|
99
107
|
|
|
100
108
|
# Look up channel_id from channel_name if needed
|
|
101
109
|
if channel_id is None and channel_name is not None:
|
|
102
|
-
logger.
|
|
110
|
+
logger.trace(
|
|
111
|
+
"Looking up channel ID for channel name", channel_name=channel_name
|
|
112
|
+
)
|
|
103
113
|
channel_id = _find_channel_id_by_name(connection, channel_name)
|
|
104
114
|
if channel_id is None:
|
|
105
115
|
raise ValueError(f"Could not find Slack channel with name '{channel_name}'")
|
|
106
116
|
logger.debug(
|
|
107
|
-
|
|
117
|
+
"Resolved channel name to ID",
|
|
118
|
+
channel_name=channel_name,
|
|
119
|
+
channel_id=channel_id,
|
|
108
120
|
)
|
|
109
121
|
|
|
110
122
|
if name is None:
|