grasp_agents 0.2.11__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. grasp_agents/__init__.py +15 -14
  2. grasp_agents/cloud_llm.py +118 -131
  3. grasp_agents/comm_processor.py +201 -0
  4. grasp_agents/generics_utils.py +15 -7
  5. grasp_agents/llm.py +60 -31
  6. grasp_agents/llm_agent.py +229 -273
  7. grasp_agents/llm_agent_memory.py +58 -0
  8. grasp_agents/llm_policy_executor.py +482 -0
  9. grasp_agents/memory.py +20 -134
  10. grasp_agents/message_history.py +140 -0
  11. grasp_agents/openai/__init__.py +54 -36
  12. grasp_agents/openai/completion_chunk_converters.py +78 -0
  13. grasp_agents/openai/completion_converters.py +53 -30
  14. grasp_agents/openai/content_converters.py +13 -14
  15. grasp_agents/openai/converters.py +44 -68
  16. grasp_agents/openai/message_converters.py +58 -72
  17. grasp_agents/openai/openai_llm.py +101 -42
  18. grasp_agents/openai/tool_converters.py +24 -19
  19. grasp_agents/packet.py +24 -0
  20. grasp_agents/packet_pool.py +91 -0
  21. grasp_agents/printer.py +29 -15
  22. grasp_agents/processor.py +193 -0
  23. grasp_agents/prompt_builder.py +175 -192
  24. grasp_agents/run_context.py +20 -37
  25. grasp_agents/typing/completion.py +58 -12
  26. grasp_agents/typing/completion_chunk.py +173 -0
  27. grasp_agents/typing/converters.py +8 -12
  28. grasp_agents/typing/events.py +86 -0
  29. grasp_agents/typing/io.py +4 -13
  30. grasp_agents/typing/message.py +12 -50
  31. grasp_agents/typing/tool.py +52 -26
  32. grasp_agents/usage_tracker.py +6 -6
  33. grasp_agents/utils.py +3 -3
  34. grasp_agents/workflow/looped_workflow.py +132 -0
  35. grasp_agents/workflow/parallel_processor.py +95 -0
  36. grasp_agents/workflow/sequential_workflow.py +66 -0
  37. grasp_agents/workflow/workflow_processor.py +78 -0
  38. {grasp_agents-0.2.11.dist-info → grasp_agents-0.3.2.dist-info}/METADATA +41 -50
  39. grasp_agents-0.3.2.dist-info/RECORD +51 -0
  40. grasp_agents/agent_message.py +0 -27
  41. grasp_agents/agent_message_pool.py +0 -92
  42. grasp_agents/base_agent.py +0 -51
  43. grasp_agents/comm_agent.py +0 -217
  44. grasp_agents/llm_agent_state.py +0 -79
  45. grasp_agents/tool_orchestrator.py +0 -203
  46. grasp_agents/workflow/looped_agent.py +0 -134
  47. grasp_agents/workflow/sequential_agent.py +0 -72
  48. grasp_agents/workflow/workflow_agent.py +0 -88
  49. grasp_agents-0.2.11.dist-info/RECORD +0 -46
  50. {grasp_agents-0.2.11.dist-info → grasp_agents-0.3.2.dist-info}/WHEEL +0 -0
  51. {grasp_agents-0.2.11.dist-info → grasp_agents-0.3.2.dist-info}/licenses/LICENSE.md +0 -0
