grasp_agents 0.2.11__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_agents/__init__.py +15 -14
- grasp_agents/cloud_llm.py +118 -131
- grasp_agents/comm_processor.py +201 -0
- grasp_agents/generics_utils.py +15 -7
- grasp_agents/llm.py +60 -31
- grasp_agents/llm_agent.py +229 -273
- grasp_agents/llm_agent_memory.py +58 -0
- grasp_agents/llm_policy_executor.py +482 -0
- grasp_agents/memory.py +20 -134
- grasp_agents/message_history.py +140 -0
- grasp_agents/openai/__init__.py +54 -36
- grasp_agents/openai/completion_chunk_converters.py +78 -0
- grasp_agents/openai/completion_converters.py +53 -30
- grasp_agents/openai/content_converters.py +13 -14
- grasp_agents/openai/converters.py +44 -68
- grasp_agents/openai/message_converters.py +58 -72
- grasp_agents/openai/openai_llm.py +101 -42
- grasp_agents/openai/tool_converters.py +24 -19
- grasp_agents/packet.py +24 -0
- grasp_agents/packet_pool.py +91 -0
- grasp_agents/printer.py +29 -15
- grasp_agents/processor.py +193 -0
- grasp_agents/prompt_builder.py +175 -192
- grasp_agents/run_context.py +20 -37
- grasp_agents/typing/completion.py +58 -12
- grasp_agents/typing/completion_chunk.py +173 -0
- grasp_agents/typing/converters.py +8 -12
- grasp_agents/typing/events.py +86 -0
- grasp_agents/typing/io.py +4 -13
- grasp_agents/typing/message.py +12 -50
- grasp_agents/typing/tool.py +52 -26
- grasp_agents/usage_tracker.py +6 -6
- grasp_agents/utils.py +3 -3
- grasp_agents/workflow/looped_workflow.py +132 -0
- grasp_agents/workflow/parallel_processor.py +95 -0
- grasp_agents/workflow/sequential_workflow.py +66 -0
- grasp_agents/workflow/workflow_processor.py +78 -0
- {grasp_agents-0.2.11.dist-info → grasp_agents-0.3.2.dist-info}/METADATA +41 -50
- grasp_agents-0.3.2.dist-info/RECORD +51 -0
- grasp_agents/agent_message.py +0 -27
- grasp_agents/agent_message_pool.py +0 -92
- grasp_agents/base_agent.py +0 -51
- grasp_agents/comm_agent.py +0 -217
- grasp_agents/llm_agent_state.py +0 -79
- grasp_agents/tool_orchestrator.py +0 -203
- grasp_agents/workflow/looped_agent.py +0 -134
- grasp_agents/workflow/sequential_agent.py +0 -72
- grasp_agents/workflow/workflow_agent.py +0 -88
- grasp_agents-0.2.11.dist-info/RECORD +0 -46
- {grasp_agents-0.2.11.dist-info → grasp_agents-0.3.2.dist-info}/WHEEL +0 -0
- {grasp_agents-0.2.11.dist-info → grasp_agents-0.3.2.dist-info}/licenses/LICENSE.md +0 -0
@@ -1,92 +0,0 @@
|
|
1
|
-
import asyncio
|
2
|
-
import logging
|
3
|
-
from typing import Any, Generic, Protocol, TypeVar
|
4
|
-
|
5
|
-
from .agent_message import AgentMessage
|
6
|
-
from .run_context import CtxT, RunContextWrapper
|
7
|
-
from .typing.io import AgentID, AgentState
|
8
|
-
|
9
|
-
logger = logging.getLogger(__name__)
|
10
|
-
|
11
|
-
|
12
|
-
_MH_PayloadT = TypeVar("_MH_PayloadT", contravariant=True) # noqa: PLC0105
|
13
|
-
_MH_StateT = TypeVar("_MH_StateT", bound=AgentState, contravariant=True) # noqa: PLC0105
|
14
|
-
|
15
|
-
|
16
|
-
class MessageHandler(Protocol[_MH_PayloadT, _MH_StateT, CtxT]):
|
17
|
-
async def __call__(
|
18
|
-
self,
|
19
|
-
message: AgentMessage[_MH_PayloadT, _MH_StateT],
|
20
|
-
ctx: RunContextWrapper[CtxT] | None,
|
21
|
-
**kwargs: Any,
|
22
|
-
) -> None: ...
|
23
|
-
|
24
|
-
|
25
|
-
class AgentMessagePool(Generic[CtxT]):
|
26
|
-
def __init__(self) -> None:
|
27
|
-
self._queues: dict[AgentID, asyncio.Queue[AgentMessage[Any, AgentState]]] = {}
|
28
|
-
self._message_handlers: dict[
|
29
|
-
AgentID, MessageHandler[Any, AgentState, CtxT]
|
30
|
-
] = {}
|
31
|
-
self._tasks: dict[AgentID, asyncio.Task[None]] = {}
|
32
|
-
|
33
|
-
async def post(self, message: AgentMessage[Any, AgentState]) -> None:
|
34
|
-
for recipient_id in message.recipient_ids:
|
35
|
-
queue = self._queues.setdefault(recipient_id, asyncio.Queue())
|
36
|
-
await queue.put(message)
|
37
|
-
|
38
|
-
def register_message_handler(
|
39
|
-
self,
|
40
|
-
agent_id: AgentID,
|
41
|
-
handler: MessageHandler[Any, AgentState, CtxT],
|
42
|
-
ctx: RunContextWrapper[CtxT] | None = None,
|
43
|
-
**run_kwargs: Any,
|
44
|
-
) -> None:
|
45
|
-
self._message_handlers[agent_id] = handler
|
46
|
-
self._queues.setdefault(agent_id, asyncio.Queue())
|
47
|
-
if agent_id not in self._tasks:
|
48
|
-
self._tasks[agent_id] = asyncio.create_task(
|
49
|
-
self._process_messages(agent_id, ctx=ctx, **run_kwargs)
|
50
|
-
)
|
51
|
-
|
52
|
-
async def _process_messages(
|
53
|
-
self,
|
54
|
-
agent_id: AgentID,
|
55
|
-
ctx: RunContextWrapper[CtxT] | None = None,
|
56
|
-
**run_kwargs: Any,
|
57
|
-
) -> None:
|
58
|
-
queue = self._queues[agent_id]
|
59
|
-
while True:
|
60
|
-
try:
|
61
|
-
message = await queue.get()
|
62
|
-
handler = self._message_handlers.get(agent_id)
|
63
|
-
if handler is None:
|
64
|
-
break
|
65
|
-
|
66
|
-
try:
|
67
|
-
await self._message_handlers[agent_id](
|
68
|
-
message, ctx=ctx, **run_kwargs
|
69
|
-
)
|
70
|
-
except Exception:
|
71
|
-
logger.exception(f"Error handling message for {agent_id}")
|
72
|
-
|
73
|
-
queue.task_done()
|
74
|
-
|
75
|
-
except Exception:
|
76
|
-
logger.exception(f"Unexpected error in processing loop for {agent_id}")
|
77
|
-
|
78
|
-
async def unregister_message_handler(self, agent_id: AgentID) -> None:
|
79
|
-
if task := self._tasks.get(agent_id):
|
80
|
-
task.cancel()
|
81
|
-
try:
|
82
|
-
await task
|
83
|
-
except asyncio.CancelledError:
|
84
|
-
logger.debug(f"{agent_id} exited")
|
85
|
-
|
86
|
-
self._tasks.pop(agent_id, None)
|
87
|
-
self._queues.pop(agent_id, None)
|
88
|
-
self._message_handlers.pop(agent_id, None)
|
89
|
-
|
90
|
-
async def stop_all(self) -> None:
|
91
|
-
for agent_id in list(self._tasks):
|
92
|
-
await self.unregister_message_handler(agent_id)
|
grasp_agents/base_agent.py
DELETED
@@ -1,51 +0,0 @@
|
|
1
|
-
from abc import ABC, abstractmethod
|
2
|
-
from typing import Any, ClassVar, Generic
|
3
|
-
|
4
|
-
from pydantic import TypeAdapter
|
5
|
-
|
6
|
-
from .generics_utils import AutoInstanceAttributesMixin
|
7
|
-
from .run_context import CtxT, RunContextWrapper
|
8
|
-
from .typing.io import AgentID, OutT, StateT
|
9
|
-
from .typing.tool import BaseTool
|
10
|
-
|
11
|
-
|
12
|
-
class BaseAgent(AutoInstanceAttributesMixin, ABC, Generic[OutT, StateT, CtxT]):
|
13
|
-
_generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {0: "_out_type"}
|
14
|
-
|
15
|
-
@abstractmethod
|
16
|
-
def __init__(self, agent_id: AgentID, **kwargs: Any) -> None:
|
17
|
-
self._out_type: type[OutT]
|
18
|
-
self._state: StateT
|
19
|
-
|
20
|
-
super().__init__()
|
21
|
-
|
22
|
-
self._agent_id = agent_id
|
23
|
-
self._out_type_adapter: TypeAdapter[OutT] = TypeAdapter(self._out_type)
|
24
|
-
|
25
|
-
@property
|
26
|
-
def out_type(self) -> type[OutT]:
|
27
|
-
return self._out_type
|
28
|
-
|
29
|
-
@property
|
30
|
-
def agent_id(self) -> AgentID:
|
31
|
-
return self._agent_id
|
32
|
-
|
33
|
-
@property
|
34
|
-
def state(self) -> StateT:
|
35
|
-
return self._state
|
36
|
-
|
37
|
-
@abstractmethod
|
38
|
-
async def run(
|
39
|
-
self,
|
40
|
-
chat_inputs: Any | None = None,
|
41
|
-
*,
|
42
|
-
ctx: RunContextWrapper[CtxT] | None = None,
|
43
|
-
**kwargs: Any,
|
44
|
-
) -> Any:
|
45
|
-
pass
|
46
|
-
|
47
|
-
@abstractmethod
|
48
|
-
def as_tool(
|
49
|
-
self, tool_name: str, tool_description: str, tool_strict: bool = True
|
50
|
-
) -> BaseTool[Any, OutT, CtxT]:
|
51
|
-
pass
|
grasp_agents/comm_agent.py
DELETED
@@ -1,217 +0,0 @@
|
|
1
|
-
import logging
|
2
|
-
from abc import abstractmethod
|
3
|
-
from collections.abc import Sequence
|
4
|
-
from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast, final
|
5
|
-
|
6
|
-
from pydantic import BaseModel, TypeAdapter
|
7
|
-
from pydantic.json_schema import SkipJsonSchema
|
8
|
-
|
9
|
-
from .agent_message import AgentMessage
|
10
|
-
from .agent_message_pool import AgentMessagePool
|
11
|
-
from .base_agent import BaseAgent
|
12
|
-
from .run_context import CtxT, RunContextWrapper
|
13
|
-
from .typing.io import AgentID, AgentState, InT, OutT, StateT
|
14
|
-
from .typing.tool import BaseTool
|
15
|
-
|
16
|
-
logger = logging.getLogger(__name__)
|
17
|
-
|
18
|
-
|
19
|
-
class DynCommPayload(BaseModel):
|
20
|
-
selected_recipient_ids: SkipJsonSchema[Sequence[AgentID]]
|
21
|
-
|
22
|
-
|
23
|
-
_EH_OutT = TypeVar("_EH_OutT", contravariant=True) # noqa: PLC0105
|
24
|
-
_EH_StateT = TypeVar("_EH_StateT", bound=AgentState, contravariant=True) # noqa: PLC0105
|
25
|
-
|
26
|
-
|
27
|
-
class ExitHandler(Protocol[_EH_OutT, _EH_StateT, CtxT]):
|
28
|
-
def __call__(
|
29
|
-
self,
|
30
|
-
output_message: AgentMessage[_EH_OutT, _EH_StateT],
|
31
|
-
ctx: RunContextWrapper[CtxT] | None,
|
32
|
-
) -> bool: ...
|
33
|
-
|
34
|
-
|
35
|
-
class CommunicatingAgent(
|
36
|
-
BaseAgent[OutT, StateT, CtxT], Generic[InT, OutT, StateT, CtxT]
|
37
|
-
):
|
38
|
-
_generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
|
39
|
-
0: "_in_type",
|
40
|
-
1: "_out_type",
|
41
|
-
}
|
42
|
-
|
43
|
-
def __init__(
|
44
|
-
self,
|
45
|
-
agent_id: AgentID,
|
46
|
-
*,
|
47
|
-
recipient_ids: Sequence[AgentID] | None = None,
|
48
|
-
message_pool: AgentMessagePool[CtxT] | None = None,
|
49
|
-
**kwargs: Any,
|
50
|
-
) -> None:
|
51
|
-
self._in_type: type[InT]
|
52
|
-
super().__init__(agent_id=agent_id, **kwargs)
|
53
|
-
|
54
|
-
self._in_args_type_adapter: TypeAdapter[InT] = TypeAdapter(self._in_type)
|
55
|
-
self.recipient_ids = recipient_ids or []
|
56
|
-
|
57
|
-
self._message_pool = message_pool or AgentMessagePool()
|
58
|
-
self._is_listening = False
|
59
|
-
self._exit_impl: ExitHandler[OutT, StateT, CtxT] | None = None
|
60
|
-
|
61
|
-
@property
|
62
|
-
def in_type(self) -> type[InT]: # type: ignore
|
63
|
-
# Exposing the type of a contravariant variable only, should be safe
|
64
|
-
return self._in_type
|
65
|
-
|
66
|
-
def _validate_routing(self, payloads: Sequence[OutT]) -> Sequence[AgentID]:
|
67
|
-
if all(isinstance(p, DynCommPayload) for p in payloads):
|
68
|
-
payloads_ = cast("Sequence[DynCommPayload]", payloads)
|
69
|
-
selected_recipient_ids_per_payload = [
|
70
|
-
set(p.selected_recipient_ids or []) for p in payloads_
|
71
|
-
]
|
72
|
-
assert all(
|
73
|
-
x == selected_recipient_ids_per_payload[0]
|
74
|
-
for x in selected_recipient_ids_per_payload
|
75
|
-
), "All payloads must have the same recipient IDs for dynamic routing"
|
76
|
-
|
77
|
-
assert payloads_[0].selected_recipient_ids is not None
|
78
|
-
selected_recipient_ids = payloads_[0].selected_recipient_ids
|
79
|
-
|
80
|
-
assert all(rid in self.recipient_ids for rid in selected_recipient_ids), (
|
81
|
-
"Dynamic routing is enabled, but recipient IDs are not in "
|
82
|
-
"the allowed agent's recipient IDs"
|
83
|
-
)
|
84
|
-
|
85
|
-
return selected_recipient_ids
|
86
|
-
|
87
|
-
if all((not isinstance(p, DynCommPayload)) for p in payloads):
|
88
|
-
return self.recipient_ids
|
89
|
-
|
90
|
-
raise ValueError(
|
91
|
-
"All payloads must be either DCommAgentPayload or not DCommAgentPayload"
|
92
|
-
)
|
93
|
-
|
94
|
-
async def post_message(self, message: AgentMessage[OutT, StateT]) -> None:
|
95
|
-
self._validate_routing(message.payloads)
|
96
|
-
|
97
|
-
await self._message_pool.post(message)
|
98
|
-
|
99
|
-
@abstractmethod
|
100
|
-
async def run(
|
101
|
-
self,
|
102
|
-
chat_inputs: Any | None = None,
|
103
|
-
*,
|
104
|
-
ctx: RunContextWrapper[CtxT] | None = None,
|
105
|
-
in_message: AgentMessage[InT, AgentState] | None = None,
|
106
|
-
entry_point: bool = False,
|
107
|
-
forbid_state_change: bool = False,
|
108
|
-
**kwargs: Any,
|
109
|
-
) -> AgentMessage[OutT, StateT]:
|
110
|
-
pass
|
111
|
-
|
112
|
-
async def run_and_post(
|
113
|
-
self, ctx: RunContextWrapper[CtxT] | None = None, **run_kwargs: Any
|
114
|
-
) -> None:
|
115
|
-
output_message = await self.run(
|
116
|
-
ctx=ctx, in_message=None, entry_point=True, **run_kwargs
|
117
|
-
)
|
118
|
-
await self.post_message(output_message)
|
119
|
-
|
120
|
-
def exit_handler(
|
121
|
-
self, func: ExitHandler[OutT, StateT, CtxT]
|
122
|
-
) -> ExitHandler[OutT, StateT, CtxT]:
|
123
|
-
self._exit_impl = func
|
124
|
-
|
125
|
-
return func
|
126
|
-
|
127
|
-
def _exit_condition(
|
128
|
-
self,
|
129
|
-
output_message: AgentMessage[OutT, StateT],
|
130
|
-
ctx: RunContextWrapper[CtxT] | None,
|
131
|
-
) -> bool:
|
132
|
-
if self._exit_impl:
|
133
|
-
return self._exit_impl(output_message=output_message, ctx=ctx)
|
134
|
-
|
135
|
-
return False
|
136
|
-
|
137
|
-
async def _message_handler(
|
138
|
-
self,
|
139
|
-
message: AgentMessage[Any, AgentState],
|
140
|
-
ctx: RunContextWrapper[CtxT] | None = None,
|
141
|
-
**run_kwargs: Any,
|
142
|
-
) -> None:
|
143
|
-
in_message = cast("AgentMessage[InT, AgentState]", message)
|
144
|
-
out_message = await self.run(ctx=ctx, in_message=in_message, **run_kwargs)
|
145
|
-
|
146
|
-
if self._exit_condition(output_message=out_message, ctx=ctx):
|
147
|
-
await self._message_pool.stop_all()
|
148
|
-
return
|
149
|
-
|
150
|
-
if self.recipient_ids:
|
151
|
-
await self.post_message(out_message)
|
152
|
-
|
153
|
-
@property
|
154
|
-
def is_listening(self) -> bool:
|
155
|
-
return self._is_listening
|
156
|
-
|
157
|
-
async def start_listening(
|
158
|
-
self, ctx: RunContextWrapper[CtxT] | None = None, **run_kwargs: Any
|
159
|
-
) -> None:
|
160
|
-
if self._is_listening:
|
161
|
-
return
|
162
|
-
|
163
|
-
self._is_listening = True
|
164
|
-
self._message_pool.register_message_handler(
|
165
|
-
agent_id=self.agent_id,
|
166
|
-
handler=self._message_handler,
|
167
|
-
ctx=ctx,
|
168
|
-
**run_kwargs,
|
169
|
-
)
|
170
|
-
|
171
|
-
async def stop_listening(self) -> None:
|
172
|
-
self._is_listening = False
|
173
|
-
await self._message_pool.unregister_message_handler(self.agent_id)
|
174
|
-
|
175
|
-
@final
|
176
|
-
def as_tool(
|
177
|
-
self,
|
178
|
-
tool_name: str,
|
179
|
-
tool_description: str,
|
180
|
-
tool_strict: bool = True,
|
181
|
-
) -> BaseTool[InT, OutT, Any]: # type: ignore[override]
|
182
|
-
# Will check if InT is a BaseModel at runtime
|
183
|
-
agent_instance = self
|
184
|
-
in_type = agent_instance.in_type
|
185
|
-
out_type = agent_instance.out_type
|
186
|
-
if not issubclass(in_type, BaseModel):
|
187
|
-
raise TypeError(
|
188
|
-
"Cannot create a tool from an agent with "
|
189
|
-
f"non-BaseModel input type: {in_type}"
|
190
|
-
)
|
191
|
-
|
192
|
-
class AgentTool(BaseTool[in_type, out_type, Any]):
|
193
|
-
name: str = tool_name
|
194
|
-
description: str = tool_description
|
195
|
-
strict: bool | None = tool_strict
|
196
|
-
|
197
|
-
async def run(
|
198
|
-
self,
|
199
|
-
inp: InT,
|
200
|
-
ctx: RunContextWrapper[CtxT] | None = None,
|
201
|
-
) -> OutT:
|
202
|
-
in_args = in_type.model_validate(inp)
|
203
|
-
in_message = AgentMessage[in_type, AgentState](
|
204
|
-
payloads=[in_args],
|
205
|
-
sender_id="<tool_user>",
|
206
|
-
recipient_ids=[agent_instance.agent_id],
|
207
|
-
)
|
208
|
-
agent_result = await agent_instance.run(
|
209
|
-
in_message=in_message,
|
210
|
-
entry_point=False,
|
211
|
-
forbid_state_change=True,
|
212
|
-
ctx=ctx,
|
213
|
-
)
|
214
|
-
|
215
|
-
return agent_result.payloads[0]
|
216
|
-
|
217
|
-
return AgentTool() # type: ignore[return-value]
|
grasp_agents/llm_agent_state.py
DELETED
@@ -1,79 +0,0 @@
|
|
1
|
-
from copy import deepcopy
|
2
|
-
from typing import Any, Literal, Optional, Protocol
|
3
|
-
|
4
|
-
from pydantic import ConfigDict, Field
|
5
|
-
|
6
|
-
from .memory import MessageHistory
|
7
|
-
from .run_context import RunContextWrapper
|
8
|
-
from .typing.io import AgentState, LLMPrompt
|
9
|
-
|
10
|
-
SetAgentStateStrategy = Literal["keep", "reset", "from_sender", "custom"]
|
11
|
-
|
12
|
-
|
13
|
-
class SetAgentState(Protocol):
|
14
|
-
def __call__(
|
15
|
-
self,
|
16
|
-
cur_state: "LLMAgentState",
|
17
|
-
*,
|
18
|
-
in_state: AgentState | None,
|
19
|
-
sys_prompt: LLMPrompt | None,
|
20
|
-
ctx: RunContextWrapper[Any] | None,
|
21
|
-
) -> "LLMAgentState": ...
|
22
|
-
|
23
|
-
|
24
|
-
class LLMAgentState(AgentState):
|
25
|
-
message_history: MessageHistory = Field(default_factory=MessageHistory)
|
26
|
-
|
27
|
-
@property
|
28
|
-
def batch_size(self) -> int:
|
29
|
-
return self.message_history.batch_size
|
30
|
-
|
31
|
-
@classmethod
|
32
|
-
def from_cur_and_in_states(
|
33
|
-
cls,
|
34
|
-
cur_state: "LLMAgentState",
|
35
|
-
*,
|
36
|
-
in_state: Optional["AgentState"] = None,
|
37
|
-
sys_prompt: LLMPrompt | None = None,
|
38
|
-
strategy: SetAgentStateStrategy = "from_sender",
|
39
|
-
set_agent_state_impl: SetAgentState | None = None,
|
40
|
-
ctx: RunContextWrapper[Any] | None = None,
|
41
|
-
) -> "LLMAgentState":
|
42
|
-
upd_mh = cur_state.message_history if cur_state else None
|
43
|
-
if upd_mh is None or len(upd_mh) == 0:
|
44
|
-
upd_mh = MessageHistory(sys_prompt=sys_prompt)
|
45
|
-
|
46
|
-
if strategy == "keep":
|
47
|
-
pass
|
48
|
-
|
49
|
-
elif strategy == "reset":
|
50
|
-
upd_mh.reset(sys_prompt)
|
51
|
-
|
52
|
-
elif strategy == "from_sender":
|
53
|
-
in_mh = (
|
54
|
-
in_state.message_history
|
55
|
-
if in_state and isinstance(in_state, "LLMAgentState")
|
56
|
-
else None
|
57
|
-
)
|
58
|
-
if in_mh:
|
59
|
-
in_mh = deepcopy(in_mh)
|
60
|
-
else:
|
61
|
-
upd_mh.reset(sys_prompt)
|
62
|
-
|
63
|
-
elif strategy == "custom":
|
64
|
-
assert set_agent_state_impl is not None, (
|
65
|
-
"Agent state setter implementation is not provided."
|
66
|
-
)
|
67
|
-
return set_agent_state_impl(
|
68
|
-
cur_state=cur_state,
|
69
|
-
in_state=in_state,
|
70
|
-
sys_prompt=sys_prompt,
|
71
|
-
ctx=ctx,
|
72
|
-
)
|
73
|
-
|
74
|
-
return cls.model_construct(message_history=upd_mh)
|
75
|
-
|
76
|
-
def __repr__(self) -> str:
|
77
|
-
return f"Message History: {len(self.message_history)}"
|
78
|
-
|
79
|
-
model_config = ConfigDict(arbitrary_types_allowed=True)
|
@@ -1,203 +0,0 @@
|
|
1
|
-
import asyncio
|
2
|
-
import json
|
3
|
-
from collections.abc import Coroutine, Sequence
|
4
|
-
from logging import getLogger
|
5
|
-
from typing import Any, Generic, Protocol
|
6
|
-
|
7
|
-
from pydantic import BaseModel
|
8
|
-
|
9
|
-
from .llm import LLM, LLMSettings
|
10
|
-
from .llm_agent_state import LLMAgentState
|
11
|
-
from .run_context import CtxT, RunContextWrapper
|
12
|
-
from .typing.converters import Converters
|
13
|
-
from .typing.message import AssistantMessage, Conversation, Message, ToolMessage
|
14
|
-
from .typing.tool import BaseTool, ToolCall, ToolChoice
|
15
|
-
|
16
|
-
logger = getLogger(__name__)
|
17
|
-
|
18
|
-
|
19
|
-
class ExitToolCallLoopHandler(Protocol[CtxT]):
|
20
|
-
def __call__(
|
21
|
-
self,
|
22
|
-
conversation: Conversation,
|
23
|
-
*,
|
24
|
-
ctx: RunContextWrapper[CtxT] | None,
|
25
|
-
**kwargs: Any,
|
26
|
-
) -> bool: ...
|
27
|
-
|
28
|
-
|
29
|
-
class ManageAgentStateHandler(Protocol[CtxT]):
|
30
|
-
def __call__(
|
31
|
-
self,
|
32
|
-
state: LLMAgentState,
|
33
|
-
*,
|
34
|
-
ctx: RunContextWrapper[CtxT] | None,
|
35
|
-
**kwargs: Any,
|
36
|
-
) -> None: ...
|
37
|
-
|
38
|
-
|
39
|
-
class ToolOrchestrator(Generic[CtxT]):
|
40
|
-
def __init__(
|
41
|
-
self,
|
42
|
-
agent_id: str,
|
43
|
-
llm: LLM[LLMSettings, Converters],
|
44
|
-
tools: list[BaseTool[BaseModel, Any, CtxT]] | None,
|
45
|
-
max_turns: int,
|
46
|
-
react_mode: bool = False,
|
47
|
-
) -> None:
|
48
|
-
self._agent_id = agent_id
|
49
|
-
|
50
|
-
self._llm = llm
|
51
|
-
self._llm.tools = tools
|
52
|
-
|
53
|
-
self._max_turns = max_turns
|
54
|
-
self._react_mode = react_mode
|
55
|
-
|
56
|
-
self.exit_tool_call_loop_impl: ExitToolCallLoopHandler[CtxT] | None = None
|
57
|
-
self.manage_agent_state_impl: ManageAgentStateHandler[CtxT] | None = None
|
58
|
-
|
59
|
-
@property
|
60
|
-
def agent_id(self) -> str:
|
61
|
-
return self._agent_id
|
62
|
-
|
63
|
-
@property
|
64
|
-
def llm(self) -> LLM[LLMSettings, Converters]:
|
65
|
-
return self._llm
|
66
|
-
|
67
|
-
@property
|
68
|
-
def tools(self) -> dict[str, BaseTool[BaseModel, Any, CtxT]]:
|
69
|
-
return self._llm.tools or {}
|
70
|
-
|
71
|
-
@property
|
72
|
-
def max_turns(self) -> int:
|
73
|
-
return self._max_turns
|
74
|
-
|
75
|
-
def _exit_tool_call_loop(
|
76
|
-
self,
|
77
|
-
conversation: Conversation,
|
78
|
-
*,
|
79
|
-
ctx: RunContextWrapper[CtxT] | None = None,
|
80
|
-
**kwargs: Any,
|
81
|
-
) -> bool:
|
82
|
-
if self.exit_tool_call_loop_impl:
|
83
|
-
return self.exit_tool_call_loop_impl(
|
84
|
-
conversation=conversation, ctx=ctx, **kwargs
|
85
|
-
)
|
86
|
-
|
87
|
-
assert conversation, "Conversation must not be empty"
|
88
|
-
assert isinstance(conversation[-1], AssistantMessage), (
|
89
|
-
"Last message in conversation must be an AssistantMessage"
|
90
|
-
)
|
91
|
-
|
92
|
-
return not bool(conversation[-1].tool_calls)
|
93
|
-
|
94
|
-
def _manage_agent_state(
|
95
|
-
self,
|
96
|
-
state: LLMAgentState,
|
97
|
-
*,
|
98
|
-
ctx: RunContextWrapper[CtxT] | None = None,
|
99
|
-
**kwargs: Any,
|
100
|
-
) -> None:
|
101
|
-
if self.manage_agent_state_impl:
|
102
|
-
self.manage_agent_state_impl(state=state, ctx=ctx, **kwargs)
|
103
|
-
|
104
|
-
async def generate_once(
|
105
|
-
self,
|
106
|
-
state: LLMAgentState,
|
107
|
-
tool_choice: ToolChoice | None = None,
|
108
|
-
ctx: RunContextWrapper[CtxT] | None = None,
|
109
|
-
) -> Sequence[AssistantMessage]:
|
110
|
-
message_history = state.message_history
|
111
|
-
message_batch = await self.llm.generate_message_batch(
|
112
|
-
message_history, tool_choice=tool_choice
|
113
|
-
)
|
114
|
-
message_history.add_message_batch(message_batch)
|
115
|
-
|
116
|
-
self._print_messages_and_track_usage(message_batch, ctx=ctx)
|
117
|
-
|
118
|
-
return message_batch
|
119
|
-
|
120
|
-
async def run_loop(
|
121
|
-
self,
|
122
|
-
state: LLMAgentState,
|
123
|
-
ctx: RunContextWrapper[CtxT] | None = None,
|
124
|
-
) -> None:
|
125
|
-
message_history = state.message_history
|
126
|
-
assert message_history.batch_size == 1, (
|
127
|
-
"Batch size must be 1 for tool call loop"
|
128
|
-
)
|
129
|
-
|
130
|
-
tool_choice: ToolChoice
|
131
|
-
|
132
|
-
tool_choice = "none" if self._react_mode else "auto"
|
133
|
-
gen_message_batch = await self.generate_once(
|
134
|
-
state, tool_choice=tool_choice, ctx=ctx
|
135
|
-
)
|
136
|
-
|
137
|
-
turns = 0
|
138
|
-
|
139
|
-
while True:
|
140
|
-
self._manage_agent_state(state=state, ctx=ctx, num_turns=turns)
|
141
|
-
|
142
|
-
if self._exit_tool_call_loop(
|
143
|
-
message_history.batched_conversations[0], ctx=ctx, num_turns=turns
|
144
|
-
):
|
145
|
-
return
|
146
|
-
if turns >= self.max_turns:
|
147
|
-
logger.info(
|
148
|
-
f"Max turns reached: {self.max_turns}. Stopping tool call loop."
|
149
|
-
)
|
150
|
-
return
|
151
|
-
|
152
|
-
msg = gen_message_batch[0]
|
153
|
-
if msg.tool_calls:
|
154
|
-
tool_messages = await self.call_tools(msg.tool_calls, ctx=ctx)
|
155
|
-
message_history.add_messages(tool_messages)
|
156
|
-
|
157
|
-
tool_choice = "none" if (self._react_mode and msg.tool_calls) else "auto"
|
158
|
-
gen_message_batch = await self.generate_once(
|
159
|
-
state, tool_choice=tool_choice, ctx=ctx
|
160
|
-
)
|
161
|
-
|
162
|
-
turns += 1
|
163
|
-
|
164
|
-
async def call_tools(
|
165
|
-
self,
|
166
|
-
calls: Sequence[ToolCall],
|
167
|
-
ctx: RunContextWrapper[CtxT] | None = None,
|
168
|
-
) -> Sequence[ToolMessage]:
|
169
|
-
corouts: list[Coroutine[Any, Any, BaseModel]] = []
|
170
|
-
for call in calls:
|
171
|
-
tool = self.tools[call.tool_name]
|
172
|
-
args = json.loads(call.tool_arguments)
|
173
|
-
corouts.append(tool(ctx=ctx, **args))
|
174
|
-
|
175
|
-
outs = await asyncio.gather(*corouts)
|
176
|
-
|
177
|
-
tool_messages = [
|
178
|
-
ToolMessage.from_tool_output(out, call, model_id=self.agent_id)
|
179
|
-
for out, call in zip(outs, calls, strict=False)
|
180
|
-
]
|
181
|
-
|
182
|
-
self._print_messages(tool_messages, ctx=ctx)
|
183
|
-
|
184
|
-
return tool_messages
|
185
|
-
|
186
|
-
def _print_messages(
|
187
|
-
self,
|
188
|
-
message_batch: Sequence[Message],
|
189
|
-
ctx: RunContextWrapper[CtxT] | None = None,
|
190
|
-
) -> None:
|
191
|
-
if ctx:
|
192
|
-
ctx.printer.print_llm_messages(message_batch, agent_id=self.agent_id)
|
193
|
-
|
194
|
-
def _print_messages_and_track_usage(
|
195
|
-
self,
|
196
|
-
message_batch: Sequence[AssistantMessage],
|
197
|
-
ctx: RunContextWrapper[CtxT] | None = None,
|
198
|
-
) -> None:
|
199
|
-
if ctx:
|
200
|
-
self._print_messages(message_batch, ctx=ctx)
|
201
|
-
ctx.usage_tracker.update(
|
202
|
-
messages=message_batch, model_name=self.llm.model_name
|
203
|
-
)
|