universal-mcp-agents 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. universal_mcp/agents/__init__.py +19 -0
  2. universal_mcp/agents/autoagent/__init__.py +1 -1
  3. universal_mcp/agents/autoagent/__main__.py +1 -1
  4. universal_mcp/agents/autoagent/graph.py +32 -13
  5. universal_mcp/agents/autoagent/studio.py +3 -8
  6. universal_mcp/agents/base.py +80 -22
  7. universal_mcp/agents/bigtool/__init__.py +13 -9
  8. universal_mcp/agents/bigtool/__main__.py +6 -7
  9. universal_mcp/agents/bigtool/graph.py +84 -40
  10. universal_mcp/agents/bigtool/prompts.py +3 -3
  11. universal_mcp/agents/bigtool2/__init__.py +16 -6
  12. universal_mcp/agents/bigtool2/__main__.py +7 -6
  13. universal_mcp/agents/bigtool2/agent.py +4 -2
  14. universal_mcp/agents/bigtool2/graph.py +78 -36
  15. universal_mcp/agents/bigtool2/prompts.py +1 -1
  16. universal_mcp/agents/bigtoolcache/__init__.py +8 -4
  17. universal_mcp/agents/bigtoolcache/__main__.py +1 -1
  18. universal_mcp/agents/bigtoolcache/agent.py +5 -3
  19. universal_mcp/agents/bigtoolcache/context.py +0 -1
  20. universal_mcp/agents/bigtoolcache/graph.py +99 -69
  21. universal_mcp/agents/bigtoolcache/prompts.py +28 -0
  22. universal_mcp/agents/bigtoolcache/tools_all.txt +956 -0
  23. universal_mcp/agents/bigtoolcache/tools_important.txt +474 -0
  24. universal_mcp/agents/builder.py +62 -20
  25. universal_mcp/agents/cli.py +19 -5
  26. universal_mcp/agents/codeact/__init__.py +16 -4
  27. universal_mcp/agents/codeact/test.py +2 -1
  28. universal_mcp/agents/hil.py +16 -4
  29. universal_mcp/agents/llm.py +12 -4
  30. universal_mcp/agents/planner/__init__.py +14 -4
  31. universal_mcp/agents/planner/__main__.py +10 -6
  32. universal_mcp/agents/planner/graph.py +9 -3
  33. universal_mcp/agents/planner/prompts.py +14 -1
  34. universal_mcp/agents/planner/state.py +0 -1
  35. universal_mcp/agents/react.py +36 -22
  36. universal_mcp/agents/shared/tool_node.py +26 -11
  37. universal_mcp/agents/simple.py +27 -4
  38. universal_mcp/agents/tools.py +9 -4
  39. universal_mcp/agents/ui_tools.py +305 -0
  40. universal_mcp/agents/utils.py +55 -17
  41. {universal_mcp_agents-0.1.3.dist-info → universal_mcp_agents-0.1.5.dist-info}/METADATA +3 -2
  42. universal_mcp_agents-0.1.5.dist-info/RECORD +52 -0
  43. universal_mcp/agents/bigtool/context.py +0 -24
  44. universal_mcp/agents/bigtool2/context.py +0 -33
  45. universal_mcp_agents-0.1.3.dist-info/RECORD +0 -51
  46. {universal_mcp_agents-0.1.3.dist-info → universal_mcp_agents-0.1.5.dist-info}/WHEEL +0 -0
@@ -7,6 +7,25 @@ from universal_mcp.agents.planner import PlannerAgent
7
7
  from universal_mcp.agents.react import ReactAgent
8
8
  from universal_mcp.agents.simple import SimpleAgent
9
9
 
