grasp_agents 0.2.10__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_agents/__init__.py +15 -14
- grasp_agents/cloud_llm.py +118 -131
- grasp_agents/comm_processor.py +201 -0
- grasp_agents/generics_utils.py +15 -7
- grasp_agents/llm.py +60 -31
- grasp_agents/llm_agent.py +229 -278
- grasp_agents/llm_agent_memory.py +58 -0
- grasp_agents/llm_policy_executor.py +482 -0
- grasp_agents/memory.py +20 -134
- grasp_agents/message_history.py +140 -0
- grasp_agents/openai/__init__.py +54 -36
- grasp_agents/openai/completion_chunk_converters.py +78 -0
- grasp_agents/openai/completion_converters.py +53 -30
- grasp_agents/openai/content_converters.py +13 -14
- grasp_agents/openai/converters.py +44 -68
- grasp_agents/openai/message_converters.py +58 -72
- grasp_agents/openai/openai_llm.py +101 -42
- grasp_agents/openai/tool_converters.py +24 -19
- grasp_agents/packet.py +24 -0
- grasp_agents/packet_pool.py +91 -0
- grasp_agents/printer.py +29 -15
- grasp_agents/processor.py +194 -0
- grasp_agents/prompt_builder.py +173 -176
- grasp_agents/run_context.py +21 -41
- grasp_agents/typing/completion.py +58 -12
- grasp_agents/typing/completion_chunk.py +173 -0
- grasp_agents/typing/converters.py +8 -12
- grasp_agents/typing/events.py +86 -0
- grasp_agents/typing/io.py +4 -13
- grasp_agents/typing/message.py +12 -50
- grasp_agents/typing/tool.py +52 -26
- grasp_agents/usage_tracker.py +6 -6
- grasp_agents/utils.py +3 -3
- grasp_agents/workflow/looped_workflow.py +132 -0
- grasp_agents/workflow/parallel_processor.py +95 -0
- grasp_agents/workflow/sequential_workflow.py +66 -0
- grasp_agents/workflow/workflow_processor.py +78 -0
- {grasp_agents-0.2.10.dist-info → grasp_agents-0.3.1.dist-info}/METADATA +41 -50
- grasp_agents-0.3.1.dist-info/RECORD +51 -0
- grasp_agents/agent_message.py +0 -27
- grasp_agents/agent_message_pool.py +0 -92
- grasp_agents/base_agent.py +0 -51
- grasp_agents/comm_agent.py +0 -217
- grasp_agents/llm_agent_state.py +0 -79
- grasp_agents/tool_orchestrator.py +0 -203
- grasp_agents/workflow/looped_agent.py +0 -120
- grasp_agents/workflow/sequential_agent.py +0 -63
- grasp_agents/workflow/workflow_agent.py +0 -73
- grasp_agents-0.2.10.dist-info/RECORD +0 -46
- {grasp_agents-0.2.10.dist-info → grasp_agents-0.3.1.dist-info}/WHEEL +0 -0
- {grasp_agents-0.2.10.dist-info → grasp_agents-0.3.1.dist-info}/licenses/LICENSE.md +0 -0
@@ -0,0 +1,58 @@
|
|
1
|
+
from collections.abc import Sequence
|
2
|
+
from typing import Any, Protocol
|
3
|
+
|
4
|
+
from pydantic import Field
|
5
|
+
|
6
|
+
from .memory import Memory
|
7
|
+
from .message_history import MessageHistory
|
8
|
+
from .run_context import RunContext
|
9
|
+
from .typing.io import LLMPrompt
|
10
|
+
from .typing.message import Message
|
11
|
+
|
12
|
+
|
13
|
+
class SetMemoryHandler(Protocol):
|
14
|
+
def __call__(
|
15
|
+
self,
|
16
|
+
prev_memory: "LLMAgentMemory",
|
17
|
+
in_args: Any | None,
|
18
|
+
sys_prompt: LLMPrompt | None,
|
19
|
+
ctx: RunContext[Any] | None,
|
20
|
+
) -> "LLMAgentMemory": ...
|
21
|
+
|
22
|
+
|
23
|
+
class LLMAgentMemory(Memory):
|
24
|
+
message_history: MessageHistory = Field(default_factory=MessageHistory)
|
25
|
+
|
26
|
+
def reset(
|
27
|
+
self, sys_prompt: LLMPrompt | None = None, ctx: RunContext[Any] | None = None
|
28
|
+
):
|
29
|
+
self.message_history.reset(sys_prompt=sys_prompt)
|
30
|
+
|
31
|
+
def update(
|
32
|
+
self,
|
33
|
+
*,
|
34
|
+
message_batch: Sequence[Message] | None = None,
|
35
|
+
message_list: Sequence[Message] | None = None,
|
36
|
+
ctx: RunContext[Any] | None = None,
|
37
|
+
):
|
38
|
+
if message_batch is not None and message_list is not None:
|
39
|
+
raise ValueError(
|
40
|
+
"Only one of message_batch or messages should be provided."
|
41
|
+
)
|
42
|
+
if message_batch is not None:
|
43
|
+
self.message_history.add_message_batch(message_batch)
|
44
|
+
elif message_list is not None:
|
45
|
+
self.message_history.add_message_list(message_list)
|
46
|
+
else:
|
47
|
+
raise ValueError("Either message_batch or messages must be provided.")
|
48
|
+
|
49
|
+
@property
|
50
|
+
def is_empty(self) -> bool:
|
51
|
+
return len(self.message_history) == 0
|
52
|
+
|
53
|
+
@property
|
54
|
+
def batch_size(self) -> int:
|
55
|
+
return self.message_history.batch_size
|
56
|
+
|
57
|
+
def __repr__(self) -> str:
|
58
|
+
return f"Message History: {len(self.message_history)}"
|
@@ -0,0 +1,482 @@
|
|
1
|
+
import asyncio
|
2
|
+
import json
|
3
|
+
from collections.abc import AsyncIterator, Coroutine, Sequence
|
4
|
+
from itertools import starmap
|
5
|
+
from logging import getLogger
|
6
|
+
from typing import Any, ClassVar, Generic, Protocol, TypeVar
|
7
|
+
|
8
|
+
from pydantic import BaseModel
|
9
|
+
|
10
|
+
from .generics_utils import AutoInstanceAttributesMixin
|
11
|
+
from .llm import LLM, LLMSettings
|
12
|
+
from .llm_agent_memory import LLMAgentMemory
|
13
|
+
from .run_context import CtxT, RunContext
|
14
|
+
from .typing.completion import Completion
|
15
|
+
from .typing.converters import Converters
|
16
|
+
from .typing.events import (
|
17
|
+
CompletionChunkEvent,
|
18
|
+
CompletionEvent,
|
19
|
+
Event,
|
20
|
+
GenMessageEvent,
|
21
|
+
ToolCallEvent,
|
22
|
+
ToolMessageEvent,
|
23
|
+
UserMessageEvent,
|
24
|
+
)
|
25
|
+
from .typing.message import AssistantMessage, Messages, ToolMessage, UserMessage
|
26
|
+
from .typing.tool import BaseTool, NamedToolChoice, ToolCall, ToolChoice
|
27
|
+
|
28
|
+
logger = getLogger(__name__)
|
29
|
+
|
30
|
+
|
31
|
+
_FinalAnswerT = TypeVar("_FinalAnswerT")
|
32
|
+
|
33
|
+
|
34
|
+
class ExitToolCallLoopHandler(Protocol[CtxT]):
|
35
|
+
def __call__(
|
36
|
+
self,
|
37
|
+
conversation: Messages,
|
38
|
+
*,
|
39
|
+
ctx: RunContext[CtxT] | None,
|
40
|
+
**kwargs: Any,
|
41
|
+
) -> bool: ...
|
42
|
+
|
43
|
+
|
44
|
+
class ManageMemoryHandler(Protocol[CtxT]):
|
45
|
+
def __call__(
|
46
|
+
self,
|
47
|
+
memory: LLMAgentMemory,
|
48
|
+
*,
|
49
|
+
ctx: RunContext[CtxT] | None,
|
50
|
+
**kwargs: Any,
|
51
|
+
) -> None: ...
|
52
|
+
|
53
|
+
|
54
|
+
class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT]):
|
55
|
+
_generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
|
56
|
+
0: "_final_answer_type",
|
57
|
+
}
|
58
|
+
|
59
|
+
def __init__(
|
60
|
+
self,
|
61
|
+
agent_name: str,
|
62
|
+
llm: LLM[LLMSettings, Converters],
|
63
|
+
tools: list[BaseTool[BaseModel, Any, CtxT]] | None,
|
64
|
+
max_turns: int,
|
65
|
+
react_mode: bool = False,
|
66
|
+
final_answer_as_tool_call: bool = False,
|
67
|
+
) -> None:
|
68
|
+
self._final_answer_type: type[_FinalAnswerT]
|
69
|
+
super().__init__()
|
70
|
+
|
71
|
+
self._agent_name = agent_name
|
72
|
+
|
73
|
+
_tools: list[BaseTool[BaseModel, Any, CtxT]] | None = tools
|
74
|
+
self._final_answer_tool_name: str | None = None
|
75
|
+
if tools and final_answer_as_tool_call:
|
76
|
+
final_answer_tool = self.get_final_answer_tool()
|
77
|
+
self._final_answer_tool_name = final_answer_tool.name
|
78
|
+
_tools = tools + [final_answer_tool]
|
79
|
+
|
80
|
+
self._llm = llm
|
81
|
+
self._llm.tools = _tools
|
82
|
+
|
83
|
+
self._max_turns = max_turns
|
84
|
+
self._react_mode = react_mode
|
85
|
+
|
86
|
+
self.exit_tool_call_loop_impl: ExitToolCallLoopHandler[CtxT] | None = None
|
87
|
+
self.manage_memory_impl: ManageMemoryHandler[CtxT] | None = None
|
88
|
+
|
89
|
+
@property
|
90
|
+
def agent_name(self) -> str:
|
91
|
+
return self._agent_name
|
92
|
+
|
93
|
+
@property
|
94
|
+
def llm(self) -> LLM[LLMSettings, Converters]:
|
95
|
+
return self._llm
|
96
|
+
|
97
|
+
@property
|
98
|
+
def tools(self) -> dict[str, BaseTool[BaseModel, Any, CtxT]]:
|
99
|
+
return self._llm.tools or {}
|
100
|
+
|
101
|
+
@property
|
102
|
+
def max_turns(self) -> int:
|
103
|
+
return self._max_turns
|
104
|
+
|
105
|
+
def _exit_tool_call_loop_fn(
|
106
|
+
self,
|
107
|
+
conversation: Messages,
|
108
|
+
*,
|
109
|
+
ctx: RunContext[CtxT] | None = None,
|
110
|
+
**kwargs: Any,
|
111
|
+
) -> bool:
|
112
|
+
if self.exit_tool_call_loop_impl:
|
113
|
+
return self.exit_tool_call_loop_impl(conversation, ctx=ctx, **kwargs)
|
114
|
+
|
115
|
+
assert conversation, "Conversation must not be empty"
|
116
|
+
assert isinstance(conversation[-1], AssistantMessage), (
|
117
|
+
"Last message in conversation must be an AssistantMessage"
|
118
|
+
)
|
119
|
+
|
120
|
+
return not bool(conversation[-1].tool_calls)
|
121
|
+
|
122
|
+
def _manage_memory_fn(
|
123
|
+
self,
|
124
|
+
memory: LLMAgentMemory,
|
125
|
+
*,
|
126
|
+
ctx: RunContext[CtxT] | None = None,
|
127
|
+
**kwargs: Any,
|
128
|
+
) -> None:
|
129
|
+
if self.manage_memory_impl:
|
130
|
+
self.manage_memory_impl(memory=memory, ctx=ctx, **kwargs)
|
131
|
+
|
132
|
+
async def generate_message_batch(
|
133
|
+
self,
|
134
|
+
memory: LLMAgentMemory,
|
135
|
+
tool_choice: ToolChoice | None = None,
|
136
|
+
ctx: RunContext[CtxT] | None = None,
|
137
|
+
) -> Sequence[AssistantMessage]:
|
138
|
+
completion_batch = await self.llm.generate_completion_batch(
|
139
|
+
memory.message_history, tool_choice=tool_choice
|
140
|
+
)
|
141
|
+
if (
|
142
|
+
len(completion_batch[0].messages) > 1
|
143
|
+
and memory.message_history.batch_size > 1
|
144
|
+
):
|
145
|
+
raise ValueError(
|
146
|
+
"Batch size must be 1 when generating completions with n>1."
|
147
|
+
)
|
148
|
+
message_batch = [c.messages[0] for c in completion_batch]
|
149
|
+
memory.update(message_batch=message_batch)
|
150
|
+
|
151
|
+
if ctx is not None:
|
152
|
+
ctx.completions[self.agent_name].extend(completion_batch)
|
153
|
+
self._track_usage(completion_batch, ctx=ctx)
|
154
|
+
self._print_completions(completion_batch, ctx=ctx)
|
155
|
+
|
156
|
+
return message_batch
|
157
|
+
|
158
|
+
async def generate_message_stream(
|
159
|
+
self,
|
160
|
+
memory: LLMAgentMemory,
|
161
|
+
tool_choice: ToolChoice | None = None,
|
162
|
+
ctx: RunContext[CtxT] | None = None,
|
163
|
+
) -> AsyncIterator[CompletionChunkEvent | CompletionEvent | GenMessageEvent]:
|
164
|
+
message_hist = memory.message_history
|
165
|
+
if memory.message_history.batch_size > 1:
|
166
|
+
raise ValueError("Batch size must be 1 when streaming completions.")
|
167
|
+
conversation = message_hist.conversations[0]
|
168
|
+
|
169
|
+
completion: Completion | None = None
|
170
|
+
async for event in await self.llm.generate_completion_stream(
|
171
|
+
conversation, tool_choice=tool_choice
|
172
|
+
):
|
173
|
+
yield event
|
174
|
+
if isinstance(event, CompletionEvent):
|
175
|
+
completion = event.data
|
176
|
+
|
177
|
+
if completion is None:
|
178
|
+
raise RuntimeError("No completion generated during stream.")
|
179
|
+
if len(completion.messages) > 1:
|
180
|
+
raise ValueError("Streaming completion must have n=1")
|
181
|
+
|
182
|
+
message = completion.messages[0]
|
183
|
+
memory.update(message_batch=[message])
|
184
|
+
|
185
|
+
yield GenMessageEvent(name=self.agent_name, data=message)
|
186
|
+
|
187
|
+
if ctx is not None:
|
188
|
+
self._track_usage([completion], ctx=ctx)
|
189
|
+
ctx.completions[self.agent_name].append(completion)
|
190
|
+
|
191
|
+
async def call_tools(
|
192
|
+
self,
|
193
|
+
calls: Sequence[ToolCall],
|
194
|
+
memory: LLMAgentMemory,
|
195
|
+
ctx: RunContext[CtxT] | None = None,
|
196
|
+
) -> Sequence[ToolMessage]:
|
197
|
+
corouts: list[Coroutine[Any, Any, BaseModel]] = []
|
198
|
+
for call in calls:
|
199
|
+
tool = self.tools[call.tool_name]
|
200
|
+
args = json.loads(call.tool_arguments)
|
201
|
+
corouts.append(tool(ctx=ctx, **args))
|
202
|
+
|
203
|
+
outs = await asyncio.gather(*corouts)
|
204
|
+
tool_messages = list(
|
205
|
+
starmap(ToolMessage.from_tool_output, zip(outs, calls, strict=False))
|
206
|
+
)
|
207
|
+
memory.update(message_list=tool_messages)
|
208
|
+
|
209
|
+
if ctx is not None:
|
210
|
+
ctx.printer.print_llm_messages(tool_messages, agent_name=self.agent_name)
|
211
|
+
|
212
|
+
return tool_messages
|
213
|
+
|
214
|
+
async def call_tools_stream(
|
215
|
+
self,
|
216
|
+
calls: Sequence[ToolCall],
|
217
|
+
memory: LLMAgentMemory,
|
218
|
+
ctx: RunContext[CtxT] | None = None,
|
219
|
+
) -> AsyncIterator[ToolMessageEvent]:
|
220
|
+
tool_messages = await self.call_tools(calls, memory=memory, ctx=ctx)
|
221
|
+
for tool_message, call in zip(tool_messages, calls, strict=False):
|
222
|
+
yield ToolMessageEvent(name=call.tool_name, data=tool_message)
|
223
|
+
|
224
|
+
def _extract_final_answer_from_tool_calls(
|
225
|
+
self, gen_message: AssistantMessage, memory: LLMAgentMemory
|
226
|
+
) -> AssistantMessage | None:
|
227
|
+
final_answer_message: AssistantMessage | None = None
|
228
|
+
for tool_call in gen_message.tool_calls or []:
|
229
|
+
if tool_call.tool_name == self._final_answer_tool_name:
|
230
|
+
final_answer_message = AssistantMessage(
|
231
|
+
name=self.agent_name, content=tool_call.tool_arguments
|
232
|
+
)
|
233
|
+
gen_message.tool_calls = None
|
234
|
+
memory.update(message_list=[final_answer_message])
|
235
|
+
return final_answer_message
|
236
|
+
|
237
|
+
return final_answer_message
|
238
|
+
|
239
|
+
async def _generate_final_answer(
|
240
|
+
self, memory: LLMAgentMemory, ctx: RunContext[CtxT] | None = None
|
241
|
+
) -> AssistantMessage:
|
242
|
+
assert self._final_answer_tool_name is not None
|
243
|
+
|
244
|
+
user_message = UserMessage.from_text(
|
245
|
+
"Exceeded the maximum number of turns: provide a final answer now!"
|
246
|
+
)
|
247
|
+
memory.update(message_list=[user_message])
|
248
|
+
if ctx is not None:
|
249
|
+
ctx.printer.print_llm_messages([user_message], agent_name=self.agent_name)
|
250
|
+
|
251
|
+
tool_choice = NamedToolChoice(name=self._final_answer_tool_name)
|
252
|
+
gen_message = (
|
253
|
+
await self.generate_message_batch(memory, tool_choice=tool_choice, ctx=ctx)
|
254
|
+
)[0]
|
255
|
+
|
256
|
+
final_answer_message = self._extract_final_answer_from_tool_calls(
|
257
|
+
gen_message, memory=memory
|
258
|
+
)
|
259
|
+
if final_answer_message is None:
|
260
|
+
raise RuntimeError(
|
261
|
+
"Final answer tool call did not return a final answer message."
|
262
|
+
)
|
263
|
+
return final_answer_message
|
264
|
+
|
265
|
+
async def _generate_final_answer_stream(
|
266
|
+
self, memory: LLMAgentMemory, ctx: RunContext[CtxT] | None = None
|
267
|
+
) -> AsyncIterator[Event[Any]]:
|
268
|
+
assert self._final_answer_tool_name is not None
|
269
|
+
|
270
|
+
user_message = UserMessage.from_text(
|
271
|
+
"Exceeded the maximum number of turns: provide a final answer now!",
|
272
|
+
)
|
273
|
+
memory.update(message_list=[user_message])
|
274
|
+
yield UserMessageEvent(name=self.agent_name, data=user_message)
|
275
|
+
|
276
|
+
tool_choice = NamedToolChoice(name=self._final_answer_tool_name)
|
277
|
+
event: Event[Any] | None = None
|
278
|
+
async for event in self.generate_message_stream(
|
279
|
+
memory, tool_choice=tool_choice, ctx=ctx
|
280
|
+
):
|
281
|
+
yield event
|
282
|
+
|
283
|
+
assert isinstance(event, GenMessageEvent)
|
284
|
+
gen_message = event.data
|
285
|
+
final_answer_message = self._extract_final_answer_from_tool_calls(
|
286
|
+
gen_message, memory=memory
|
287
|
+
)
|
288
|
+
if final_answer_message is None:
|
289
|
+
raise RuntimeError(
|
290
|
+
"Final answer tool call did not return a final answer message."
|
291
|
+
)
|
292
|
+
yield GenMessageEvent(name=self.agent_name, data=final_answer_message)
|
293
|
+
|
294
|
+
async def execute(
|
295
|
+
self, memory: LLMAgentMemory, ctx: RunContext[CtxT] | None = None
|
296
|
+
) -> AssistantMessage | Sequence[AssistantMessage]:
|
297
|
+
# 1. Generate the first message:
|
298
|
+
# In ReAct mode, we generate the first message without tool calls
|
299
|
+
# to force the agent to plan its actions in a separate message.
|
300
|
+
tool_choice: ToolChoice | None = None
|
301
|
+
if self.tools:
|
302
|
+
tool_choice = "none" if self._react_mode else "auto"
|
303
|
+
gen_message_batch = await self.generate_message_batch(
|
304
|
+
memory, tool_choice=tool_choice, ctx=ctx
|
305
|
+
)
|
306
|
+
if not self.tools:
|
307
|
+
return gen_message_batch
|
308
|
+
|
309
|
+
if memory.message_history.batch_size > 1:
|
310
|
+
raise ValueError("Batch size must be 1 for tool call loop.")
|
311
|
+
gen_message = gen_message_batch[0]
|
312
|
+
turns = 0
|
313
|
+
|
314
|
+
while True:
|
315
|
+
conversation = memory.message_history.conversations[0]
|
316
|
+
|
317
|
+
# 2. Check if we should exit the tool call loop
|
318
|
+
|
319
|
+
# When final_answer_tool_name is None, we use exit_tool_call_loop_impl
|
320
|
+
# to determine whether to exit the loop.
|
321
|
+
if self._final_answer_tool_name is None and self._exit_tool_call_loop_fn(
|
322
|
+
conversation, ctx=ctx, num_turns=turns
|
323
|
+
):
|
324
|
+
return gen_message
|
325
|
+
|
326
|
+
# When final_answer_tool_name is set, we check if the last message contains
|
327
|
+
# a tool call to the final answer tool. If it does, we exit the loop.
|
328
|
+
if self._final_answer_tool_name is not None:
|
329
|
+
final_answer = self._extract_final_answer_from_tool_calls(
|
330
|
+
gen_message, memory=memory
|
331
|
+
)
|
332
|
+
if final_answer is not None:
|
333
|
+
return final_answer
|
334
|
+
|
335
|
+
# Exit if the maximum number of turns is reached
|
336
|
+
if turns >= self.max_turns:
|
337
|
+
# When final_answer_tool_name is set, we force the agent to provide
|
338
|
+
# a final answer by generating a message with a final answer
|
339
|
+
# tool call.
|
340
|
+
# Otherwise, we simply return the last generated message.
|
341
|
+
if self._final_answer_tool_name is not None:
|
342
|
+
final_answer = await self._generate_final_answer(memory, ctx=ctx)
|
343
|
+
else:
|
344
|
+
final_answer = gen_message
|
345
|
+
logger.info(
|
346
|
+
f"Max turns reached: {self.max_turns}. Exiting the tool call loop."
|
347
|
+
)
|
348
|
+
return final_answer
|
349
|
+
|
350
|
+
# 3. Call tools if there are any tool calls in the generated message.
|
351
|
+
|
352
|
+
if gen_message.tool_calls:
|
353
|
+
await self.call_tools(gen_message.tool_calls, memory=memory, ctx=ctx)
|
354
|
+
|
355
|
+
# Apply the memory management function if provided.
|
356
|
+
self._manage_memory_fn(memory, ctx=ctx, num_turns=turns)
|
357
|
+
|
358
|
+
# 4. Generate the next message based on the updated memory.
|
359
|
+
# In ReAct mode, we set tool_choice to "none" if we just called tools,
|
360
|
+
# so the next message will be an observation/planning message with
|
361
|
+
# no immediate tool calls.
|
362
|
+
# If we are not in ReAct mode, we set tool_choice to "auto" to allow
|
363
|
+
# the LLM to choose freely whether to call tools.
|
364
|
+
|
365
|
+
tool_choice = (
|
366
|
+
"none" if (self._react_mode and gen_message.tool_calls) else "required"
|
367
|
+
)
|
368
|
+
gen_message = (
|
369
|
+
await self.generate_message_batch(
|
370
|
+
memory, tool_choice=tool_choice, ctx=ctx
|
371
|
+
)
|
372
|
+
)[0]
|
373
|
+
|
374
|
+
turns += 1
|
375
|
+
|
376
|
+
async def execute_stream(
|
377
|
+
self, memory: LLMAgentMemory, ctx: RunContext[CtxT] | None = None
|
378
|
+
) -> AsyncIterator[Event[Any]]:
|
379
|
+
if memory.message_history.batch_size > 1:
|
380
|
+
raise ValueError("Batch size must be 1 when streaming.")
|
381
|
+
|
382
|
+
tool_choice: ToolChoice = "none" if self._react_mode else "auto"
|
383
|
+
gen_message: AssistantMessage | None = None
|
384
|
+
async for event in self.generate_message_stream(
|
385
|
+
memory, tool_choice=tool_choice, ctx=ctx
|
386
|
+
):
|
387
|
+
yield event
|
388
|
+
if isinstance(event, GenMessageEvent):
|
389
|
+
gen_message = event.data
|
390
|
+
assert isinstance(gen_message, AssistantMessage)
|
391
|
+
|
392
|
+
turns = 0
|
393
|
+
|
394
|
+
while True:
|
395
|
+
conversation = memory.message_history.conversations[0]
|
396
|
+
|
397
|
+
if self._final_answer_tool_name is None and self._exit_tool_call_loop_fn(
|
398
|
+
conversation, ctx=ctx, num_turns=turns
|
399
|
+
):
|
400
|
+
return
|
401
|
+
|
402
|
+
if self._final_answer_tool_name is not None:
|
403
|
+
final_answer_message = self._extract_final_answer_from_tool_calls(
|
404
|
+
gen_message, memory=memory
|
405
|
+
)
|
406
|
+
if final_answer_message is not None:
|
407
|
+
yield GenMessageEvent(
|
408
|
+
name=self.agent_name, data=final_answer_message
|
409
|
+
)
|
410
|
+
return
|
411
|
+
|
412
|
+
if turns >= self.max_turns:
|
413
|
+
if self._final_answer_tool_name is not None:
|
414
|
+
async for event in self._generate_final_answer_stream(
|
415
|
+
memory, ctx=ctx
|
416
|
+
):
|
417
|
+
yield event
|
418
|
+
logger.info(
|
419
|
+
f"Max turns reached: {self.max_turns}. Exiting the tool call loop."
|
420
|
+
)
|
421
|
+
return
|
422
|
+
|
423
|
+
if gen_message.tool_calls:
|
424
|
+
for tool_call in gen_message.tool_calls:
|
425
|
+
yield ToolCallEvent(name=self.agent_name, data=tool_call)
|
426
|
+
|
427
|
+
async for tool_message_event in self.call_tools_stream(
|
428
|
+
gen_message.tool_calls, memory=memory, ctx=ctx
|
429
|
+
):
|
430
|
+
yield tool_message_event
|
431
|
+
|
432
|
+
self._manage_memory_fn(memory, ctx=ctx, num_turns=turns)
|
433
|
+
|
434
|
+
tool_choice = (
|
435
|
+
"none" if (self._react_mode and gen_message.tool_calls) else "required"
|
436
|
+
)
|
437
|
+
async for event in self.generate_message_stream(
|
438
|
+
memory, tool_choice=tool_choice, ctx=ctx
|
439
|
+
):
|
440
|
+
yield event
|
441
|
+
if isinstance(event, GenMessageEvent):
|
442
|
+
gen_message = event.data
|
443
|
+
|
444
|
+
turns += 1
|
445
|
+
|
446
|
+
def _track_usage(
|
447
|
+
self, completion_batch: Sequence[Completion], ctx: RunContext[CtxT]
|
448
|
+
) -> None:
|
449
|
+
ctx.usage_tracker.update(
|
450
|
+
completions=completion_batch, model_name=self.llm.model_name
|
451
|
+
)
|
452
|
+
|
453
|
+
def get_final_answer_tool(self) -> BaseTool[BaseModel, None, Any]:
|
454
|
+
if not issubclass(self._final_answer_type, BaseModel):
|
455
|
+
raise TypeError(
|
456
|
+
"final_answer_type must be a subclass of BaseModel to create "
|
457
|
+
"a final answer tool."
|
458
|
+
)
|
459
|
+
|
460
|
+
class FinalAnswerTool(BaseTool[self._final_answer_type, None, Any]):
|
461
|
+
name: str = "final_answer"
|
462
|
+
description: str = (
|
463
|
+
"You must use this tool to provide the final answer. "
|
464
|
+
"Do not provide the final answer anywhere else. "
|
465
|
+
"Input arguments correspond to the final answer."
|
466
|
+
)
|
467
|
+
|
468
|
+
async def run(
|
469
|
+
self, inp: _FinalAnswerT, ctx: RunContext[Any] | None = None
|
470
|
+
) -> None:
|
471
|
+
return None
|
472
|
+
|
473
|
+
return FinalAnswerTool()
|
474
|
+
|
475
|
+
def _print_completions(
|
476
|
+
self, completion_batch: Sequence[Completion], ctx: RunContext[CtxT]
|
477
|
+
) -> None:
|
478
|
+
messages = [c.messages[0] for c in completion_batch]
|
479
|
+
usages = [c.usage for c in completion_batch]
|
480
|
+
ctx.printer.print_llm_messages(
|
481
|
+
messages, usages=usages, agent_name=self.agent_name
|
482
|
+
)
|
grasp_agents/memory.py
CHANGED
@@ -1,144 +1,30 @@
|
|
1
|
-
import
|
2
|
-
from
|
3
|
-
from copy import deepcopy
|
1
|
+
from abc import ABC, abstractmethod
|
2
|
+
from typing import Any
|
4
3
|
|
5
|
-
from
|
6
|
-
from .typing.message import Conversation, Message, SystemMessage
|
4
|
+
from pydantic import BaseModel, ConfigDict
|
7
5
|
|
8
|
-
|
6
|
+
from .run_context import RunContext
|
9
7
|
|
10
8
|
|
11
|
-
class
|
12
|
-
|
13
|
-
|
14
|
-
self
|
15
|
-
self.reset()
|
16
|
-
|
17
|
-
@property
|
18
|
-
def sys_prompt(self) -> LLMPrompt | None:
|
19
|
-
return self._sys_prompt
|
20
|
-
|
21
|
-
def add_message_batch(self, message_batch: Sequence[Message]) -> None:
|
22
|
-
"""
|
23
|
-
Adds a batch of messages to the current batched conversations.
|
24
|
-
This method verifies that the size of the input message batch matches
|
25
|
-
the expected batch size (self.batch_size).
|
26
|
-
If there is a mismatch, the method adjusts by duplicating either
|
27
|
-
the message or the conversation as necessary:
|
28
|
-
|
29
|
-
- If the message batch contains exactly one message and
|
30
|
-
self.batch_size > 1, the single message is duplicated to match
|
31
|
-
the batch size.
|
32
|
-
- If the message batch contains multiple messages but
|
33
|
-
self.batch_size == 1, the entire conversation is duplicated to
|
34
|
-
accommodate each message in the batch.
|
35
|
-
- If the message batch size does not match self.batch_size and none of
|
36
|
-
the above adjustments apply, a ValueError is raised.
|
37
|
-
|
38
|
-
Afterwards, each message in the batch is appended to its corresponding
|
39
|
-
conversation in the batched conversations.
|
40
|
-
|
41
|
-
Args:
|
42
|
-
message_batch: A sequence of Message objects
|
43
|
-
representing the batch of messages to be added. Must align with
|
44
|
-
or be adjusted to match the current batch size.
|
45
|
-
|
46
|
-
Raises:
|
47
|
-
ValueError: If the message batch size does not match the current
|
48
|
-
batch size and cannot be automatically adjusted.
|
49
|
-
|
50
|
-
"""
|
51
|
-
message_batch_size = len(message_batch)
|
52
|
-
|
53
|
-
if message_batch_size == 1 and self.batch_size > 1:
|
54
|
-
logger.info(
|
55
|
-
"Message batch size is 1, current batch size is "
|
56
|
-
f"{self.batch_size}: duplicating the message to match the "
|
57
|
-
"current batch size"
|
58
|
-
)
|
59
|
-
message_batch = self._duplicate_message_to_current_batch_size(message_batch)
|
60
|
-
message_batch_size = self.batch_size
|
61
|
-
elif message_batch_size > 1 and self.batch_size == 1:
|
62
|
-
logger.info(
|
63
|
-
f"Message batch size is {len(message_batch)}, current batch "
|
64
|
-
"size is 1: duplicating the conversation to match the message "
|
65
|
-
"batch size"
|
66
|
-
)
|
67
|
-
self._duplicate_conversation_to_message_batch_size(message_batch_size)
|
68
|
-
elif message_batch_size != self.batch_size:
|
69
|
-
raise ValueError(
|
70
|
-
f"Message batch size {message_batch_size} does not match "
|
71
|
-
f"current batch size {self.batch_size}"
|
72
|
-
)
|
73
|
-
|
74
|
-
for batch_id in range(message_batch_size):
|
75
|
-
self._batched_conversations[batch_id].append(message_batch[batch_id])
|
76
|
-
|
77
|
-
def add_message_batches(self, message_batches: Sequence[Sequence[Message]]) -> None:
|
78
|
-
for message_batch in message_batches:
|
79
|
-
self.add_message_batch(message_batch)
|
80
|
-
|
81
|
-
def add_message(self, message: Message) -> None:
|
82
|
-
for conversation in self._batched_conversations:
|
83
|
-
conversation.append(message)
|
84
|
-
|
85
|
-
def add_messages(self, messages: Sequence[Message]) -> None:
|
86
|
-
for message in messages:
|
87
|
-
self.add_message(message)
|
88
|
-
|
89
|
-
def __len__(self) -> int:
|
90
|
-
return len(self._batched_conversations[0])
|
91
|
-
|
92
|
-
def __repr__(self) -> str:
|
93
|
-
return f"{self.__class__.__name__}(len={len(self)}; bs={self.batch_size})"
|
94
|
-
|
95
|
-
def __getitem__(self, idx: int) -> tuple[Message, ...]:
|
96
|
-
return tuple(conversation[idx] for conversation in self._batched_conversations)
|
97
|
-
|
98
|
-
def __iter__(self) -> Iterator[tuple[Message, ...]]:
|
99
|
-
for idx in range(len(self)):
|
100
|
-
yield tuple(
|
101
|
-
conversation[idx] for conversation in self._batched_conversations
|
102
|
-
)
|
103
|
-
|
104
|
-
def _duplicate_message_to_current_batch_size(
|
105
|
-
self, message_batch: Sequence[Message]
|
106
|
-
) -> Sequence[Message]:
|
107
|
-
assert len(message_batch) == 1, (
|
108
|
-
"Message batch size must be 1 to duplicate to current batch size"
|
109
|
-
)
|
110
|
-
|
111
|
-
return [deepcopy(message_batch[0]) for _ in range(self.batch_size)]
|
112
|
-
|
113
|
-
def _duplicate_conversation_to_message_batch_size(
|
114
|
-
self, target_batch_size: int
|
9
|
+
class Memory(BaseModel, ABC):
|
10
|
+
@abstractmethod
|
11
|
+
def reset(
|
12
|
+
self, *args: Any, ctx: RunContext[Any] | None = None, **kwargs: Any
|
115
13
|
) -> None:
|
116
|
-
|
117
|
-
self._batched_conversations = [
|
118
|
-
deepcopy(self._batched_conversations[0]) for _ in range(target_batch_size)
|
119
|
-
]
|
14
|
+
pass
|
120
15
|
|
121
|
-
@
|
122
|
-
def
|
123
|
-
|
124
|
-
|
125
|
-
@property
|
126
|
-
def batch_size(self) -> int:
|
127
|
-
return len(self._batched_conversations)
|
128
|
-
|
129
|
-
def reset(
|
130
|
-
self, sys_prompt: LLMPrompt | None = None, *, batch_size: int = 1
|
16
|
+
@abstractmethod
|
17
|
+
def update(
|
18
|
+
self, *args: Any, ctx: RunContext[Any] | None = None, **kwargs: Any
|
131
19
|
) -> None:
|
132
|
-
|
133
|
-
self._sys_prompt = sys_prompt
|
20
|
+
pass
|
134
21
|
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
conv = []
|
22
|
+
@property
|
23
|
+
@abstractmethod
|
24
|
+
def is_empty(self) -> bool:
|
25
|
+
pass
|
140
26
|
|
141
|
-
|
27
|
+
def __repr__(self) -> str:
|
28
|
+
return f"{self.__class__.__name__}()"
|
142
29
|
|
143
|
-
|
144
|
-
self._batched_conversations = [[]]
|
30
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|