universal-mcp-agents 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- universal_mcp/agents/__init__.py +19 -0
- universal_mcp/agents/autoagent/__init__.py +1 -1
- universal_mcp/agents/autoagent/__main__.py +1 -1
- universal_mcp/agents/autoagent/graph.py +32 -13
- universal_mcp/agents/autoagent/studio.py +3 -8
- universal_mcp/agents/base.py +80 -22
- universal_mcp/agents/bigtool/__init__.py +13 -9
- universal_mcp/agents/bigtool/__main__.py +6 -7
- universal_mcp/agents/bigtool/graph.py +84 -40
- universal_mcp/agents/bigtool/prompts.py +3 -3
- universal_mcp/agents/bigtool2/__init__.py +16 -6
- universal_mcp/agents/bigtool2/__main__.py +7 -6
- universal_mcp/agents/bigtool2/agent.py +4 -2
- universal_mcp/agents/bigtool2/graph.py +78 -36
- universal_mcp/agents/bigtool2/prompts.py +1 -1
- universal_mcp/agents/bigtoolcache/__init__.py +8 -4
- universal_mcp/agents/bigtoolcache/__main__.py +1 -1
- universal_mcp/agents/bigtoolcache/agent.py +5 -3
- universal_mcp/agents/bigtoolcache/context.py +0 -1
- universal_mcp/agents/bigtoolcache/graph.py +99 -69
- universal_mcp/agents/bigtoolcache/prompts.py +28 -0
- universal_mcp/agents/bigtoolcache/tools_all.txt +956 -0
- universal_mcp/agents/bigtoolcache/tools_important.txt +474 -0
- universal_mcp/agents/builder.py +62 -20
- universal_mcp/agents/cli.py +19 -5
- universal_mcp/agents/codeact/__init__.py +16 -4
- universal_mcp/agents/codeact/test.py +2 -1
- universal_mcp/agents/hil.py +16 -4
- universal_mcp/agents/llm.py +12 -4
- universal_mcp/agents/planner/__init__.py +14 -4
- universal_mcp/agents/planner/__main__.py +10 -6
- universal_mcp/agents/planner/graph.py +9 -3
- universal_mcp/agents/planner/prompts.py +14 -1
- universal_mcp/agents/planner/state.py +0 -1
- universal_mcp/agents/react.py +36 -22
- universal_mcp/agents/shared/tool_node.py +26 -11
- universal_mcp/agents/simple.py +27 -4
- universal_mcp/agents/tools.py +9 -4
- universal_mcp/agents/ui_tools.py +305 -0
- universal_mcp/agents/utils.py +55 -17
- {universal_mcp_agents-0.1.3.dist-info → universal_mcp_agents-0.1.5.dist-info}/METADATA +3 -2
- universal_mcp_agents-0.1.5.dist-info/RECORD +52 -0
- universal_mcp/agents/bigtool/context.py +0 -24
- universal_mcp/agents/bigtool2/context.py +0 -33
- universal_mcp_agents-0.1.3.dist-info/RECORD +0 -51
- {universal_mcp_agents-0.1.3.dist-info → universal_mcp_agents-0.1.5.dist-info}/WHEEL +0 -0
universal_mcp/agents/__init__.py
CHANGED
|
@@ -7,6 +7,25 @@ from universal_mcp.agents.planner import PlannerAgent
|
|
|
7
7
|
from universal_mcp.agents.react import ReactAgent
|
|
8
8
|
from universal_mcp.agents.simple import SimpleAgent
|
|
9
9
|
|
|
10
|
+
|
|
11
|
+
def get_agent(agent_name: str):
|
|
12
|
+
if agent_name == "auto":
|
|
13
|
+
return AutoAgent
|
|
14
|
+
elif agent_name == "react":
|
|
15
|
+
return ReactAgent
|
|
16
|
+
elif agent_name == "simple":
|
|
17
|
+
return SimpleAgent
|
|
18
|
+
elif agent_name == "builder":
|
|
19
|
+
return BuilderAgent
|
|
20
|
+
elif agent_name == "planner":
|
|
21
|
+
return PlannerAgent
|
|
22
|
+
elif agent_name == "bigtool":
|
|
23
|
+
return BigToolAgent
|
|
24
|
+
elif agent_name == "bigtool2":
|
|
25
|
+
return BigToolAgent2
|
|
26
|
+
else:
|
|
27
|
+
raise ValueError(f"Unknown agent: {agent_name}. Possible values: auto, react, simple, builder, planner, bigtool, bigtool2")
|
|
28
|
+
|
|
10
29
|
__all__ = [
|
|
11
30
|
"BaseAgent",
|
|
12
31
|
"ReactAgent",
|
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
from langgraph.checkpoint.base import BaseCheckpointSaver
|
|
2
|
+
from universal_mcp.tools.registry import ToolRegistry
|
|
2
3
|
|
|
3
4
|
from universal_mcp.agents.autoagent.graph import build_graph
|
|
4
5
|
from universal_mcp.agents.base import BaseAgent
|
|
5
|
-
from universal_mcp.tools.registry import ToolRegistry
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
class AutoAgent(BaseAgent):
|
|
@@ -6,13 +6,13 @@ from langchain_core.messages import AIMessage, ToolMessage
|
|
|
6
6
|
from langchain_core.tools import tool
|
|
7
7
|
from langgraph.graph import END, START, StateGraph
|
|
8
8
|
from langgraph.runtime import Runtime
|
|
9
|
+
from universal_mcp.tools.registry import ToolRegistry
|
|
10
|
+
from universal_mcp.types import ToolFormat
|
|
9
11
|
|
|
10
12
|
from universal_mcp.agents.autoagent.context import Context
|
|
11
13
|
from universal_mcp.agents.autoagent.prompts import SYSTEM_PROMPT
|
|
12
14
|
from universal_mcp.agents.autoagent.state import State
|
|
13
15
|
from universal_mcp.agents.llm import load_chat_model
|
|
14
|
-
from universal_mcp.tools.registry import ToolRegistry
|
|
15
|
-
from universal_mcp.types import ToolFormat
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
async def build_graph(tool_registry: ToolRegistry, instructions: str = ""):
|
|
@@ -22,7 +22,9 @@ async def build_graph(tool_registry: ToolRegistry, instructions: str = ""):
|
|
|
22
22
|
tools_list = []
|
|
23
23
|
if app_ids is not None:
|
|
24
24
|
for app_id in app_ids:
|
|
25
|
-
tools_list.extend(
|
|
25
|
+
tools_list.extend(
|
|
26
|
+
await tool_registry.search_tools(query, limit=10, app_id=app_id)
|
|
27
|
+
)
|
|
26
28
|
else:
|
|
27
29
|
tools_list = await tool_registry.search_tools(query, limit=10)
|
|
28
30
|
tools_list = [f"{tool['id']}: {tool['description']}" for tool in tools_list]
|
|
@@ -48,21 +50,33 @@ async def build_graph(tool_registry: ToolRegistry, instructions: str = ""):
|
|
|
48
50
|
connections = await tool_registry.list_connected_apps()
|
|
49
51
|
connection_ids = set([connection["app_id"] for connection in connections])
|
|
50
52
|
connected_apps = [app["id"] for app in app_ids if app["id"] in connection_ids]
|
|
51
|
-
unconnected_apps = [
|
|
52
|
-
|
|
53
|
-
|
|
53
|
+
unconnected_apps = [
|
|
54
|
+
app["id"] for app in app_ids if app["id"] not in connection_ids
|
|
55
|
+
]
|
|
56
|
+
app_id_descriptions = (
|
|
57
|
+
"These are the apps connected to the user's account:\n"
|
|
58
|
+
+ "\n".join([f"{app}" for app in connected_apps])
|
|
54
59
|
)
|
|
55
60
|
if unconnected_apps:
|
|
56
61
|
app_id_descriptions += "\n\nOther (not connected) apps: " + "\n".join(
|
|
57
62
|
[f"{app}" for app in unconnected_apps]
|
|
58
63
|
)
|
|
59
64
|
|
|
60
|
-
system_prompt = system_prompt.format(
|
|
65
|
+
system_prompt = system_prompt.format(
|
|
66
|
+
system_time=datetime.now(tz=UTC).isoformat(), app_ids=app_id_descriptions
|
|
67
|
+
)
|
|
61
68
|
|
|
62
|
-
messages = [
|
|
69
|
+
messages = [
|
|
70
|
+
{"role": "system", "content": system_prompt + "\n" + instructions},
|
|
71
|
+
*state["messages"],
|
|
72
|
+
]
|
|
63
73
|
model = load_chat_model(runtime.context.model)
|
|
64
|
-
loaded_tools = await tool_registry.export_tools(
|
|
65
|
-
|
|
74
|
+
loaded_tools = await tool_registry.export_tools(
|
|
75
|
+
tools=state["selected_tool_ids"], format=ToolFormat.LANGCHAIN
|
|
76
|
+
)
|
|
77
|
+
model_with_tools = model.bind_tools(
|
|
78
|
+
[search_tools, ask_user, load_tools, *loaded_tools], tool_choice="auto"
|
|
79
|
+
)
|
|
66
80
|
response_raw = model_with_tools.invoke(messages)
|
|
67
81
|
response = cast(AIMessage, response_raw)
|
|
68
82
|
return {"messages": [response]}
|
|
@@ -102,7 +116,8 @@ async def build_graph(tool_registry: ToolRegistry, instructions: str = ""):
|
|
|
102
116
|
tools = await search_tools.ainvoke(tool_call["args"])
|
|
103
117
|
outputs.append(
|
|
104
118
|
ToolMessage(
|
|
105
|
-
content=json.dumps(tools)
|
|
119
|
+
content=json.dumps(tools)
|
|
120
|
+
+ "\n\nUse the load_tools tool to load the tools you want to use.",
|
|
106
121
|
name=tool_call["name"],
|
|
107
122
|
tool_call_id=tool_call["id"],
|
|
108
123
|
)
|
|
@@ -119,9 +134,13 @@ async def build_graph(tool_registry: ToolRegistry, instructions: str = ""):
|
|
|
119
134
|
)
|
|
120
135
|
)
|
|
121
136
|
else:
|
|
122
|
-
await tool_registry.export_tools(
|
|
137
|
+
await tool_registry.export_tools(
|
|
138
|
+
[tool_call["name"]], ToolFormat.LANGCHAIN
|
|
139
|
+
)
|
|
123
140
|
try:
|
|
124
|
-
tool_result = await tool_registry.call_tool(
|
|
141
|
+
tool_result = await tool_registry.call_tool(
|
|
142
|
+
tool_call["name"], tool_call["args"]
|
|
143
|
+
)
|
|
125
144
|
outputs.append(
|
|
126
145
|
ToolMessage(
|
|
127
146
|
content=json.dumps(tool_result),
|
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
|
|
3
3
|
from universal_mcp.agentr.registry import AgentrRegistry
|
|
4
|
-
from universal_mcp.agents.autoagent import build_graph
|
|
5
4
|
from universal_mcp.tools import ToolManager
|
|
6
5
|
|
|
6
|
+
from universal_mcp.agents.autoagent import build_graph
|
|
7
|
+
|
|
7
8
|
tool_registry = AgentrRegistry()
|
|
8
9
|
tool_manager = ToolManager()
|
|
9
10
|
|
|
10
11
|
|
|
11
|
-
|
|
12
12
|
async def main():
|
|
13
13
|
instructions = """
|
|
14
14
|
You are a helpful assistant that can use tools to help the user. If a task requires multiple steps, you should perform separate different searches for different actions. Prefer completing one action before searching for another.
|
|
@@ -16,10 +16,5 @@ async def main():
|
|
|
16
16
|
graph = await build_graph(tool_registry, instructions=instructions)
|
|
17
17
|
return graph
|
|
18
18
|
|
|
19
|
-
graph = asyncio.run(main())
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
19
|
|
|
20
|
+
graph = asyncio.run(main())
|
universal_mcp/agents/base.py
CHANGED
|
@@ -2,15 +2,24 @@
|
|
|
2
2
|
from typing import cast
|
|
3
3
|
from uuid import uuid4
|
|
4
4
|
|
|
5
|
-
from langchain_core.messages import AIMessageChunk
|
|
5
|
+
from langchain_core.messages import AIMessage, AIMessageChunk
|
|
6
6
|
from langgraph.checkpoint.base import BaseCheckpointSaver
|
|
7
|
+
from langgraph.graph import StateGraph
|
|
7
8
|
from langgraph.types import Command
|
|
9
|
+
from universal_mcp.logger import logger
|
|
8
10
|
|
|
9
11
|
from .utils import RichCLI
|
|
10
12
|
|
|
11
13
|
|
|
12
14
|
class BaseAgent:
|
|
13
|
-
def __init__(
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
name: str,
|
|
18
|
+
instructions: str,
|
|
19
|
+
model: str,
|
|
20
|
+
memory: BaseCheckpointSaver | None = None,
|
|
21
|
+
**kwargs,
|
|
22
|
+
):
|
|
14
23
|
self.name = name
|
|
15
24
|
self.instructions = instructions
|
|
16
25
|
self.model = model
|
|
@@ -24,56 +33,96 @@ class BaseAgent:
|
|
|
24
33
|
self._graph = await self._build_graph()
|
|
25
34
|
self._initialized = True
|
|
26
35
|
|
|
27
|
-
async def _build_graph(self):
|
|
36
|
+
async def _build_graph(self) -> StateGraph:
|
|
28
37
|
raise NotImplementedError("Subclasses must implement this method")
|
|
29
38
|
|
|
30
|
-
async def stream(self, thread_id: str, user_input: str):
|
|
39
|
+
async def stream(self, thread_id: str, user_input: str, metadata: dict = None):
|
|
31
40
|
await self.ainit()
|
|
32
41
|
aggregate = None
|
|
42
|
+
|
|
43
|
+
run_metadata = {
|
|
44
|
+
"agent_name": self.name,
|
|
45
|
+
"is_background_run": False, # Default to False
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
if metadata:
|
|
49
|
+
run_metadata.update(metadata)
|
|
50
|
+
|
|
51
|
+
run_config = {
|
|
52
|
+
"configurable": {"thread_id": thread_id},
|
|
53
|
+
"metadata": run_metadata,
|
|
54
|
+
}
|
|
55
|
+
|
|
33
56
|
async for event, metadata in self._graph.astream(
|
|
34
57
|
{"messages": [{"role": "user", "content": user_input}]},
|
|
35
|
-
config=
|
|
58
|
+
config=run_config,
|
|
36
59
|
context={"system_prompt": self.instructions, "model": self.model},
|
|
37
60
|
stream_mode="messages",
|
|
38
61
|
stream_usage=True,
|
|
39
62
|
):
|
|
40
63
|
# Only forward assistant token chunks that are not tool-related.
|
|
41
64
|
type_ = type(event)
|
|
42
|
-
if type_ != AIMessageChunk:
|
|
43
|
-
continue
|
|
44
|
-
event = cast(AIMessageChunk, event)
|
|
45
|
-
aggregate = event if aggregate is None else aggregate + event
|
|
46
65
|
tags = metadata.get("tags", []) if isinstance(metadata, dict) else []
|
|
47
66
|
is_quiet = isinstance(tags, list) and ("quiet" in tags)
|
|
48
|
-
|
|
49
67
|
if is_quiet:
|
|
50
68
|
continue
|
|
69
|
+
# Handle different types of messages
|
|
70
|
+
if type_ in (AIMessage, AIMessageChunk):
|
|
71
|
+
# Accumulate billing and aggregate message
|
|
72
|
+
aggregate = event if aggregate is None else aggregate + event
|
|
73
|
+
# Ignore intermeddite finish messages
|
|
51
74
|
if "finish_reason" in event.response_metadata:
|
|
52
75
|
# Got LLM finish reason ignore it
|
|
53
|
-
|
|
76
|
+
logger.debug(f"Finish event: {event}, Metadata: {metadata}")
|
|
54
77
|
pass
|
|
55
78
|
else:
|
|
56
|
-
|
|
79
|
+
logger.debug(f"Event: {event}, Metadata: {metadata}")
|
|
57
80
|
yield event
|
|
58
81
|
# Send a final finished message
|
|
59
82
|
# The last event would be finish
|
|
60
83
|
event = cast(AIMessageChunk, event)
|
|
84
|
+
event.usage_metadata = aggregate.usage_metadata
|
|
85
|
+
logger.debug(f"Usage metadata: {event.usage_metadata}")
|
|
61
86
|
yield event
|
|
62
87
|
|
|
63
88
|
async def stream_interactive(self, thread_id: str, user_input: str):
|
|
64
89
|
await self.ainit()
|
|
65
90
|
with self.cli.display_agent_response_streaming(self.name) as stream_updater:
|
|
66
91
|
async for event in self.stream(thread_id, user_input):
|
|
67
|
-
stream_updater.update(event.content)
|
|
68
92
|
|
|
69
|
-
|
|
93
|
+
if isinstance(event.content, list):
|
|
94
|
+
thinking_content = "".join([c.get("thinking", "") for c in event.content])
|
|
95
|
+
stream_updater.update(thinking_content, type_="thinking")
|
|
96
|
+
content = "".join([c.get("text", "") for c in event.content])
|
|
97
|
+
stream_updater.update(content, type_="text")
|
|
98
|
+
else:
|
|
99
|
+
stream_updater.update(event.content, type_="text")
|
|
100
|
+
|
|
101
|
+
async def invoke(
|
|
102
|
+
self, user_input: str, thread_id: str = str(uuid4()), metadata: dict = None
|
|
103
|
+
):
|
|
70
104
|
"""Run the agent"""
|
|
71
105
|
await self.ainit()
|
|
72
|
-
|
|
106
|
+
|
|
107
|
+
run_metadata = {
|
|
108
|
+
"agent_name": self.name,
|
|
109
|
+
"is_background_run": False, # Default to False
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
if metadata:
|
|
113
|
+
run_metadata.update(metadata)
|
|
114
|
+
|
|
115
|
+
run_config = {
|
|
116
|
+
"configurable": {"thread_id": thread_id},
|
|
117
|
+
"metadata": run_metadata,
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
result = await self._graph.ainvoke(
|
|
73
121
|
{"messages": [{"role": "user", "content": user_input}]},
|
|
74
|
-
config=
|
|
122
|
+
config=run_config,
|
|
75
123
|
context={"system_prompt": self.instructions, "model": self.model},
|
|
76
124
|
)
|
|
125
|
+
return result
|
|
77
126
|
|
|
78
127
|
async def run_interactive(self, thread_id: str = str(uuid4())):
|
|
79
128
|
"""Main application loop"""
|
|
@@ -85,10 +134,15 @@ class BaseAgent:
|
|
|
85
134
|
# Main loop
|
|
86
135
|
while True:
|
|
87
136
|
try:
|
|
88
|
-
state = self._graph.get_state(
|
|
137
|
+
state = self._graph.get_state(
|
|
138
|
+
config={"configurable": {"thread_id": thread_id}}
|
|
139
|
+
)
|
|
89
140
|
if state.interrupts:
|
|
90
141
|
value = self.cli.handle_interrupt(state.interrupts[0])
|
|
91
|
-
self._graph.invoke(
|
|
142
|
+
self._graph.invoke(
|
|
143
|
+
Command(resume=value),
|
|
144
|
+
config={"configurable": {"thread_id": thread_id}},
|
|
145
|
+
)
|
|
92
146
|
continue
|
|
93
147
|
|
|
94
148
|
user_input = self.cli.get_user_input()
|
|
@@ -99,9 +153,11 @@ class BaseAgent:
|
|
|
99
153
|
if user_input.startswith("/"):
|
|
100
154
|
command = user_input.lower().lstrip("/")
|
|
101
155
|
if command == "about":
|
|
102
|
-
self.cli.display_info(
|
|
156
|
+
self.cli.display_info(
|
|
157
|
+
f"Agent is {self.name}. {self.instructions}"
|
|
158
|
+
)
|
|
103
159
|
continue
|
|
104
|
-
elif command
|
|
160
|
+
elif command in {"exit", "quit", "q"}:
|
|
105
161
|
self.cli.display_info("Goodbye! 👋")
|
|
106
162
|
break
|
|
107
163
|
elif command == "reset":
|
|
@@ -110,7 +166,9 @@ class BaseAgent:
|
|
|
110
166
|
thread_id = str(uuid4())
|
|
111
167
|
continue
|
|
112
168
|
elif command == "help":
|
|
113
|
-
self.cli.display_info(
|
|
169
|
+
self.cli.display_info(
|
|
170
|
+
"Available commands: /about, /exit, /quit, /q, /reset"
|
|
171
|
+
)
|
|
114
172
|
continue
|
|
115
173
|
else:
|
|
116
174
|
self.cli.display_error(f"Unknown command: {command}")
|
|
@@ -124,6 +182,6 @@ class BaseAgent:
|
|
|
124
182
|
break
|
|
125
183
|
except Exception as e:
|
|
126
184
|
import traceback
|
|
127
|
-
|
|
128
185
|
traceback.print_exc()
|
|
129
186
|
self.cli.display_error(f"An error occurred: {str(e)}")
|
|
187
|
+
break
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
from langgraph.checkpoint.base import BaseCheckpointSaver
|
|
2
|
+
from universal_mcp.logger import logger
|
|
3
|
+
from universal_mcp.tools.registry import ToolRegistry
|
|
2
4
|
|
|
3
5
|
from universal_mcp.agents.base import BaseAgent
|
|
4
6
|
from universal_mcp.agents.llm import load_chat_model
|
|
5
|
-
from universal_mcp.logger import logger
|
|
6
|
-
from universal_mcp.tools.registry import ToolRegistry
|
|
7
7
|
|
|
8
8
|
from .graph import build_graph
|
|
9
9
|
from .prompts import SYSTEM_PROMPT
|
|
@@ -19,15 +19,19 @@ class BigToolAgent(BaseAgent):
|
|
|
19
19
|
memory: BaseCheckpointSaver | None = None,
|
|
20
20
|
**kwargs,
|
|
21
21
|
):
|
|
22
|
-
|
|
23
|
-
full_instructions = f"{SYSTEM_PROMPT}\n\n**User Instructions:**\n{instructions}"
|
|
24
|
-
super().__init__(name, full_instructions, model, memory, **kwargs)
|
|
25
|
-
|
|
22
|
+
super().__init__(name, instructions, model, memory, **kwargs)
|
|
26
23
|
self.registry = registry
|
|
27
24
|
self.llm = load_chat_model(self.model)
|
|
28
|
-
self.tool_selection_llm = load_chat_model("gemini/gemini-2.0-flash-001")
|
|
29
25
|
|
|
30
|
-
logger.info(
|
|
26
|
+
logger.info(
|
|
27
|
+
f"BigToolAgent '{self.name}' initialized with model '{self.model}'."
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
def _build_system_message(self):
|
|
31
|
+
return SYSTEM_PROMPT.format(
|
|
32
|
+
name=self.name,
|
|
33
|
+
instructions=self.instructions,
|
|
34
|
+
)
|
|
31
35
|
|
|
32
36
|
async def _build_graph(self):
|
|
33
37
|
"""Build the bigtool agent graph using the existing create_agent function."""
|
|
@@ -36,7 +40,7 @@ class BigToolAgent(BaseAgent):
|
|
|
36
40
|
graph_builder = build_graph(
|
|
37
41
|
tool_registry=self.registry,
|
|
38
42
|
llm=self.llm,
|
|
39
|
-
|
|
43
|
+
system_prompt=self._build_system_message(),
|
|
40
44
|
)
|
|
41
45
|
|
|
42
46
|
compiled_graph = graph_builder.compile(checkpointer=self.memory)
|
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
|
|
3
3
|
from loguru import logger
|
|
4
|
-
|
|
5
4
|
from universal_mcp.agentr.registry import AgentrRegistry
|
|
5
|
+
|
|
6
6
|
from universal_mcp.agents.bigtool import BigToolAgent
|
|
7
|
+
from universal_mcp.agents.utils import messages_to_list
|
|
7
8
|
|
|
8
9
|
|
|
9
10
|
async def main():
|
|
@@ -13,12 +14,10 @@ async def main():
|
|
|
13
14
|
model="azure/gpt-4.1",
|
|
14
15
|
registry=AgentrRegistry(),
|
|
15
16
|
)
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
)
|
|
20
|
-
logger.info(event.content)
|
|
21
|
-
|
|
17
|
+
await agent.ainit()
|
|
18
|
+
output = await agent.invoke(
|
|
19
|
+
user_input="Send an email to manoj@agentr.dev")
|
|
20
|
+
logger.info(messages_to_list(output["messages"]))
|
|
22
21
|
|
|
23
22
|
if __name__ == "__main__":
|
|
24
23
|
asyncio.run(main())
|
|
@@ -1,28 +1,24 @@
|
|
|
1
1
|
import json
|
|
2
|
-
from datetime import UTC, datetime
|
|
3
2
|
from typing import Literal, TypedDict, cast
|
|
4
3
|
|
|
5
|
-
from langchain_anthropic import ChatAnthropic
|
|
6
4
|
from langchain_core.language_models import BaseChatModel
|
|
7
5
|
from langchain_core.messages import AIMessage, ToolMessage
|
|
8
6
|
from langchain_core.tools import tool
|
|
9
7
|
from langgraph.graph import StateGraph
|
|
10
|
-
from langgraph.runtime import Runtime
|
|
11
8
|
from langgraph.types import Command
|
|
12
|
-
|
|
13
|
-
from universal_mcp.agents.bigtool.context import Context
|
|
14
|
-
from universal_mcp.agents.bigtool.state import State
|
|
15
9
|
from universal_mcp.logger import logger
|
|
16
10
|
from universal_mcp.tools.registry import ToolRegistry
|
|
17
11
|
from universal_mcp.types import ToolFormat
|
|
18
12
|
|
|
13
|
+
from universal_mcp.agents.bigtool.state import State
|
|
14
|
+
|
|
19
15
|
from .prompts import SELECT_TOOL_PROMPT
|
|
20
16
|
|
|
21
17
|
|
|
22
18
|
def build_graph(
|
|
23
19
|
tool_registry: ToolRegistry,
|
|
24
20
|
llm: BaseChatModel,
|
|
25
|
-
|
|
21
|
+
system_prompt: str,
|
|
26
22
|
):
|
|
27
23
|
@tool
|
|
28
24
|
async def retrieve_tools(task_query: str) -> list[str]:
|
|
@@ -32,29 +28,40 @@ def build_graph(
|
|
|
32
28
|
logger.info(f"Retrieving tools for task: '{task_query}'")
|
|
33
29
|
try:
|
|
34
30
|
tools_list = await tool_registry.search_tools(task_query, limit=10)
|
|
35
|
-
tool_candidates = [
|
|
31
|
+
tool_candidates = [
|
|
32
|
+
f"{tool['id']}: {tool['description']}" for tool in tools_list
|
|
33
|
+
]
|
|
36
34
|
logger.info(f"Found {len(tool_candidates)} candidate tools.")
|
|
37
35
|
|
|
38
36
|
class ToolSelectionOutput(TypedDict):
|
|
39
37
|
tool_names: list[str]
|
|
40
38
|
|
|
41
|
-
model =
|
|
39
|
+
model = llm
|
|
42
40
|
app_ids = await tool_registry.list_all_apps()
|
|
43
41
|
connections = await tool_registry.list_connected_apps()
|
|
44
42
|
connection_ids = set([connection["app_id"] for connection in connections])
|
|
45
|
-
connected_apps = [
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
43
|
+
connected_apps = [
|
|
44
|
+
app["id"] for app in app_ids if app["id"] in connection_ids
|
|
45
|
+
]
|
|
46
|
+
unconnected_apps = [
|
|
47
|
+
app["id"] for app in app_ids if app["id"] not in connection_ids
|
|
48
|
+
]
|
|
49
|
+
app_id_descriptions = (
|
|
50
|
+
"These are the apps connected to the user's account:\n"
|
|
51
|
+
+ "\n".join([f"{app}" for app in connected_apps])
|
|
49
52
|
)
|
|
50
53
|
if unconnected_apps:
|
|
51
54
|
app_id_descriptions += "\n\nOther (not connected) apps: " + "\n".join(
|
|
52
55
|
[f"{app}" for app in unconnected_apps]
|
|
53
56
|
)
|
|
54
57
|
|
|
55
|
-
response = await model.with_structured_output(
|
|
58
|
+
response = await model.with_structured_output(
|
|
59
|
+
schema=ToolSelectionOutput, method="json_mode"
|
|
60
|
+
).ainvoke(
|
|
56
61
|
SELECT_TOOL_PROMPT.format(
|
|
57
|
-
app_ids=app_id_descriptions,
|
|
62
|
+
app_ids=app_id_descriptions,
|
|
63
|
+
tool_candidates="\n - ".join(tool_candidates),
|
|
64
|
+
task=task_query,
|
|
58
65
|
)
|
|
59
66
|
)
|
|
60
67
|
|
|
@@ -65,45 +72,63 @@ def build_graph(
|
|
|
65
72
|
logger.error(f"Error retrieving tools: {e}")
|
|
66
73
|
return []
|
|
67
74
|
|
|
68
|
-
|
|
75
|
+
|
|
76
|
+
async def call_model(
|
|
77
|
+
state: State
|
|
78
|
+
) -> Command[Literal["select_tools", "call_tools"]]:
|
|
69
79
|
logger.info("Calling model...")
|
|
70
80
|
try:
|
|
71
|
-
|
|
72
|
-
|
|
81
|
+
messages = [
|
|
82
|
+
{"role": "system", "content": system_prompt},
|
|
83
|
+
*state["messages"],
|
|
84
|
+
]
|
|
73
85
|
|
|
74
86
|
logger.info(f"Selected tool IDs: {state['selected_tool_ids']}")
|
|
75
87
|
if len(state["selected_tool_ids"]) > 0:
|
|
76
|
-
selected_tools = await tool_registry.export_tools(
|
|
88
|
+
selected_tools = await tool_registry.export_tools(
|
|
89
|
+
tools=state["selected_tool_ids"], format=ToolFormat.LANGCHAIN
|
|
90
|
+
)
|
|
77
91
|
logger.info(f"Exported {len(selected_tools)} tools for model.")
|
|
78
92
|
else:
|
|
79
93
|
selected_tools = []
|
|
80
94
|
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
95
|
+
model_with_tools = llm.bind_tools(
|
|
96
|
+
[retrieve_tools, *selected_tools], tool_choice="auto"
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
response = await model_with_tools.ainvoke(messages)
|
|
101
|
+
cast(AIMessage, response)
|
|
102
|
+
logger.debug(f"Response: {response}")
|
|
103
|
+
|
|
89
104
|
|
|
90
105
|
if response.tool_calls:
|
|
91
|
-
logger.info(
|
|
106
|
+
logger.info(
|
|
107
|
+
f"Model responded with {len(response.tool_calls)} tool calls."
|
|
108
|
+
)
|
|
92
109
|
if len(response.tool_calls) > 1:
|
|
93
|
-
raise Exception(
|
|
110
|
+
raise Exception(
|
|
111
|
+
"Not possible in Claude with llm.bind_tools(tools=tools, tool_choice='auto')"
|
|
112
|
+
)
|
|
94
113
|
tool_call = response.tool_calls[0]
|
|
95
114
|
if tool_call["name"] == retrieve_tools.name:
|
|
96
115
|
logger.info("Model requested to select tools.")
|
|
97
116
|
return Command(goto="select_tools", update={"messages": [response]})
|
|
98
117
|
elif tool_call["name"] not in state["selected_tool_ids"]:
|
|
99
118
|
try:
|
|
100
|
-
await tool_registry.export_tools(
|
|
119
|
+
await tool_registry.export_tools(
|
|
120
|
+
[tool_call["name"]], ToolFormat.LANGCHAIN
|
|
121
|
+
)
|
|
101
122
|
logger.info(
|
|
102
123
|
f"Tool '{tool_call['name']}' not in selected tools, but available. Proceeding to call."
|
|
103
124
|
)
|
|
104
|
-
return Command(
|
|
125
|
+
return Command(
|
|
126
|
+
goto="call_tools", update={"messages": [response]}
|
|
127
|
+
)
|
|
105
128
|
except Exception as e:
|
|
106
|
-
logger.error(
|
|
129
|
+
logger.error(
|
|
130
|
+
f"Unexpected tool call: {tool_call['name']}. Error: {e}"
|
|
131
|
+
)
|
|
107
132
|
raise Exception(
|
|
108
133
|
f"Unexpected tool call: {tool_call['name']}. Available tools: {state['selected_tool_ids']}"
|
|
109
134
|
) from e
|
|
@@ -116,14 +141,24 @@ def build_graph(
|
|
|
116
141
|
logger.error(f"Error in call_model: {e}")
|
|
117
142
|
raise
|
|
118
143
|
|
|
119
|
-
async def select_tools(
|
|
144
|
+
async def select_tools(
|
|
145
|
+
state: State
|
|
146
|
+
) -> Command[Literal["call_model"]]:
|
|
120
147
|
logger.info("Selecting tools...")
|
|
121
148
|
try:
|
|
122
149
|
tool_call = state["messages"][-1].tool_calls[0]
|
|
123
150
|
selected_tool_names = await retrieve_tools.ainvoke(input=tool_call["args"])
|
|
124
|
-
tool_msg = ToolMessage(
|
|
151
|
+
tool_msg = ToolMessage(
|
|
152
|
+
f"Available tools: {selected_tool_names}", tool_call_id=tool_call["id"]
|
|
153
|
+
)
|
|
125
154
|
logger.info(f"Tools selected: {selected_tool_names}")
|
|
126
|
-
return Command(
|
|
155
|
+
return Command(
|
|
156
|
+
goto="call_model",
|
|
157
|
+
update={
|
|
158
|
+
"messages": [tool_msg],
|
|
159
|
+
"selected_tool_ids": selected_tool_names,
|
|
160
|
+
},
|
|
161
|
+
)
|
|
127
162
|
except Exception as e:
|
|
128
163
|
logger.error(f"Error in select_tools: {e}")
|
|
129
164
|
raise
|
|
@@ -133,10 +168,16 @@ def build_graph(
|
|
|
133
168
|
outputs = []
|
|
134
169
|
recent_tool_ids = []
|
|
135
170
|
for tool_call in state["messages"][-1].tool_calls:
|
|
136
|
-
logger.info(
|
|
171
|
+
logger.info(
|
|
172
|
+
f"Executing tool: {tool_call['name']} with args: {tool_call['args']}"
|
|
173
|
+
)
|
|
137
174
|
try:
|
|
138
|
-
await tool_registry.export_tools(
|
|
139
|
-
|
|
175
|
+
await tool_registry.export_tools(
|
|
176
|
+
[tool_call["name"]], ToolFormat.LANGCHAIN
|
|
177
|
+
)
|
|
178
|
+
tool_result = await tool_registry.call_tool(
|
|
179
|
+
tool_call["name"], tool_call["args"]
|
|
180
|
+
)
|
|
140
181
|
logger.info(f"Tool '{tool_call['name']}' executed successfully.")
|
|
141
182
|
outputs.append(
|
|
142
183
|
ToolMessage(
|
|
@@ -155,9 +196,12 @@ def build_graph(
|
|
|
155
196
|
tool_call_id=tool_call["id"],
|
|
156
197
|
)
|
|
157
198
|
)
|
|
158
|
-
return Command(
|
|
199
|
+
return Command(
|
|
200
|
+
goto="call_model",
|
|
201
|
+
update={"messages": outputs, "selected_tool_ids": recent_tool_ids},
|
|
202
|
+
)
|
|
159
203
|
|
|
160
|
-
builder = StateGraph(State
|
|
204
|
+
builder = StateGraph(State)
|
|
161
205
|
|
|
162
206
|
builder.add_node(call_model)
|
|
163
207
|
builder.add_node(select_tools)
|