grasp_agents 0.4.7__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_agents/cloud_llm.py +191 -224
- grasp_agents/comm_processor.py +101 -100
- grasp_agents/errors.py +69 -9
- grasp_agents/litellm/__init__.py +106 -0
- grasp_agents/litellm/completion_chunk_converters.py +68 -0
- grasp_agents/litellm/completion_converters.py +72 -0
- grasp_agents/litellm/converters.py +138 -0
- grasp_agents/litellm/lite_llm.py +210 -0
- grasp_agents/litellm/message_converters.py +66 -0
- grasp_agents/llm.py +84 -49
- grasp_agents/llm_agent.py +136 -120
- grasp_agents/llm_agent_memory.py +3 -3
- grasp_agents/llm_policy_executor.py +167 -174
- grasp_agents/memory.py +23 -0
- grasp_agents/openai/__init__.py +24 -9
- grasp_agents/openai/completion_chunk_converters.py +6 -6
- grasp_agents/openai/completion_converters.py +12 -14
- grasp_agents/openai/content_converters.py +1 -3
- grasp_agents/openai/converters.py +6 -8
- grasp_agents/openai/message_converters.py +21 -3
- grasp_agents/openai/openai_llm.py +155 -103
- grasp_agents/openai/tool_converters.py +4 -6
- grasp_agents/packet.py +5 -2
- grasp_agents/packet_pool.py +14 -13
- grasp_agents/printer.py +233 -73
- grasp_agents/processor.py +229 -91
- grasp_agents/prompt_builder.py +2 -2
- grasp_agents/run_context.py +11 -20
- grasp_agents/runner.py +42 -0
- grasp_agents/typing/completion.py +16 -9
- grasp_agents/typing/completion_chunk.py +51 -22
- grasp_agents/typing/events.py +95 -19
- grasp_agents/typing/message.py +25 -1
- grasp_agents/typing/tool.py +2 -0
- grasp_agents/usage_tracker.py +31 -37
- grasp_agents/utils.py +95 -84
- grasp_agents/workflow/looped_workflow.py +60 -11
- grasp_agents/workflow/sequential_workflow.py +43 -11
- grasp_agents/workflow/workflow_processor.py +25 -24
- {grasp_agents-0.4.7.dist-info → grasp_agents-0.5.1.dist-info}/METADATA +7 -6
- grasp_agents-0.5.1.dist-info/RECORD +57 -0
- grasp_agents-0.4.7.dist-info/RECORD +0 -50
- {grasp_agents-0.4.7.dist-info → grasp_agents-0.5.1.dist-info}/WHEEL +0 -0
- {grasp_agents-0.4.7.dist-info → grasp_agents-0.5.1.dist-info}/licenses/LICENSE.md +0 -0
@@ -3,11 +3,11 @@ import json
|
|
3
3
|
from collections.abc import AsyncIterator, Coroutine, Sequence
|
4
4
|
from itertools import starmap
|
5
5
|
from logging import getLogger
|
6
|
-
from typing import Any,
|
6
|
+
from typing import Any, Generic, Protocol
|
7
7
|
|
8
8
|
from pydantic import BaseModel
|
9
9
|
|
10
|
-
from .
|
10
|
+
from .errors import AgentFinalAnswerError
|
11
11
|
from .llm import LLM, LLMSettings
|
12
12
|
from .llm_agent_memory import LLMAgentMemory
|
13
13
|
from .run_context import CtxT, RunContext
|
@@ -18,6 +18,7 @@ from .typing.events import (
|
|
18
18
|
CompletionEvent,
|
19
19
|
Event,
|
20
20
|
GenMessageEvent,
|
21
|
+
LLMStreamingErrorEvent,
|
21
22
|
ToolCallEvent,
|
22
23
|
ToolMessageEvent,
|
23
24
|
UserMessageEvent,
|
@@ -28,12 +29,6 @@ from .typing.tool import BaseTool, NamedToolChoice, ToolCall, ToolChoice
|
|
28
29
|
logger = getLogger(__name__)
|
29
30
|
|
30
31
|
|
31
|
-
FINAL_ANSWER_TOOL_NAME = "final_answer"
|
32
|
-
|
33
|
-
|
34
|
-
_FinalAnswerT = TypeVar("_FinalAnswerT")
|
35
|
-
|
36
|
-
|
37
32
|
class ExitToolCallLoopHandler(Protocol[CtxT]):
|
38
33
|
def __call__(
|
39
34
|
self,
|
@@ -54,11 +49,7 @@ class ManageMemoryHandler(Protocol[CtxT]):
|
|
54
49
|
) -> None: ...
|
55
50
|
|
56
51
|
|
57
|
-
class LLMPolicyExecutor(
|
58
|
-
_generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
|
59
|
-
0: "_final_answer_type",
|
60
|
-
}
|
61
|
-
|
52
|
+
class LLMPolicyExecutor(Generic[CtxT]):
|
62
53
|
def __init__(
|
63
54
|
self,
|
64
55
|
agent_name: str,
|
@@ -66,19 +57,20 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
|
|
66
57
|
tools: list[BaseTool[BaseModel, Any, CtxT]] | None,
|
67
58
|
max_turns: int,
|
68
59
|
react_mode: bool = False,
|
60
|
+
final_answer_type: type[BaseModel] = BaseModel,
|
69
61
|
final_answer_as_tool_call: bool = False,
|
70
62
|
) -> None:
|
71
|
-
self._final_answer_type: type[_FinalAnswerT]
|
72
63
|
super().__init__()
|
73
64
|
|
74
65
|
self._agent_name = agent_name
|
75
66
|
|
67
|
+
self._final_answer_type = final_answer_type
|
68
|
+
self._final_answer_as_tool_call = final_answer_as_tool_call
|
69
|
+
self._final_answer_tool = self.get_final_answer_tool()
|
70
|
+
|
76
71
|
_tools: list[BaseTool[BaseModel, Any, CtxT]] | None = tools
|
77
|
-
self._final_answer_tool_name: str | None = None
|
78
72
|
if tools and final_answer_as_tool_call:
|
79
|
-
|
80
|
-
self._final_answer_tool_name = final_answer_tool.name
|
81
|
-
_tools = tools + [final_answer_tool]
|
73
|
+
_tools = tools + [self._final_answer_tool]
|
82
74
|
|
83
75
|
self._llm = llm
|
84
76
|
self._llm.tools = _tools
|
@@ -115,12 +107,7 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
|
|
115
107
|
if self.exit_tool_call_loop_impl:
|
116
108
|
return self.exit_tool_call_loop_impl(conversation, ctx=ctx, **kwargs)
|
117
109
|
|
118
|
-
|
119
|
-
assert isinstance(conversation[-1], AssistantMessage), (
|
120
|
-
"Last message in conversation must be an AssistantMessage"
|
121
|
-
)
|
122
|
-
|
123
|
-
return not bool(conversation[-1].tool_calls)
|
110
|
+
return False
|
124
111
|
|
125
112
|
def _manage_memory(
|
126
113
|
self,
|
@@ -132,60 +119,71 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
|
|
132
119
|
if self.manage_memory_impl:
|
133
120
|
self.manage_memory_impl(memory=memory, ctx=ctx, **kwargs)
|
134
121
|
|
135
|
-
async def
|
122
|
+
async def generate_message(
|
136
123
|
self,
|
137
124
|
memory: LLMAgentMemory,
|
138
|
-
|
125
|
+
call_id: str,
|
139
126
|
tool_choice: ToolChoice | None = None,
|
140
127
|
ctx: RunContext[CtxT] | None = None,
|
141
|
-
) ->
|
128
|
+
) -> AssistantMessage:
|
142
129
|
completion = await self.llm.generate_completion(
|
143
|
-
memory.message_history,
|
130
|
+
memory.message_history,
|
131
|
+
tool_choice=tool_choice,
|
132
|
+
n_choices=1,
|
133
|
+
proc_name=self.agent_name,
|
134
|
+
call_id=call_id,
|
144
135
|
)
|
145
136
|
memory.update(completion.messages)
|
137
|
+
self._process_completion(
|
138
|
+
completion, call_id=call_id, ctx=ctx, print_messages=True
|
139
|
+
)
|
146
140
|
|
147
|
-
|
148
|
-
ctx.completions[self.agent_name].append(completion)
|
149
|
-
self._track_usage(self.agent_name, completion, ctx=ctx)
|
150
|
-
self._print_completion(completion, run_id=run_id, ctx=ctx)
|
151
|
-
|
152
|
-
return completion.messages
|
141
|
+
return completion.messages[0]
|
153
142
|
|
154
|
-
async def
|
143
|
+
async def generate_message_stream(
|
155
144
|
self,
|
156
145
|
memory: LLMAgentMemory,
|
157
|
-
|
146
|
+
call_id: str,
|
158
147
|
tool_choice: ToolChoice | None = None,
|
159
148
|
ctx: RunContext[CtxT] | None = None,
|
160
|
-
) -> AsyncIterator[
|
161
|
-
|
162
|
-
|
149
|
+
) -> AsyncIterator[
|
150
|
+
CompletionChunkEvent
|
151
|
+
| CompletionEvent
|
152
|
+
| GenMessageEvent
|
153
|
+
| LLMStreamingErrorEvent
|
154
|
+
]:
|
163
155
|
completion: Completion | None = None
|
164
|
-
async for event in
|
165
|
-
|
156
|
+
async for event in self.llm.generate_completion_stream( # type: ignore[no-untyped-call]
|
157
|
+
memory.message_history,
|
158
|
+
tool_choice=tool_choice,
|
159
|
+
n_choices=1,
|
160
|
+
proc_name=self.agent_name,
|
161
|
+
call_id=call_id,
|
166
162
|
):
|
167
|
-
yield event
|
168
163
|
if isinstance(event, CompletionEvent):
|
169
164
|
completion = event.data
|
165
|
+
yield event
|
170
166
|
if completion is None:
|
171
|
-
|
167
|
+
return
|
172
168
|
|
173
|
-
|
169
|
+
yield GenMessageEvent(
|
170
|
+
proc_name=self.agent_name, call_id=call_id, data=completion.messages[0]
|
171
|
+
)
|
174
172
|
|
175
|
-
|
176
|
-
yield GenMessageEvent(name=self.agent_name, data=message)
|
173
|
+
memory.update(completion.messages)
|
177
174
|
|
178
|
-
|
179
|
-
|
180
|
-
|
175
|
+
self._process_completion(
|
176
|
+
completion, call_id=call_id, print_messages=True, ctx=ctx
|
177
|
+
)
|
181
178
|
|
182
179
|
async def call_tools(
|
183
180
|
self,
|
184
181
|
calls: Sequence[ToolCall],
|
185
182
|
memory: LLMAgentMemory,
|
186
|
-
|
183
|
+
call_id: str,
|
187
184
|
ctx: RunContext[CtxT] | None = None,
|
188
185
|
) -> Sequence[ToolMessage]:
|
186
|
+
# TODO: Add image support
|
189
187
|
corouts: list[Coroutine[Any, Any, BaseModel]] = []
|
190
188
|
for call in calls:
|
191
189
|
tool = self.tools[call.tool_name]
|
@@ -196,11 +194,12 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
|
|
196
194
|
tool_messages = list(
|
197
195
|
starmap(ToolMessage.from_tool_output, zip(outs, calls, strict=True))
|
198
196
|
)
|
197
|
+
|
199
198
|
memory.update(tool_messages)
|
200
199
|
|
201
|
-
if ctx
|
202
|
-
ctx.printer.
|
203
|
-
tool_messages, agent_name=self.agent_name,
|
200
|
+
if ctx and ctx.printer:
|
201
|
+
ctx.printer.print_messages(
|
202
|
+
tool_messages, agent_name=self.agent_name, call_id=call_id
|
204
203
|
)
|
205
204
|
|
206
205
|
return tool_messages
|
@@ -209,138 +208,129 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
|
|
209
208
|
self,
|
210
209
|
calls: Sequence[ToolCall],
|
211
210
|
memory: LLMAgentMemory,
|
212
|
-
|
211
|
+
call_id: str,
|
213
212
|
ctx: RunContext[CtxT] | None = None,
|
214
213
|
) -> AsyncIterator[ToolMessageEvent]:
|
215
214
|
tool_messages = await self.call_tools(
|
216
|
-
calls, memory=memory,
|
215
|
+
calls, memory=memory, call_id=call_id, ctx=ctx
|
217
216
|
)
|
218
217
|
for tool_message, call in zip(tool_messages, calls, strict=True):
|
219
|
-
yield ToolMessageEvent(
|
218
|
+
yield ToolMessageEvent(
|
219
|
+
proc_name=call.tool_name, call_id=call_id, data=tool_message
|
220
|
+
)
|
220
221
|
|
221
222
|
def _extract_final_answer_from_tool_calls(
|
222
|
-
self,
|
223
|
+
self, memory: LLMAgentMemory
|
223
224
|
) -> AssistantMessage | None:
|
224
|
-
|
225
|
-
|
226
|
-
|
225
|
+
last_message = memory.message_history[-1]
|
226
|
+
if not isinstance(last_message, AssistantMessage):
|
227
|
+
return None
|
228
|
+
|
229
|
+
for tool_call in last_message.tool_calls or []:
|
230
|
+
if tool_call.tool_name == self._final_answer_tool.name:
|
227
231
|
final_answer_message = AssistantMessage(
|
228
232
|
name=self.agent_name, content=tool_call.tool_arguments
|
229
233
|
)
|
230
|
-
|
234
|
+
last_message.tool_calls = None
|
231
235
|
memory.update([final_answer_message])
|
232
|
-
return final_answer_message
|
233
236
|
|
234
|
-
|
237
|
+
return final_answer_message
|
235
238
|
|
236
239
|
async def _generate_final_answer(
|
237
|
-
self, memory: LLMAgentMemory,
|
240
|
+
self, memory: LLMAgentMemory, call_id: str, ctx: RunContext[CtxT] | None = None
|
238
241
|
) -> AssistantMessage:
|
239
|
-
assert self._final_answer_tool_name is not None
|
240
|
-
|
241
242
|
user_message = UserMessage.from_text(
|
242
243
|
"Exceeded the maximum number of turns: provide a final answer now!"
|
243
244
|
)
|
244
245
|
memory.update([user_message])
|
245
|
-
if ctx
|
246
|
-
ctx.printer.
|
247
|
-
[user_message], agent_name=self.agent_name,
|
246
|
+
if ctx and ctx.printer:
|
247
|
+
ctx.printer.print_messages(
|
248
|
+
[user_message], agent_name=self.agent_name, call_id=call_id
|
248
249
|
)
|
249
250
|
|
250
|
-
tool_choice = NamedToolChoice(name=self.
|
251
|
-
|
252
|
-
|
253
|
-
memory, tool_choice=tool_choice, run_id=run_id, ctx=ctx
|
254
|
-
)
|
255
|
-
)[0]
|
256
|
-
|
257
|
-
final_answer_message = self._extract_final_answer_from_tool_calls(
|
258
|
-
gen_message, memory=memory
|
251
|
+
tool_choice = NamedToolChoice(name=self._final_answer_tool.name)
|
252
|
+
await self.generate_message(
|
253
|
+
memory, tool_choice=tool_choice, call_id=call_id, ctx=ctx
|
259
254
|
)
|
255
|
+
|
256
|
+
final_answer_message = self._extract_final_answer_from_tool_calls(memory=memory)
|
260
257
|
if final_answer_message is None:
|
261
|
-
raise
|
262
|
-
"Final answer tool call did not return a final answer message."
|
263
|
-
)
|
258
|
+
raise AgentFinalAnswerError
|
264
259
|
|
265
260
|
return final_answer_message
|
266
261
|
|
267
262
|
async def _generate_final_answer_stream(
|
268
|
-
self, memory: LLMAgentMemory,
|
263
|
+
self, memory: LLMAgentMemory, call_id: str, ctx: RunContext[CtxT] | None = None
|
269
264
|
) -> AsyncIterator[Event[Any]]:
|
270
|
-
assert self._final_answer_tool_name is not None
|
271
|
-
|
272
265
|
user_message = UserMessage.from_text(
|
273
266
|
"Exceeded the maximum number of turns: provide a final answer now!",
|
274
267
|
)
|
275
268
|
memory.update([user_message])
|
276
|
-
yield UserMessageEvent(
|
269
|
+
yield UserMessageEvent(
|
270
|
+
proc_name=self.agent_name, call_id=call_id, data=user_message
|
271
|
+
)
|
272
|
+
if ctx and ctx.printer:
|
273
|
+
ctx.printer.print_messages(
|
274
|
+
[user_message], agent_name=self.agent_name, call_id=call_id
|
275
|
+
)
|
277
276
|
|
278
|
-
tool_choice = NamedToolChoice(name=self.
|
279
|
-
|
280
|
-
|
281
|
-
memory, tool_choice=tool_choice, run_id=run_id, ctx=ctx
|
277
|
+
tool_choice = NamedToolChoice(name=self._final_answer_tool.name)
|
278
|
+
async for event in self.generate_message_stream(
|
279
|
+
memory, tool_choice=tool_choice, call_id=call_id, ctx=ctx
|
282
280
|
):
|
283
281
|
yield event
|
284
282
|
|
285
|
-
|
286
|
-
gen_message = event.data
|
287
|
-
final_answer_message = self._extract_final_answer_from_tool_calls(
|
288
|
-
gen_message, memory=memory
|
289
|
-
)
|
283
|
+
final_answer_message = self._extract_final_answer_from_tool_calls(memory)
|
290
284
|
if final_answer_message is None:
|
291
|
-
raise
|
292
|
-
|
293
|
-
|
294
|
-
|
285
|
+
raise AgentFinalAnswerError
|
286
|
+
yield GenMessageEvent(
|
287
|
+
proc_name=self.agent_name, call_id=call_id, data=final_answer_message
|
288
|
+
)
|
295
289
|
|
296
290
|
async def execute(
|
297
|
-
self, memory: LLMAgentMemory,
|
291
|
+
self, memory: LLMAgentMemory, call_id: str, ctx: RunContext[CtxT] | None = None
|
298
292
|
) -> AssistantMessage | Sequence[AssistantMessage]:
|
299
293
|
# 1. Generate the first message:
|
300
294
|
# In ReAct mode, we generate the first message without tool calls
|
301
295
|
# to force the agent to plan its actions in a separate message.
|
296
|
+
|
302
297
|
tool_choice: ToolChoice | None = None
|
303
298
|
if self.tools:
|
304
299
|
tool_choice = "none" if self._react_mode else "auto"
|
305
|
-
|
306
|
-
memory, tool_choice=tool_choice,
|
300
|
+
gen_message = await self.generate_message(
|
301
|
+
memory, tool_choice=tool_choice, call_id=call_id, ctx=ctx
|
307
302
|
)
|
308
303
|
if not self.tools:
|
309
|
-
return
|
304
|
+
return gen_message
|
310
305
|
|
311
|
-
if len(gen_messages) > 1:
|
312
|
-
raise ValueError("n_choices must be 1 when executing the tool call loop.")
|
313
|
-
gen_message = gen_messages[0]
|
314
306
|
turns = 0
|
315
307
|
|
316
308
|
while True:
|
317
309
|
# 2. Check if we should exit the tool call loop
|
318
310
|
|
319
|
-
#
|
320
|
-
# to determine whether to exit the loop.
|
321
|
-
if self.
|
311
|
+
# If a final answer is not provided via a tool call, we use
|
312
|
+
# exit_tool_call_loop to determine whether to exit the loop.
|
313
|
+
if not self._final_answer_as_tool_call and self._exit_tool_call_loop(
|
322
314
|
memory.message_history, ctx=ctx, num_turns=turns
|
323
315
|
):
|
324
316
|
return gen_message
|
325
317
|
|
326
|
-
#
|
327
|
-
#
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
)
|
318
|
+
# If a final answer is provided via a tool call, we check
|
319
|
+
# if the last message contains the corresponding tool call.
|
320
|
+
# If it does, we exit the loop.
|
321
|
+
if self._final_answer_as_tool_call:
|
322
|
+
final_answer = self._extract_final_answer_from_tool_calls(memory)
|
332
323
|
if final_answer is not None:
|
333
324
|
return final_answer
|
334
325
|
|
335
326
|
# Exit if the maximum number of turns is reached
|
336
327
|
if turns >= self.max_turns:
|
337
|
-
#
|
338
|
-
#
|
339
|
-
# tool call.
|
328
|
+
# If a final answer is provided via a tool call, we force the
|
329
|
+
# agent to use the final answer tool.
|
340
330
|
# Otherwise, we simply return the last generated message.
|
341
|
-
if self.
|
331
|
+
if self._final_answer_as_tool_call:
|
342
332
|
final_answer = await self._generate_final_answer(
|
343
|
-
memory,
|
333
|
+
memory, call_id=call_id, ctx=ctx
|
344
334
|
)
|
345
335
|
else:
|
346
336
|
final_answer = gen_message
|
@@ -349,22 +339,20 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
|
|
349
339
|
)
|
350
340
|
return final_answer
|
351
341
|
|
352
|
-
# 3. Call tools
|
342
|
+
# 3. Call tools.
|
353
343
|
|
354
344
|
if gen_message.tool_calls:
|
355
345
|
await self.call_tools(
|
356
|
-
gen_message.tool_calls, memory=memory,
|
346
|
+
gen_message.tool_calls, memory=memory, call_id=call_id, ctx=ctx
|
357
347
|
)
|
358
348
|
|
359
|
-
# Apply
|
349
|
+
# Apply memory management (e.g. compacting or pruning memory)
|
360
350
|
self._manage_memory(memory, ctx=ctx, num_turns=turns)
|
361
351
|
|
362
352
|
# 4. Generate the next message based on the updated memory.
|
363
353
|
# In ReAct mode, we set tool_choice to "none" if we just called tools,
|
364
354
|
# so the next message will be an observation/planning message with
|
365
355
|
# no immediate tool calls.
|
366
|
-
# If we are not in ReAct mode, we set tool_choice to "auto" to allow
|
367
|
-
# the LLM to choose freely whether to call tools.
|
368
356
|
|
369
357
|
if self._react_mode and gen_message.tool_calls:
|
370
358
|
tool_choice = "none"
|
@@ -373,49 +361,56 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
|
|
373
361
|
else:
|
374
362
|
tool_choice = "required"
|
375
363
|
|
376
|
-
gen_message = (
|
377
|
-
|
378
|
-
|
379
|
-
)
|
380
|
-
)[0]
|
364
|
+
gen_message = await self.generate_message(
|
365
|
+
memory, tool_choice=tool_choice, call_id=call_id, ctx=ctx
|
366
|
+
)
|
381
367
|
|
382
368
|
turns += 1
|
383
369
|
|
384
370
|
async def execute_stream(
|
385
|
-
self,
|
371
|
+
self,
|
372
|
+
memory: LLMAgentMemory,
|
373
|
+
call_id: str,
|
374
|
+
ctx: RunContext[CtxT] | None = None,
|
386
375
|
) -> AsyncIterator[Event[Any]]:
|
387
376
|
tool_choice: ToolChoice = "none" if self._react_mode else "auto"
|
388
377
|
gen_message: AssistantMessage | None = None
|
389
|
-
async for event in self.
|
390
|
-
memory, tool_choice=tool_choice,
|
378
|
+
async for event in self.generate_message_stream(
|
379
|
+
memory, tool_choice=tool_choice, call_id=call_id, ctx=ctx
|
391
380
|
):
|
392
|
-
yield event
|
393
381
|
if isinstance(event, GenMessageEvent):
|
394
382
|
gen_message = event.data
|
395
|
-
|
383
|
+
yield event
|
384
|
+
if gen_message is None:
|
385
|
+
return
|
386
|
+
|
387
|
+
if not self.tools:
|
388
|
+
return
|
396
389
|
|
397
390
|
turns = 0
|
398
391
|
|
399
392
|
while True:
|
400
|
-
if self.
|
393
|
+
if not self._final_answer_as_tool_call and self._exit_tool_call_loop(
|
401
394
|
memory.message_history, ctx=ctx, num_turns=turns
|
402
395
|
):
|
403
396
|
return
|
404
397
|
|
405
|
-
if self.
|
398
|
+
if self._final_answer_as_tool_call:
|
406
399
|
final_answer_message = self._extract_final_answer_from_tool_calls(
|
407
|
-
|
400
|
+
memory
|
408
401
|
)
|
409
402
|
if final_answer_message is not None:
|
410
403
|
yield GenMessageEvent(
|
411
|
-
|
404
|
+
proc_name=self.agent_name,
|
405
|
+
call_id=call_id,
|
406
|
+
data=final_answer_message,
|
412
407
|
)
|
413
408
|
return
|
414
409
|
|
415
410
|
if turns >= self.max_turns:
|
416
|
-
if self.
|
411
|
+
if self._final_answer_as_tool_call:
|
417
412
|
async for event in self._generate_final_answer_stream(
|
418
|
-
memory,
|
413
|
+
memory, call_id=call_id, ctx=ctx
|
419
414
|
):
|
420
415
|
yield event
|
421
416
|
logger.info(
|
@@ -425,12 +420,14 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
|
|
425
420
|
|
426
421
|
if gen_message.tool_calls:
|
427
422
|
for tool_call in gen_message.tool_calls:
|
428
|
-
yield ToolCallEvent(
|
423
|
+
yield ToolCallEvent(
|
424
|
+
proc_name=self.agent_name, call_id=call_id, data=tool_call
|
425
|
+
)
|
429
426
|
|
430
|
-
async for
|
431
|
-
gen_message.tool_calls, memory=memory,
|
427
|
+
async for event in self.call_tools_stream(
|
428
|
+
gen_message.tool_calls, memory=memory, call_id=call_id, ctx=ctx
|
432
429
|
):
|
433
|
-
yield
|
430
|
+
yield event
|
434
431
|
|
435
432
|
self._manage_memory(memory, ctx=ctx, num_turns=turns)
|
436
433
|
|
@@ -440,8 +437,9 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
|
|
440
437
|
tool_choice = "auto"
|
441
438
|
else:
|
442
439
|
tool_choice = "required"
|
443
|
-
|
444
|
-
|
440
|
+
|
441
|
+
async for event in self.generate_message_stream(
|
442
|
+
memory, tool_choice=tool_choice, call_id=call_id, ctx=ctx
|
445
443
|
):
|
446
444
|
yield event
|
447
445
|
if isinstance(event, GenMessageEvent):
|
@@ -449,45 +447,40 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
|
|
449
447
|
|
450
448
|
turns += 1
|
451
449
|
|
452
|
-
def _track_usage(
|
453
|
-
self,
|
454
|
-
agent_name: str,
|
455
|
-
completion: Completion,
|
456
|
-
ctx: RunContext[CtxT],
|
457
|
-
) -> None:
|
458
|
-
ctx.usage_tracker.update(
|
459
|
-
agent_name=agent_name,
|
460
|
-
completions=[completion],
|
461
|
-
model_name=self.llm.model_name,
|
462
|
-
)
|
463
|
-
|
464
450
|
def get_final_answer_tool(self) -> BaseTool[BaseModel, None, Any]:
|
465
|
-
if not issubclass(self._final_answer_type, BaseModel):
|
466
|
-
raise TypeError(
|
467
|
-
"final_answer_type must be a subclass of BaseModel to create "
|
468
|
-
"a final answer tool."
|
469
|
-
)
|
470
|
-
|
471
451
|
class FinalAnswerTool(BaseTool[self._final_answer_type, None, Any]):
|
472
|
-
name: str =
|
452
|
+
name: str = "final_answer"
|
473
453
|
description: str = (
|
474
454
|
"You must call this tool to provide the final answer. "
|
475
455
|
"DO NOT output your answer before calling the tool. "
|
476
456
|
)
|
477
457
|
|
478
458
|
async def run(
|
479
|
-
self, inp:
|
459
|
+
self, inp: BaseModel, ctx: RunContext[Any] | None = None
|
480
460
|
) -> None:
|
481
461
|
return None
|
482
462
|
|
483
463
|
return FinalAnswerTool()
|
484
464
|
|
485
|
-
def
|
486
|
-
self,
|
465
|
+
def _process_completion(
|
466
|
+
self,
|
467
|
+
completion: Completion,
|
468
|
+
call_id: str,
|
469
|
+
print_messages: bool = False,
|
470
|
+
ctx: RunContext[CtxT] | None = None,
|
487
471
|
) -> None:
|
488
|
-
ctx
|
489
|
-
completion
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
472
|
+
if ctx is not None:
|
473
|
+
ctx.completions[self.agent_name].append(completion)
|
474
|
+
ctx.usage_tracker.update(
|
475
|
+
agent_name=self.agent_name,
|
476
|
+
completions=[completion],
|
477
|
+
model_name=self.llm.model_name,
|
478
|
+
)
|
479
|
+
if ctx.printer and print_messages:
|
480
|
+
usages = [None] * (len(completion.messages) - 1) + [completion.usage]
|
481
|
+
ctx.printer.print_messages(
|
482
|
+
completion.messages,
|
483
|
+
usages=usages,
|
484
|
+
agent_name=self.agent_name,
|
485
|
+
call_id=call_id,
|
486
|
+
)
|
grasp_agents/memory.py
CHANGED
@@ -15,6 +15,10 @@ class Memory(BaseModel, ABC):
|
|
15
15
|
) -> None:
|
16
16
|
pass
|
17
17
|
|
18
|
+
@abstractmethod
|
19
|
+
def erase(self) -> None:
|
20
|
+
pass
|
21
|
+
|
18
22
|
@abstractmethod
|
19
23
|
def update(
|
20
24
|
self, *args: Any, ctx: RunContext[Any] | None = None, **kwargs: Any
|
@@ -30,3 +34,22 @@ class Memory(BaseModel, ABC):
|
|
30
34
|
return f"{self.__class__.__name__}()"
|
31
35
|
|
32
36
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
37
|
+
|
38
|
+
|
39
|
+
class DummyMemory(Memory):
|
40
|
+
def reset(
|
41
|
+
self, *args: Any, ctx: RunContext[Any] | None = None, **kwargs: Any
|
42
|
+
) -> None:
|
43
|
+
pass
|
44
|
+
|
45
|
+
def erase(self) -> None:
|
46
|
+
pass
|
47
|
+
|
48
|
+
def update(
|
49
|
+
self, *args: Any, ctx: RunContext[Any] | None = None, **kwargs: Any
|
50
|
+
) -> None:
|
51
|
+
pass
|
52
|
+
|
53
|
+
@property
|
54
|
+
def is_empty(self) -> bool:
|
55
|
+
return True
|