10
+
11
+ def get_agent(agent_name: str):
12
+ if agent_name == "auto":
13
+ return AutoAgent
14
+ elif agent_name == "react":
15
+ return ReactAgent
16
+ elif agent_name == "simple":
17
+ return SimpleAgent
18
+ elif agent_name == "builder":
19
+ return BuilderAgent
20
+ elif agent_name == "planner":
21
+ return PlannerAgent
22
+ elif agent_name == "bigtool":
23
+ return BigToolAgent
24
+ elif agent_name == "bigtool2":
25
+ return BigToolAgent2
26
+ else:
27
+ raise ValueError(f"Unknown agent: {agent_name}. Possible values: auto, react, simple, builder, planner, bigtool, bigtool2")
28
+
10
29
  __all__ = [
11
30
  "BaseAgent",
12
31
  "ReactAgent",
@@ -1,8 +1,8 @@
1
1
  from langgraph.checkpoint.base import BaseCheckpointSaver
2
+ from universal_mcp.tools.registry import ToolRegistry
2
3
 
3
4
  from universal_mcp.agents.autoagent.graph import build_graph
4
5
  from universal_mcp.agents.base import BaseAgent
5
- from universal_mcp.tools.registry import ToolRegistry
6
6
 
7
7
 
8
8
  class AutoAgent(BaseAgent):
@@ -1,8 +1,8 @@
1
1
  import asyncio
2
2
 
3
3
  from loguru import logger
4
-
5
4
  from universal_mcp.agentr.registry import AgentrRegistry
5
+
6
6
  from universal_mcp.agents.autoagent import AutoAgent
7
7
 
8
8
 
@@ -6,13 +6,13 @@ from langchain_core.messages import AIMessage, ToolMessage
6
6
  from langchain_core.tools import tool
7
7
  from langgraph.graph import END, START, StateGraph
8
8
  from langgraph.runtime import Runtime
9
+ from universal_mcp.tools.registry import ToolRegistry
10
+ from universal_mcp.types import ToolFormat
9
11
 
10
12
  from universal_mcp.agents.autoagent.context import Context
11
13
  from universal_mcp.agents.autoagent.prompts import SYSTEM_PROMPT
12
14
  from universal_mcp.agents.autoagent.state import State
13
15
  from universal_mcp.agents.llm import load_chat_model
14
- from universal_mcp.tools.registry import ToolRegistry
15
- from universal_mcp.types import ToolFormat
16
16
 
17
17
 
18
18
  async def build_graph(tool_registry: ToolRegistry, instructions: str = ""):
@@ -22,7 +22,9 @@ async def build_graph(tool_registry: ToolRegistry, instructions: str = ""):
22
22
  tools_list = []
23
23
  if app_ids is not None:
24
24
  for app_id in app_ids:
25
- tools_list.extend(await tool_registry.search_tools(query, limit=10, app_id=app_id))
25
+ tools_list.extend(
26
+ await tool_registry.search_tools(query, limit=10, app_id=app_id)
27
+ )
26
28
  else:
27
29
  tools_list = await tool_registry.search_tools(query, limit=10)
28
30
  tools_list = [f"{tool['id']}: {tool['description']}" for tool in tools_list]
@@ -48,21 +50,33 @@ async def build_graph(tool_registry: ToolRegistry, instructions: str = ""):
48
50
  connections = await tool_registry.list_connected_apps()
49
51
  connection_ids = set([connection["app_id"] for connection in connections])
50
52
  connected_apps = [app["id"] for app in app_ids if app["id"] in connection_ids]
51
- unconnected_apps = [app["id"] for app in app_ids if app["id"] not in connection_ids]
52
- app_id_descriptions = "These are the apps connected to the user's account:\n" + "\n".join(
53
- [f"{app}" for app in connected_apps]
53
+ unconnected_apps = [
54
+ app["id"] for app in app_ids if app["id"] not in connection_ids
55
+ ]
56
+ app_id_descriptions = (
57
+ "These are the apps connected to the user's account:\n"
58
+ + "\n".join([f"{app}" for app in connected_apps])
54
59
  )
55
60
  if unconnected_apps:
56
61
  app_id_descriptions += "\n\nOther (not connected) apps: " + "\n".join(
57
62
  [f"{app}" for app in unconnected_apps]
58
63
  )
59
64
 
60
- system_prompt = system_prompt.format(system_time=datetime.now(tz=UTC).isoformat(), app_ids=app_id_descriptions)
65
+ system_prompt = system_prompt.format(
66
+ system_time=datetime.now(tz=UTC).isoformat(), app_ids=app_id_descriptions
67
+ )
61
68
 
62
- messages = [{"role": "system", "content": system_prompt + "\n" + instructions}, *state["messages"]]
69
+ messages = [
70
+ {"role": "system", "content": system_prompt + "\n" + instructions},
71
+ *state["messages"],
72
+ ]
63
73
  model = load_chat_model(runtime.context.model)
64
- loaded_tools = await tool_registry.export_tools(tools=state["selected_tool_ids"], format=ToolFormat.LANGCHAIN)
65
- model_with_tools = model.bind_tools([search_tools, ask_user, load_tools, *loaded_tools], tool_choice="auto")
74
+ loaded_tools = await tool_registry.export_tools(
75
+ tools=state["selected_tool_ids"], format=ToolFormat.LANGCHAIN
76
+ )
77
+ model_with_tools = model.bind_tools(
78
+ [search_tools, ask_user, load_tools, *loaded_tools], tool_choice="auto"
79
+ )
66
80
  response_raw = model_with_tools.invoke(messages)
67
81
  response = cast(AIMessage, response_raw)
68
82
  return {"messages": [response]}
@@ -102,7 +116,8 @@ async def build_graph(tool_registry: ToolRegistry, instructions: str = ""):
102
116
  tools = await search_tools.ainvoke(tool_call["args"])
103
117
  outputs.append(
104
118
  ToolMessage(
105
- content=json.dumps(tools) + "\n\nUse the load_tools tool to load the tools you want to use.",
119
+ content=json.dumps(tools)
120
+ + "\n\nUse the load_tools tool to load the tools you want to use.",
106
121
  name=tool_call["name"],
107
122
  tool_call_id=tool_call["id"],
108
123
  )
@@ -119,9 +134,13 @@ async def build_graph(tool_registry: ToolRegistry, instructions: str = ""):
119
134
  )
120
135
  )
121
136
  else:
122
- await tool_registry.export_tools([tool_call["name"]], ToolFormat.LANGCHAIN)
137
+ await tool_registry.export_tools(
138
+ [tool_call["name"]], ToolFormat.LANGCHAIN
139
+ )
123
140
  try:
124
- tool_result = await tool_registry.call_tool(tool_call["name"], tool_call["args"])
141
+ tool_result = await tool_registry.call_tool(
142
+ tool_call["name"], tool_call["args"]
143
+ )
125
144
  outputs.append(
126
145
  ToolMessage(
127
146
  content=json.dumps(tool_result),
@@ -1,14 +1,14 @@
1
1
  import asyncio
2
2
 
3
3
  from universal_mcp.agentr.registry import AgentrRegistry
4
- from universal_mcp.agents.autoagent import build_graph
5
4
  from universal_mcp.tools import ToolManager
6
5
 
6
+ from universal_mcp.agents.autoagent import build_graph
7
+
7
8
  tool_registry = AgentrRegistry()
8
9
  tool_manager = ToolManager()
9
10
 
10
11
 
11
-
12
12
  async def main():
13
13
  instructions = """
14
14
  You are a helpful assistant that can use tools to help the user. If a task requires multiple steps, you should perform separate different searches for different actions. Prefer completing one action before searching for another.
@@ -16,10 +16,5 @@ async def main():
16
16
  graph = await build_graph(tool_registry, instructions=instructions)
17
17
  return graph
18
18
 
19
- graph = asyncio.run(main())
20
-
21
-
22
-
23
-
24
-
25
19
 
20
+ graph = asyncio.run(main())
@@ -2,15 +2,24 @@
2
2
  from typing import cast
3
3
  from uuid import uuid4
4
4
 
5
- from langchain_core.messages import AIMessageChunk
5
+ from langchain_core.messages import AIMessage, AIMessageChunk
6
6
  from langgraph.checkpoint.base import BaseCheckpointSaver
7
+ from langgraph.graph import StateGraph
7
8
  from langgraph.types import Command
9
+ from universal_mcp.logger import logger
8
10
 
9
11
  from .utils import RichCLI
10
12
 
11
13
 
12
14
  class BaseAgent:
13
- def __init__(self, name: str, instructions: str, model: str, memory: BaseCheckpointSaver | None = None, **kwargs):
15
+ def __init__(
16
+ self,
17
+ name: str,
18
+ instructions: str,
19
+ model: str,
20
+ memory: BaseCheckpointSaver | None = None,
21
+ **kwargs,
22
+ ):
14
23
  self.name = name
15
24
  self.instructions = instructions
16
25
  self.model = model
@@ -24,56 +33,96 @@ class BaseAgent:
24
33
  self._graph = await self._build_graph()
25
34
  self._initialized = True
26
35
 
27
- async def _build_graph(self):
36
+ async def _build_graph(self) -> StateGraph:
28
37
  raise NotImplementedError("Subclasses must implement this method")
29
38
 
30
- async def stream(self, thread_id: str, user_input: str):
39
+ async def stream(self, thread_id: str, user_input: str, metadata: dict = None):
31
40
  await self.ainit()
32
41
  aggregate = None
42
+
43
+ run_metadata = {
44
+ "agent_name": self.name,
45
+ "is_background_run": False, # Default to False
46
+ }
47
+
48
+ if metadata:
49
+ run_metadata.update(metadata)
50
+
51
+ run_config = {
52
+ "configurable": {"thread_id": thread_id},
53
+ "metadata": run_metadata,
54
+ }
55
+
33
56
  async for event, metadata in self._graph.astream(
34
57
  {"messages": [{"role": "user", "content": user_input}]},
35
- config={"configurable": {"thread_id": thread_id}},
58
+ config=run_config,
36
59
  context={"system_prompt": self.instructions, "model": self.model},
37
60
  stream_mode="messages",
38
61
  stream_usage=True,
39
62
  ):
40
63
  # Only forward assistant token chunks that are not tool-related.
41
64
  type_ = type(event)
42
- if type_ != AIMessageChunk:
43
- continue
44
- event = cast(AIMessageChunk, event)
45
- aggregate = event if aggregate is None else aggregate + event
46
65
  tags = metadata.get("tags", []) if isinstance(metadata, dict) else []
47
66
  is_quiet = isinstance(tags, list) and ("quiet" in tags)
48
-
49
67
  if is_quiet:
50
68
  continue
69
+ # Handle different types of messages
70
+ if type_ in (AIMessage, AIMessageChunk):
71
+ # Accumulate billing and aggregate message
72
+ aggregate = event if aggregate is None else aggregate + event
73
+ # Ignore intermeddite finish messages
51
74
  if "finish_reason" in event.response_metadata:
52
75
  # Got LLM finish reason ignore it
53
- # logger.debug(f"Finish event: {event}, Metadata: {metadata}")
76
+ logger.debug(f"Finish event: {event}, Metadata: {metadata}")
54
77
  pass
55
78
  else:
56
- # logger.debug(f"Event: {event}, Metadata: {metadata}")
79
+ logger.debug(f"Event: {event}, Metadata: {metadata}")
57
80
  yield event
58
81
  # Send a final finished message
59
82
  # The last event would be finish
60
83
  event = cast(AIMessageChunk, event)
84
+ event.usage_metadata = aggregate.usage_metadata
85
+ logger.debug(f"Usage metadata: {event.usage_metadata}")
61
86
  yield event
62
87
 
63
88
  async def stream_interactive(self, thread_id: str, user_input: str):
64
89
  await self.ainit()
65
90
  with self.cli.display_agent_response_streaming(self.name) as stream_updater:
66
91
  async for event in self.stream(thread_id, user_input):
67
- stream_updater.update(event.content)
68
92
 
69
- async def invoke(self, user_input: str, thread_id: str = str(uuid4())):
93
+ if isinstance(event.content, list):
94
+ thinking_content = "".join([c.get("thinking", "") for c in event.content])
95
+ stream_updater.update(thinking_content, type_="thinking")
96
+ content = "".join([c.get("text", "") for c in event.content])
97
+ stream_updater.update(content, type_="text")
98
+ else:
99
+ stream_updater.update(event.content, type_="text")
100
+
101
+ async def invoke(
102
+ self, user_input: str, thread_id: str = str(uuid4()), metadata: dict = None
103
+ ):
70
104
  """Run the agent"""
71
105
  await self.ainit()
72
- return await self._graph.ainvoke(
106
+
107
+ run_metadata = {
108
+ "agent_name": self.name,
109
+ "is_background_run": False, # Default to False
110
+ }
111
+
112
+ if metadata:
113
+ run_metadata.update(metadata)
114
+
115
+ run_config = {
116
+ "configurable": {"thread_id": thread_id},
117
+ "metadata": run_metadata,
118
+ }
119
+
120
+ result = await self._graph.ainvoke(
73
121
  {"messages": [{"role": "user", "content": user_input}]},
74
- config={"configurable": {"thread_id": thread_id}},
122
+ config=run_config,
75
123
  context={"system_prompt": self.instructions, "model": self.model},
76
124
  )
125
+ return result
77
126
 
78
127
  async def run_interactive(self, thread_id: str = str(uuid4())):
79
128
  """Main application loop"""
@@ -85,10 +134,15 @@ class BaseAgent:
85
134
  # Main loop
86
135
  while True:
87
136
  try:
88
- state = self._graph.get_state(config={"configurable": {"thread_id": thread_id}})
137
+ state = self._graph.get_state(
138
+ config={"configurable": {"thread_id": thread_id}}
139
+ )
89
140
  if state.interrupts:
90
141
  value = self.cli.handle_interrupt(state.interrupts[0])
91
- self._graph.invoke(Command(resume=value), config={"configurable": {"thread_id": thread_id}})
142
+ self._graph.invoke(
143
+ Command(resume=value),
144
+ config={"configurable": {"thread_id": thread_id}},
145
+ )
92
146
  continue
93
147
 
94
148
  user_input = self.cli.get_user_input()
@@ -99,9 +153,11 @@ class BaseAgent:
99
153
  if user_input.startswith("/"):
100
154
  command = user_input.lower().lstrip("/")
101
155
  if command == "about":
102
- self.cli.display_info(f"Agent is {self.name}. {self.instructions}")
156
+ self.cli.display_info(
157
+ f"Agent is {self.name}. {self.instructions}"
158
+ )
103
159
  continue
104
- elif command == "exit" or command == "quit" or command == "q":
160
+ elif command in {"exit", "quit", "q"}:
105
161
  self.cli.display_info("Goodbye! 👋")
106
162
  break
107
163
  elif command == "reset":
@@ -110,7 +166,9 @@ class BaseAgent:
110
166
  thread_id = str(uuid4())
111
167
  continue
112
168
  elif command == "help":
113
- self.cli.display_info("Available commands: /about, /exit, /quit, /q, /reset")
169
+ self.cli.display_info(
170
+ "Available commands: /about, /exit, /quit, /q, /reset"
171
+ )
114
172
  continue
115
173
  else:
116
174
  self.cli.display_error(f"Unknown command: {command}")
@@ -124,6 +182,6 @@ class BaseAgent:
124
182
  break
125
183
  except Exception as e:
126
184
  import traceback
127
-
128
185
  traceback.print_exc()
129
186
  self.cli.display_error(f"An error occurred: {str(e)}")
187
+ break
@@ -1,9 +1,9 @@
1
1
  from langgraph.checkpoint.base import BaseCheckpointSaver
2
+ from universal_mcp.logger import logger
3
+ from universal_mcp.tools.registry import ToolRegistry
2
4
 
3
5
  from universal_mcp.agents.base import BaseAgent
4
6
  from universal_mcp.agents.llm import load_chat_model
5
- from universal_mcp.logger import logger
6
- from universal_mcp.tools.registry import ToolRegistry
7
7
 
8
8
  from .graph import build_graph
9
9
  from .prompts import SYSTEM_PROMPT
@@ -19,15 +19,19 @@ class BigToolAgent(BaseAgent):
19
19
  memory: BaseCheckpointSaver | None = None,
20
20
  **kwargs,
21
21
  ):
22
- # Combine the base system prompt with agent-specific instructions
23
- full_instructions = f"{SYSTEM_PROMPT}\n\n**User Instructions:**\n{instructions}"
24
- super().__init__(name, full_instructions, model, memory, **kwargs)
25
-
22
+ super().__init__(name, instructions, model, memory, **kwargs)
26
23
  self.registry = registry
27
24
  self.llm = load_chat_model(self.model)
28
- self.tool_selection_llm = load_chat_model("gemini/gemini-2.0-flash-001")
29
25
 
30
- logger.info(f"BigToolAgent '{self.name}' initialized with model '{self.model}'.")
26
+ logger.info(
27
+ f"BigToolAgent '{self.name}' initialized with model '{self.model}'."
28
+ )
29
+
30
+ def _build_system_message(self):
31
+ return SYSTEM_PROMPT.format(
32
+ name=self.name,
33
+ instructions=self.instructions,
34
+ )
31
35
 
32
36
  async def _build_graph(self):
33
37
  """Build the bigtool agent graph using the existing create_agent function."""
@@ -36,7 +40,7 @@ class BigToolAgent(BaseAgent):
36
40
  graph_builder = build_graph(
37
41
  tool_registry=self.registry,
38
42
  llm=self.llm,
39
- tool_selection_llm=self.tool_selection_llm,
43
+ system_prompt=self._build_system_message(),
40
44
  )
41
45
 
42
46
  compiled_graph = graph_builder.compile(checkpointer=self.memory)
@@ -1,9 +1,10 @@
1
1
  import asyncio
2
2
 
3
3
  from loguru import logger
4
-
5
4
  from universal_mcp.agentr.registry import AgentrRegistry
5
+
6
6
  from universal_mcp.agents.bigtool import BigToolAgent
7
+ from universal_mcp.agents.utils import messages_to_list
7
8
 
8
9
 
9
10
  async def main():
@@ -13,12 +14,10 @@ async def main():
13
14
  model="azure/gpt-4.1",
14
15
  registry=AgentrRegistry(),
15
16
  )
16
- async for event in agent.stream(
17
- user_input="Send an email to manoj@agentr.dev",
18
- thread_id="test123",
19
- ):
20
- logger.info(event.content)
21
-
17
+ await agent.ainit()
18
+ output = await agent.invoke(
19
+ user_input="Send an email to manoj@agentr.dev")
20
+ logger.info(messages_to_list(output["messages"]))
22
21
 
23
22
  if __name__ == "__main__":
24
23
  asyncio.run(main())
@@ -1,28 +1,24 @@
1
1
  import json
2
- from datetime import UTC, datetime
3
2
  from typing import Literal, TypedDict, cast
4
3
 
5
- from langchain_anthropic import ChatAnthropic
6
4
  from langchain_core.language_models import BaseChatModel
7
5
  from langchain_core.messages import AIMessage, ToolMessage
8
6
  from langchain_core.tools import tool
9
7
  from langgraph.graph import StateGraph
10
- from langgraph.runtime import Runtime
11
8
  from langgraph.types import Command
12
-
13
- from universal_mcp.agents.bigtool.context import Context
14
- from universal_mcp.agents.bigtool.state import State
15
9
  from universal_mcp.logger import logger
16
10
  from universal_mcp.tools.registry import ToolRegistry
17
11
  from universal_mcp.types import ToolFormat
18
12
 
13
+ from universal_mcp.agents.bigtool.state import State
14
+
19
15
  from .prompts import SELECT_TOOL_PROMPT
20
16
 
21
17
 
22
18
  def build_graph(
23
19
  tool_registry: ToolRegistry,
24
20
  llm: BaseChatModel,
25
- tool_selection_llm: BaseChatModel,
21
+ system_prompt: str,
26
22
  ):
27
23
  @tool
28
24
  async def retrieve_tools(task_query: str) -> list[str]:
@@ -32,29 +28,40 @@ def build_graph(
32
28
  logger.info(f"Retrieving tools for task: '{task_query}'")
33
29
  try:
34
30
  tools_list = await tool_registry.search_tools(task_query, limit=10)
35
- tool_candidates = [f"{tool['id']}: {tool['description']}" for tool in tools_list]
31
+ tool_candidates = [
32
+ f"{tool['id']}: {tool['description']}" for tool in tools_list
33
+ ]
36
34
  logger.info(f"Found {len(tool_candidates)} candidate tools.")
37
35
 
38
36
  class ToolSelectionOutput(TypedDict):
39
37
  tool_names: list[str]
40
38
 
41
- model = tool_selection_llm
39
+ model = llm
42
40
  app_ids = await tool_registry.list_all_apps()
43
41
  connections = await tool_registry.list_connected_apps()
44
42
  connection_ids = set([connection["app_id"] for connection in connections])
45
- connected_apps = [app["id"] for app in app_ids if app["id"] in connection_ids]
46
- unconnected_apps = [app["id"] for app in app_ids if app["id"] not in connection_ids]
47
- app_id_descriptions = "These are the apps connected to the user's account:\n" + "\n".join(
48
- [f"{app}" for app in connected_apps]
43
+ connected_apps = [
44
+ app["id"] for app in app_ids if app["id"] in connection_ids
45
+ ]
46
+ unconnected_apps = [
47
+ app["id"] for app in app_ids if app["id"] not in connection_ids
48
+ ]
49
+ app_id_descriptions = (
50
+ "These are the apps connected to the user's account:\n"
51
+ + "\n".join([f"{app}" for app in connected_apps])
49
52
  )
50
53
  if unconnected_apps:
51
54
  app_id_descriptions += "\n\nOther (not connected) apps: " + "\n".join(
52
55
  [f"{app}" for app in unconnected_apps]
53
56
  )
54
57
 
55
- response = await model.with_structured_output(schema=ToolSelectionOutput, method="json_mode").ainvoke(
58
+ response = await model.with_structured_output(
59
+ schema=ToolSelectionOutput, method="json_mode"
60
+ ).ainvoke(
56
61
  SELECT_TOOL_PROMPT.format(
57
- app_ids=app_id_descriptions, tool_candidates="\n - ".join(tool_candidates), task=task_query
62
+ app_ids=app_id_descriptions,
63
+ tool_candidates="\n - ".join(tool_candidates),
64
+ task=task_query,
58
65
  )
59
66
  )
60
67
 
@@ -65,45 +72,63 @@ def build_graph(
65
72
  logger.error(f"Error retrieving tools: {e}")
66
73
  return []
67
74
 
68
- async def call_model(state: State, runtime: Runtime[Context]) -> Command[Literal["select_tools", "call_tools"]]:
75
+
76
+ async def call_model(
77
+ state: State
78
+ ) -> Command[Literal["select_tools", "call_tools"]]:
69
79
  logger.info("Calling model...")
70
80
  try:
71
- system_message = runtime.context.system_prompt.format(system_time=datetime.now(tz=UTC).isoformat())
72
- messages = [{"role": "system", "content": system_message}, *state["messages"]]
81
+ messages = [
82
+ {"role": "system", "content": system_prompt},
83
+ *state["messages"],
84
+ ]
73
85
 
74
86
  logger.info(f"Selected tool IDs: {state['selected_tool_ids']}")
75
87
  if len(state["selected_tool_ids"]) > 0:
76
- selected_tools = await tool_registry.export_tools(tools=state["selected_tool_ids"], format=ToolFormat.LANGCHAIN)
88
+ selected_tools = await tool_registry.export_tools(
89
+ tools=state["selected_tool_ids"], format=ToolFormat.LANGCHAIN
90
+ )
77
91
  logger.info(f"Exported {len(selected_tools)} tools for model.")
78
92
  else:
79
93
  selected_tools = []
80
94
 
81
- model = llm
82
- if isinstance(model, ChatAnthropic):
83
- model_with_tools = model.bind_tools(
84
- [retrieve_tools, *selected_tools], tool_choice="auto", cache_control={"type": "ephemeral"}
85
- )
86
- else:
87
- model_with_tools = model.bind_tools([retrieve_tools, *selected_tools], tool_choice="auto")
88
- response = cast(AIMessage, await model_with_tools.ainvoke(messages))
95
+ model_with_tools = llm.bind_tools(
96
+ [retrieve_tools, *selected_tools], tool_choice="auto"
97
+ )
98
+
99
+
100
+ response = await model_with_tools.ainvoke(messages)
101
+ cast(AIMessage, response)
102
+ logger.debug(f"Response: {response}")
103
+
89
104
 
90
105
  if response.tool_calls:
91
- logger.info(f"Model responded with {len(response.tool_calls)} tool calls.")
106
+ logger.info(
107
+ f"Model responded with {len(response.tool_calls)} tool calls."
108
+ )
92
109
  if len(response.tool_calls) > 1:
93
- raise Exception("Not possible in Claude with llm.bind_tools(tools=tools, tool_choice='auto')")
110
+ raise Exception(
111
+ "Not possible in Claude with llm.bind_tools(tools=tools, tool_choice='auto')"
112
+ )
94
113
  tool_call = response.tool_calls[0]
95
114
  if tool_call["name"] == retrieve_tools.name:
96
115
  logger.info("Model requested to select tools.")
97
116
  return Command(goto="select_tools", update={"messages": [response]})
98
117
  elif tool_call["name"] not in state["selected_tool_ids"]:
99
118
  try:
100
- await tool_registry.export_tools([tool_call["name"]], ToolFormat.LANGCHAIN)
119
+ await tool_registry.export_tools(
120
+ [tool_call["name"]], ToolFormat.LANGCHAIN
121
+ )
101
122
  logger.info(
102
123
  f"Tool '{tool_call['name']}' not in selected tools, but available. Proceeding to call."
103
124
  )
104
- return Command(goto="call_tools", update={"messages": [response]})
125
+ return Command(
126
+ goto="call_tools", update={"messages": [response]}
127
+ )
105
128
  except Exception as e:
106
- logger.error(f"Unexpected tool call: {tool_call['name']}. Error: {e}")
129
+ logger.error(
130
+ f"Unexpected tool call: {tool_call['name']}. Error: {e}"
131
+ )
107
132
  raise Exception(
108
133
  f"Unexpected tool call: {tool_call['name']}. Available tools: {state['selected_tool_ids']}"
109
134
  ) from e
@@ -116,14 +141,24 @@ def build_graph(
116
141
  logger.error(f"Error in call_model: {e}")
117
142
  raise
118
143
 
119
- async def select_tools(state: State, runtime: Runtime[Context]) -> Command[Literal["call_model"]]:
144
+ async def select_tools(
145
+ state: State
146
+ ) -> Command[Literal["call_model"]]:
120
147
  logger.info("Selecting tools...")
121
148
  try:
122
149
  tool_call = state["messages"][-1].tool_calls[0]
123
150
  selected_tool_names = await retrieve_tools.ainvoke(input=tool_call["args"])
124
- tool_msg = ToolMessage(f"Available tools: {selected_tool_names}", tool_call_id=tool_call["id"])
151
+ tool_msg = ToolMessage(
152
+ f"Available tools: {selected_tool_names}", tool_call_id=tool_call["id"]
153
+ )
125
154
  logger.info(f"Tools selected: {selected_tool_names}")
126
- return Command(goto="call_model", update={"messages": [tool_msg], "selected_tool_ids": selected_tool_names})
155
+ return Command(
156
+ goto="call_model",
157
+ update={
158
+ "messages": [tool_msg],
159
+ "selected_tool_ids": selected_tool_names,
160
+ },
161
+ )
127
162
  except Exception as e:
128
163
  logger.error(f"Error in select_tools: {e}")
129
164
  raise
@@ -133,10 +168,16 @@ def build_graph(
133
168
  outputs = []
134
169
  recent_tool_ids = []
135
170
  for tool_call in state["messages"][-1].tool_calls:
136
- logger.info(f"Executing tool: {tool_call['name']} with args: {tool_call['args']}")
171
+ logger.info(
172
+ f"Executing tool: {tool_call['name']} with args: {tool_call['args']}"
173
+ )
137
174
  try:
138
- await tool_registry.export_tools([tool_call["name"]], ToolFormat.LANGCHAIN)
139
- tool_result = await tool_registry.call_tool(tool_call["name"], tool_call["args"])
175
+ await tool_registry.export_tools(
176
+ [tool_call["name"]], ToolFormat.LANGCHAIN
177
+ )
178
+ tool_result = await tool_registry.call_tool(
179
+ tool_call["name"], tool_call["args"]
180
+ )
140
181
  logger.info(f"Tool '{tool_call['name']}' executed successfully.")
141
182
  outputs.append(
142
183
  ToolMessage(
@@ -155,9 +196,12 @@ def build_graph(
155
196
  tool_call_id=tool_call["id"],
156
197
  )
157
198
  )
158
- return Command(goto="call_model", update={"messages": outputs, "selected_tool_ids": recent_tool_ids})
199
+ return Command(
200
+ goto="call_model",
201
+ update={"messages": outputs, "selected_tool_ids": recent_tool_ids},
202
+ )
159
203
 
160
- builder = StateGraph(State, context_schema=Context)
204
+ builder = StateGraph(State)
161
205
 
162
206
  builder.add_node(call_model)
163
207
  builder.add_node(select_tools)