@@ -1,92 +0,0 @@
1
- import asyncio
2
- import logging
3
- from typing import Any, Generic, Protocol, TypeVar
4
-
5
- from .agent_message import AgentMessage
6
- from .run_context import CtxT, RunContextWrapper
7
- from .typing.io import AgentID, AgentState
8
-
9
- logger = logging.getLogger(__name__)
10
-
11
-
12
- _MH_PayloadT = TypeVar("_MH_PayloadT", contravariant=True) # noqa: PLC0105
13
- _MH_StateT = TypeVar("_MH_StateT", bound=AgentState, contravariant=True) # noqa: PLC0105
14
-
15
-
16
- class MessageHandler(Protocol[_MH_PayloadT, _MH_StateT, CtxT]):
17
- async def __call__(
18
- self,
19
- message: AgentMessage[_MH_PayloadT, _MH_StateT],
20
- ctx: RunContextWrapper[CtxT] | None,
21
- **kwargs: Any,
22
- ) -> None: ...
23
-
24
-
25
- class AgentMessagePool(Generic[CtxT]):
26
- def __init__(self) -> None:
27
- self._queues: dict[AgentID, asyncio.Queue[AgentMessage[Any, AgentState]]] = {}
28
- self._message_handlers: dict[
29
- AgentID, MessageHandler[Any, AgentState, CtxT]
30
- ] = {}
31
- self._tasks: dict[AgentID, asyncio.Task[None]] = {}
32
-
33
- async def post(self, message: AgentMessage[Any, AgentState]) -> None:
34
- for recipient_id in message.recipient_ids:
35
- queue = self._queues.setdefault(recipient_id, asyncio.Queue())
36
- await queue.put(message)
37
-
38
- def register_message_handler(
39
- self,
40
- agent_id: AgentID,
41
- handler: MessageHandler[Any, AgentState, CtxT],
42
- ctx: RunContextWrapper[CtxT] | None = None,
43
- **run_kwargs: Any,
44
- ) -> None:
45
- self._message_handlers[agent_id] = handler
46
- self._queues.setdefault(agent_id, asyncio.Queue())
47
- if agent_id not in self._tasks:
48
- self._tasks[agent_id] = asyncio.create_task(
49
- self._process_messages(agent_id, ctx=ctx, **run_kwargs)
50
- )
51
-
52
- async def _process_messages(
53
- self,
54
- agent_id: AgentID,
55
- ctx: RunContextWrapper[CtxT] | None = None,
56
- **run_kwargs: Any,
57
- ) -> None:
58
- queue = self._queues[agent_id]
59
- while True:
60
- try:
61
- message = await queue.get()
62
- handler = self._message_handlers.get(agent_id)
63
- if handler is None:
64
- break
65
-
66
- try:
67
- await self._message_handlers[agent_id](
68
- message, ctx=ctx, **run_kwargs
69
- )
70
- except Exception:
71
- logger.exception(f"Error handling message for {agent_id}")
72
-
73
- queue.task_done()
74
-
75
- except Exception:
76
- logger.exception(f"Unexpected error in processing loop for {agent_id}")
77
-
78
- async def unregister_message_handler(self, agent_id: AgentID) -> None:
79
- if task := self._tasks.get(agent_id):
80
- task.cancel()
81
- try:
82
- await task
83
- except asyncio.CancelledError:
84
- logger.debug(f"{agent_id} exited")
85
-
86
- self._tasks.pop(agent_id, None)
87
- self._queues.pop(agent_id, None)
88
- self._message_handlers.pop(agent_id, None)
89
-
90
- async def stop_all(self) -> None:
91
- for agent_id in list(self._tasks):
92
- await self.unregister_message_handler(agent_id)
@@ -1,51 +0,0 @@
1
- from abc import ABC, abstractmethod
2
- from typing import Any, ClassVar, Generic
3
-
4
- from pydantic import TypeAdapter
5
-
6
- from .generics_utils import AutoInstanceAttributesMixin
7
- from .run_context import CtxT, RunContextWrapper
8
- from .typing.io import AgentID, OutT, StateT
9
- from .typing.tool import BaseTool
10
-
11
-
12
- class BaseAgent(AutoInstanceAttributesMixin, ABC, Generic[OutT, StateT, CtxT]):
13
- _generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {0: "_out_type"}
14
-
15
- @abstractmethod
16
- def __init__(self, agent_id: AgentID, **kwargs: Any) -> None:
17
- self._out_type: type[OutT]
18
- self._state: StateT
19
-
20
- super().__init__()
21
-
22
- self._agent_id = agent_id
23
- self._out_type_adapter: TypeAdapter[OutT] = TypeAdapter(self._out_type)
24
-
25
- @property
26
- def out_type(self) -> type[OutT]:
27
- return self._out_type
28
-
29
- @property
30
- def agent_id(self) -> AgentID:
31
- return self._agent_id
32
-
33
- @property
34
- def state(self) -> StateT:
35
- return self._state
36
-
37
- @abstractmethod
38
- async def run(
39
- self,
40
- chat_inputs: Any | None = None,
41
- *,
42
- ctx: RunContextWrapper[CtxT] | None = None,
43
- **kwargs: Any,
44
- ) -> Any:
45
- pass
46
-
47
- @abstractmethod
48
- def as_tool(
49
- self, tool_name: str, tool_description: str, tool_strict: bool = True
50
- ) -> BaseTool[Any, OutT, CtxT]:
51
- pass
@@ -1,217 +0,0 @@
1
- import logging
2
- from abc import abstractmethod
3
- from collections.abc import Sequence
4
- from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast, final
5
-
6
- from pydantic import BaseModel, TypeAdapter
7
- from pydantic.json_schema import SkipJsonSchema
8
-
9
- from .agent_message import AgentMessage
10
- from .agent_message_pool import AgentMessagePool
11
- from .base_agent import BaseAgent
12
- from .run_context import CtxT, RunContextWrapper
13
- from .typing.io import AgentID, AgentState, InT, OutT, StateT
14
- from .typing.tool import BaseTool
15
-
16
- logger = logging.getLogger(__name__)
17
-
18
-
19
- class DynCommPayload(BaseModel):
20
- selected_recipient_ids: SkipJsonSchema[Sequence[AgentID]]
21
-
22
-
23
- _EH_OutT = TypeVar("_EH_OutT", contravariant=True) # noqa: PLC0105
24
- _EH_StateT = TypeVar("_EH_StateT", bound=AgentState, contravariant=True) # noqa: PLC0105
25
-
26
-
27
- class ExitHandler(Protocol[_EH_OutT, _EH_StateT, CtxT]):
28
- def __call__(
29
- self,
30
- output_message: AgentMessage[_EH_OutT, _EH_StateT],
31
- ctx: RunContextWrapper[CtxT] | None,
32
- ) -> bool: ...
33
-
34
-
35
- class CommunicatingAgent(
36
- BaseAgent[OutT, StateT, CtxT], Generic[InT, OutT, StateT, CtxT]
37
- ):
38
- _generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
39
- 0: "_in_type",
40
- 1: "_out_type",
41
- }
42
-
43
- def __init__(
44
- self,
45
- agent_id: AgentID,
46
- *,
47
- recipient_ids: Sequence[AgentID] | None = None,
48
- message_pool: AgentMessagePool[CtxT] | None = None,
49
- **kwargs: Any,
50
- ) -> None:
51
- self._in_type: type[InT]
52
- super().__init__(agent_id=agent_id, **kwargs)
53
-
54
- self._in_args_type_adapter: TypeAdapter[InT] = TypeAdapter(self._in_type)
55
- self.recipient_ids = recipient_ids or []
56
-
57
- self._message_pool = message_pool or AgentMessagePool()
58
- self._is_listening = False
59
- self._exit_impl: ExitHandler[OutT, StateT, CtxT] | None = None
60
-
61
- @property
62
- def in_type(self) -> type[InT]: # type: ignore
63
- # Exposing the type of a contravariant variable only, should be safe
64
- return self._in_type
65
-
66
- def _validate_routing(self, payloads: Sequence[OutT]) -> Sequence[AgentID]:
67
- if all(isinstance(p, DynCommPayload) for p in payloads):
68
- payloads_ = cast("Sequence[DynCommPayload]", payloads)
69
- selected_recipient_ids_per_payload = [
70
- set(p.selected_recipient_ids or []) for p in payloads_
71
- ]
72
- assert all(
73
- x == selected_recipient_ids_per_payload[0]
74
- for x in selected_recipient_ids_per_payload
75
- ), "All payloads must have the same recipient IDs for dynamic routing"
76
-
77
- assert payloads_[0].selected_recipient_ids is not None
78
- selected_recipient_ids = payloads_[0].selected_recipient_ids
79
-
80
- assert all(rid in self.recipient_ids for rid in selected_recipient_ids), (
81
- "Dynamic routing is enabled, but recipient IDs are not in "
82
- "the allowed agent's recipient IDs"
83
- )
84
-
85
- return selected_recipient_ids
86
-
87
- if all((not isinstance(p, DynCommPayload)) for p in payloads):
88
- return self.recipient_ids
89
-
90
- raise ValueError(
91
- "All payloads must be either DCommAgentPayload or not DCommAgentPayload"
92
- )
93
-
94
- async def post_message(self, message: AgentMessage[OutT, StateT]) -> None:
95
- self._validate_routing(message.payloads)
96
-
97
- await self._message_pool.post(message)
98
-
99
- @abstractmethod
100
- async def run(
101
- self,
102
- chat_inputs: Any | None = None,
103
- *,
104
- ctx: RunContextWrapper[CtxT] | None = None,
105
- in_message: AgentMessage[InT, AgentState] | None = None,
106
- entry_point: bool = False,
107
- forbid_state_change: bool = False,
108
- **kwargs: Any,
109
- ) -> AgentMessage[OutT, StateT]:
110
- pass
111
-
112
- async def run_and_post(
113
- self, ctx: RunContextWrapper[CtxT] | None = None, **run_kwargs: Any
114
- ) -> None:
115
- output_message = await self.run(
116
- ctx=ctx, in_message=None, entry_point=True, **run_kwargs
117
- )
118
- await self.post_message(output_message)
119
-
120
- def exit_handler(
121
- self, func: ExitHandler[OutT, StateT, CtxT]
122
- ) -> ExitHandler[OutT, StateT, CtxT]:
123
- self._exit_impl = func
124
-
125
- return func
126
-
127
- def _exit_condition(
128
- self,
129
- output_message: AgentMessage[OutT, StateT],
130
- ctx: RunContextWrapper[CtxT] | None,
131
- ) -> bool:
132
- if self._exit_impl:
133
- return self._exit_impl(output_message=output_message, ctx=ctx)
134
-
135
- return False
136
-
137
- async def _message_handler(
138
- self,
139
- message: AgentMessage[Any, AgentState],
140
- ctx: RunContextWrapper[CtxT] | None = None,
141
- **run_kwargs: Any,
142
- ) -> None:
143
- in_message = cast("AgentMessage[InT, AgentState]", message)
144
- out_message = await self.run(ctx=ctx, in_message=in_message, **run_kwargs)
145
-
146
- if self._exit_condition(output_message=out_message, ctx=ctx):
147
- await self._message_pool.stop_all()
148
- return
149
-
150
- if self.recipient_ids:
151
- await self.post_message(out_message)
152
-
153
- @property
154
- def is_listening(self) -> bool:
155
- return self._is_listening
156
-
157
- async def start_listening(
158
- self, ctx: RunContextWrapper[CtxT] | None = None, **run_kwargs: Any
159
- ) -> None:
160
- if self._is_listening:
161
- return
162
-
163
- self._is_listening = True
164
- self._message_pool.register_message_handler(
165
- agent_id=self.agent_id,
166
- handler=self._message_handler,
167
- ctx=ctx,
168
- **run_kwargs,
169
- )
170
-
171
- async def stop_listening(self) -> None:
172
- self._is_listening = False
173
- await self._message_pool.unregister_message_handler(self.agent_id)
174
-
175
- @final
176
- def as_tool(
177
- self,
178
- tool_name: str,
179
- tool_description: str,
180
- tool_strict: bool = True,
181
- ) -> BaseTool[InT, OutT, Any]: # type: ignore[override]
182
- # Will check if InT is a BaseModel at runtime
183
- agent_instance = self
184
- in_type = agent_instance.in_type
185
- out_type = agent_instance.out_type
186
- if not issubclass(in_type, BaseModel):
187
- raise TypeError(
188
- "Cannot create a tool from an agent with "
189
- f"non-BaseModel input type: {in_type}"
190
- )
191
-
192
- class AgentTool(BaseTool[in_type, out_type, Any]):
193
- name: str = tool_name
194
- description: str = tool_description
195
- strict: bool | None = tool_strict
196
-
197
- async def run(
198
- self,
199
- inp: InT,
200
- ctx: RunContextWrapper[CtxT] | None = None,
201
- ) -> OutT:
202
- in_args = in_type.model_validate(inp)
203
- in_message = AgentMessage[in_type, AgentState](
204
- payloads=[in_args],
205
- sender_id="<tool_user>",
206
- recipient_ids=[agent_instance.agent_id],
207
- )
208
- agent_result = await agent_instance.run(
209
- in_message=in_message,
210
- entry_point=False,
211
- forbid_state_change=True,
212
- ctx=ctx,
213
- )
214
-
215
- return agent_result.payloads[0]
216
-
217
- return AgentTool() # type: ignore[return-value]
@@ -1,79 +0,0 @@
1
- from copy import deepcopy
2
- from typing import Any, Literal, Optional, Protocol
3
-
4
- from pydantic import ConfigDict, Field
5
-
6
- from .memory import MessageHistory
7
- from .run_context import RunContextWrapper
8
- from .typing.io import AgentState, LLMPrompt
9
-
10
- SetAgentStateStrategy = Literal["keep", "reset", "from_sender", "custom"]
11
-
12
-
13
- class SetAgentState(Protocol):
14
- def __call__(
15
- self,
16
- cur_state: "LLMAgentState",
17
- *,
18
- in_state: AgentState | None,
19
- sys_prompt: LLMPrompt | None,
20
- ctx: RunContextWrapper[Any] | None,
21
- ) -> "LLMAgentState": ...
22
-
23
-
24
- class LLMAgentState(AgentState):
25
- message_history: MessageHistory = Field(default_factory=MessageHistory)
26
-
27
- @property
28
- def batch_size(self) -> int:
29
- return self.message_history.batch_size
30
-
31
- @classmethod
32
- def from_cur_and_in_states(
33
- cls,
34
- cur_state: "LLMAgentState",
35
- *,
36
- in_state: Optional["AgentState"] = None,
37
- sys_prompt: LLMPrompt | None = None,
38
- strategy: SetAgentStateStrategy = "from_sender",
39
- set_agent_state_impl: SetAgentState | None = None,
40
- ctx: RunContextWrapper[Any] | None = None,
41
- ) -> "LLMAgentState":
42
- upd_mh = cur_state.message_history if cur_state else None
43
- if upd_mh is None or len(upd_mh) == 0:
44
- upd_mh = MessageHistory(sys_prompt=sys_prompt)
45
-
46
- if strategy == "keep":
47
- pass
48
-
49
- elif strategy == "reset":
50
- upd_mh.reset(sys_prompt)
51
-
52
- elif strategy == "from_sender":
53
- in_mh = (
54
- in_state.message_history
55
- if in_state and isinstance(in_state, "LLMAgentState")
56
- else None
57
- )
58
- if in_mh:
59
- in_mh = deepcopy(in_mh)
60
- else:
61
- upd_mh.reset(sys_prompt)
62
-
63
- elif strategy == "custom":
64
- assert set_agent_state_impl is not None, (
65
- "Agent state setter implementation is not provided."
66
- )
67
- return set_agent_state_impl(
68
- cur_state=cur_state,
69
- in_state=in_state,
70
- sys_prompt=sys_prompt,
71
- ctx=ctx,
72
- )
73
-
74
- return cls.model_construct(message_history=upd_mh)
75
-
76
- def __repr__(self) -> str:
77
- return f"Message History: {len(self.message_history)}"
78
-
79
- model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -1,203 +0,0 @@
1
- import asyncio
2
- import json
3
- from collections.abc import Coroutine, Sequence
4
- from logging import getLogger
5
- from typing import Any, Generic, Protocol
6
-
7
- from pydantic import BaseModel
8
-
9
- from .llm import LLM, LLMSettings
10
- from .llm_agent_state import LLMAgentState
11
- from .run_context import CtxT, RunContextWrapper
12
- from .typing.converters import Converters
13
- from .typing.message import AssistantMessage, Conversation, Message, ToolMessage
14
- from .typing.tool import BaseTool, ToolCall, ToolChoice
15
-
16
- logger = getLogger(__name__)
17
-
18
-
19
- class ExitToolCallLoopHandler(Protocol[CtxT]):
20
- def __call__(
21
- self,
22
- conversation: Conversation,
23
- *,
24
- ctx: RunContextWrapper[CtxT] | None,
25
- **kwargs: Any,
26
- ) -> bool: ...
27
-
28
-
29
- class ManageAgentStateHandler(Protocol[CtxT]):
30
- def __call__(
31
- self,
32
- state: LLMAgentState,
33
- *,
34
- ctx: RunContextWrapper[CtxT] | None,
35
- **kwargs: Any,
36
- ) -> None: ...
37
-
38
-
39
- class ToolOrchestrator(Generic[CtxT]):
40
- def __init__(
41
- self,
42
- agent_id: str,
43
- llm: LLM[LLMSettings, Converters],
44
- tools: list[BaseTool[BaseModel, Any, CtxT]] | None,
45
- max_turns: int,
46
- react_mode: bool = False,
47
- ) -> None:
48
- self._agent_id = agent_id
49
-
50
- self._llm = llm
51
- self._llm.tools = tools
52
-
53
- self._max_turns = max_turns
54
- self._react_mode = react_mode
55
-
56
- self.exit_tool_call_loop_impl: ExitToolCallLoopHandler[CtxT] | None = None
57
- self.manage_agent_state_impl: ManageAgentStateHandler[CtxT] | None = None
58
-
59
- @property
60
- def agent_id(self) -> str:
61
- return self._agent_id
62
-
63
- @property
64
- def llm(self) -> LLM[LLMSettings, Converters]:
65
- return self._llm
66
-
67
- @property
68
- def tools(self) -> dict[str, BaseTool[BaseModel, Any, CtxT]]:
69
- return self._llm.tools or {}
70
-
71
- @property
72
- def max_turns(self) -> int:
73
- return self._max_turns
74
-
75
- def _exit_tool_call_loop(
76
- self,
77
- conversation: Conversation,
78
- *,
79
- ctx: RunContextWrapper[CtxT] | None = None,
80
- **kwargs: Any,
81
- ) -> bool:
82
- if self.exit_tool_call_loop_impl:
83
- return self.exit_tool_call_loop_impl(
84
- conversation=conversation, ctx=ctx, **kwargs
85
- )
86
-
87
- assert conversation, "Conversation must not be empty"
88
- assert isinstance(conversation[-1], AssistantMessage), (
89
- "Last message in conversation must be an AssistantMessage"
90
- )
91
-
92
- return not bool(conversation[-1].tool_calls)
93
-
94
- def _manage_agent_state(
95
- self,
96
- state: LLMAgentState,
97
- *,
98
- ctx: RunContextWrapper[CtxT] | None = None,
99
- **kwargs: Any,
100
- ) -> None:
101
- if self.manage_agent_state_impl:
102
- self.manage_agent_state_impl(state=state, ctx=ctx, **kwargs)
103
-
104
- async def generate_once(
105
- self,
106
- state: LLMAgentState,
107
- tool_choice: ToolChoice | None = None,
108
- ctx: RunContextWrapper[CtxT] | None = None,
109
- ) -> Sequence[AssistantMessage]:
110
- message_history = state.message_history
111
- message_batch = await self.llm.generate_message_batch(
112
- message_history, tool_choice=tool_choice
113
- )
114
- message_history.add_message_batch(message_batch)
115
-
116
- self._print_messages_and_track_usage(message_batch, ctx=ctx)
117
-
118
- return message_batch
119
-
120
- async def run_loop(
121
- self,
122
- state: LLMAgentState,
123
- ctx: RunContextWrapper[CtxT] | None = None,
124
- ) -> None:
125
- message_history = state.message_history
126
- assert message_history.batch_size == 1, (
127
- "Batch size must be 1 for tool call loop"
128
- )
129
-
130
- tool_choice: ToolChoice
131
-
132
- tool_choice = "none" if self._react_mode else "auto"
133
- gen_message_batch = await self.generate_once(
134
- state, tool_choice=tool_choice, ctx=ctx
135
- )
136
-
137
- turns = 0
138
-
139
- while True:
140
- self._manage_agent_state(state=state, ctx=ctx, num_turns=turns)
141
-
142
- if self._exit_tool_call_loop(
143
- message_history.batched_conversations[0], ctx=ctx, num_turns=turns
144
- ):
145
- return
146
- if turns >= self.max_turns:
147
- logger.info(
148
- f"Max turns reached: {self.max_turns}. Stopping tool call loop."
149
- )
150
- return
151
-
152
- msg = gen_message_batch[0]
153
- if msg.tool_calls:
154
- tool_messages = await self.call_tools(msg.tool_calls, ctx=ctx)
155
- message_history.add_messages(tool_messages)
156
-
157
- tool_choice = "none" if (self._react_mode and msg.tool_calls) else "auto"
158
- gen_message_batch = await self.generate_once(
159
- state, tool_choice=tool_choice, ctx=ctx
160
- )
161
-
162
- turns += 1
163
-
164
- async def call_tools(
165
- self,
166
- calls: Sequence[ToolCall],
167
- ctx: RunContextWrapper[CtxT] | None = None,
168
- ) -> Sequence[ToolMessage]:
169
- corouts: list[Coroutine[Any, Any, BaseModel]] = []
170
- for call in calls:
171
- tool = self.tools[call.tool_name]
172
- args = json.loads(call.tool_arguments)
173
- corouts.append(tool(ctx=ctx, **args))
174
-
175
- outs = await asyncio.gather(*corouts)
176
-
177
- tool_messages = [
178
- ToolMessage.from_tool_output(out, call, model_id=self.agent_id)
179
- for out, call in zip(outs, calls, strict=False)
180
- ]
181
-
182
- self._print_messages(tool_messages, ctx=ctx)
183
-
184
- return tool_messages
185
-
186
- def _print_messages(
187
- self,
188
- message_batch: Sequence[Message],
189
- ctx: RunContextWrapper[CtxT] | None = None,
190
- ) -> None:
191
- if ctx:
192
- ctx.printer.print_llm_messages(message_batch, agent_id=self.agent_id)
193
-
194
- def _print_messages_and_track_usage(
195
- self,
196
- message_batch: Sequence[AssistantMessage],
197
- ctx: RunContextWrapper[CtxT] | None = None,
198
- ) -> None:
199
- if ctx:
200
- self._print_messages(message_batch, ctx=ctx)
201
- ctx.usage_tracker.update(
202
- messages=message_batch, model_name=self.llm.model_name
203
- )