grasp_agents 0.4.6__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. grasp_agents/cloud_llm.py +191 -218
  2. grasp_agents/comm_processor.py +101 -100
  3. grasp_agents/errors.py +69 -9
  4. grasp_agents/litellm/__init__.py +106 -0
  5. grasp_agents/litellm/completion_chunk_converters.py +68 -0
  6. grasp_agents/litellm/completion_converters.py +72 -0
  7. grasp_agents/litellm/converters.py +138 -0
  8. grasp_agents/litellm/lite_llm.py +210 -0
  9. grasp_agents/litellm/message_converters.py +66 -0
  10. grasp_agents/llm.py +84 -49
  11. grasp_agents/llm_agent.py +136 -120
  12. grasp_agents/llm_agent_memory.py +3 -3
  13. grasp_agents/llm_policy_executor.py +167 -174
  14. grasp_agents/memory.py +4 -0
  15. grasp_agents/openai/__init__.py +24 -9
  16. grasp_agents/openai/completion_chunk_converters.py +6 -6
  17. grasp_agents/openai/completion_converters.py +12 -14
  18. grasp_agents/openai/content_converters.py +1 -3
  19. grasp_agents/openai/converters.py +6 -8
  20. grasp_agents/openai/message_converters.py +21 -3
  21. grasp_agents/openai/openai_llm.py +155 -103
  22. grasp_agents/openai/tool_converters.py +4 -6
  23. grasp_agents/packet.py +5 -2
  24. grasp_agents/packet_pool.py +14 -13
  25. grasp_agents/printer.py +234 -72
  26. grasp_agents/processor.py +228 -88
  27. grasp_agents/prompt_builder.py +2 -2
  28. grasp_agents/run_context.py +11 -20
  29. grasp_agents/runner.py +42 -0
  30. grasp_agents/typing/completion.py +16 -9
  31. grasp_agents/typing/completion_chunk.py +51 -22
  32. grasp_agents/typing/events.py +95 -19
  33. grasp_agents/typing/message.py +25 -1
  34. grasp_agents/typing/tool.py +2 -0
  35. grasp_agents/usage_tracker.py +31 -37
  36. grasp_agents/utils.py +95 -84
  37. grasp_agents/workflow/looped_workflow.py +60 -11
  38. grasp_agents/workflow/sequential_workflow.py +43 -11
  39. grasp_agents/workflow/workflow_processor.py +25 -24
  40. {grasp_agents-0.4.6.dist-info → grasp_agents-0.5.0.dist-info}/METADATA +7 -6
  41. grasp_agents-0.5.0.dist-info/RECORD +57 -0
  42. grasp_agents-0.4.6.dist-info/RECORD +0 -50
  43. {grasp_agents-0.4.6.dist-info → grasp_agents-0.5.0.dist-info}/WHEEL +0 -0
  44. {grasp_agents-0.4.6.dist-info → grasp_agents-0.5.0.dist-info}/licenses/LICENSE.md +0 -0
@@ -3,11 +3,11 @@ import json
3
3
  from collections.abc import AsyncIterator, Coroutine, Sequence
4
4
  from itertools import starmap
5
5
  from logging import getLogger
6
- from typing import Any, ClassVar, Generic, Protocol, TypeVar
6
+ from typing import Any, Generic, Protocol
7
7
 
8
8
  from pydantic import BaseModel
9
9
 
10
- from .generics_utils import AutoInstanceAttributesMixin
10
+ from .errors import AgentFinalAnswerError
11
11
  from .llm import LLM, LLMSettings
12
12
  from .llm_agent_memory import LLMAgentMemory
13
13
  from .run_context import CtxT, RunContext
@@ -18,6 +18,7 @@ from .typing.events import (
18
18
  CompletionEvent,
19
19
  Event,
20
20
  GenMessageEvent,
21
+ LLMStreamingErrorEvent,
21
22
  ToolCallEvent,
22
23
  ToolMessageEvent,
23
24
  UserMessageEvent,
@@ -28,12 +29,6 @@ from .typing.tool import BaseTool, NamedToolChoice, ToolCall, ToolChoice
28
29
  logger = getLogger(__name__)
29
30
 
30
31
 
31
- FINAL_ANSWER_TOOL_NAME = "final_answer"
32
-
33
-
34
- _FinalAnswerT = TypeVar("_FinalAnswerT")
35
-
36
-
37
32
  class ExitToolCallLoopHandler(Protocol[CtxT]):
38
33
  def __call__(
39
34
  self,
@@ -54,11 +49,7 @@ class ManageMemoryHandler(Protocol[CtxT]):
54
49
  ) -> None: ...
55
50
 
56
51
 
57
- class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT]):
58
- _generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
59
- 0: "_final_answer_type",
60
- }
61
-
52
+ class LLMPolicyExecutor(Generic[CtxT]):
62
53
  def __init__(
63
54
  self,
64
55
  agent_name: str,
@@ -66,19 +57,20 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
66
57
  tools: list[BaseTool[BaseModel, Any, CtxT]] | None,
67
58
  max_turns: int,
68
59
  react_mode: bool = False,
60
+ final_answer_type: type[BaseModel] = BaseModel,
69
61
  final_answer_as_tool_call: bool = False,
70
62
  ) -> None:
71
- self._final_answer_type: type[_FinalAnswerT]
72
63
  super().__init__()
73
64
 
74
65
  self._agent_name = agent_name
75
66
 
67
+ self._final_answer_type = final_answer_type
68
+ self._final_answer_as_tool_call = final_answer_as_tool_call
69
+ self._final_answer_tool = self.get_final_answer_tool()
70
+
76
71
  _tools: list[BaseTool[BaseModel, Any, CtxT]] | None = tools
77
- self._final_answer_tool_name: str | None = None
78
72
  if tools and final_answer_as_tool_call:
79
- final_answer_tool = self.get_final_answer_tool()
80
- self._final_answer_tool_name = final_answer_tool.name
81
- _tools = tools + [final_answer_tool]
73
+ _tools = tools + [self._final_answer_tool]
82
74
 
83
75
  self._llm = llm
84
76
  self._llm.tools = _tools
@@ -115,12 +107,7 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
115
107
  if self.exit_tool_call_loop_impl:
116
108
  return self.exit_tool_call_loop_impl(conversation, ctx=ctx, **kwargs)
117
109
 
118
- assert conversation, "Conversation must not be empty"
119
- assert isinstance(conversation[-1], AssistantMessage), (
120
- "Last message in conversation must be an AssistantMessage"
121
- )
122
-
123
- return not bool(conversation[-1].tool_calls)
110
+ return False
124
111
 
125
112
  def _manage_memory(
126
113
  self,
@@ -132,60 +119,71 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
132
119
  if self.manage_memory_impl:
133
120
  self.manage_memory_impl(memory=memory, ctx=ctx, **kwargs)
134
121
 
135
- async def generate_messages(
122
+ async def generate_message(
136
123
  self,
137
124
  memory: LLMAgentMemory,
138
- run_id: str,
125
+ call_id: str,
139
126
  tool_choice: ToolChoice | None = None,
140
127
  ctx: RunContext[CtxT] | None = None,
141
- ) -> Sequence[AssistantMessage]:
128
+ ) -> AssistantMessage:
142
129
  completion = await self.llm.generate_completion(
143
- memory.message_history, tool_choice=tool_choice
130
+ memory.message_history,
131
+ tool_choice=tool_choice,
132
+ n_choices=1,
133
+ proc_name=self.agent_name,
134
+ call_id=call_id,
144
135
  )
145
136
  memory.update(completion.messages)
137
+ self._process_completion(
138
+ completion, call_id=call_id, ctx=ctx, print_messages=True
139
+ )
146
140
 
147
- if ctx is not None:
148
- ctx.completions[self.agent_name].append(completion)
149
- self._track_usage(self.agent_name, completion, ctx=ctx)
150
- self._print_completion(completion, run_id=run_id, ctx=ctx)
151
-
152
- return completion.messages
141
+ return completion.messages[0]
153
142
 
154
- async def generate_messages_stream(
143
+ async def generate_message_stream(
155
144
  self,
156
145
  memory: LLMAgentMemory,
157
- run_id: str,
146
+ call_id: str,
158
147
  tool_choice: ToolChoice | None = None,
159
148
  ctx: RunContext[CtxT] | None = None,
160
- ) -> AsyncIterator[CompletionChunkEvent | CompletionEvent | GenMessageEvent]:
161
- message_hist = memory.message_history
162
-
149
+ ) -> AsyncIterator[
150
+ CompletionChunkEvent
151
+ | CompletionEvent
152
+ | GenMessageEvent
153
+ | LLMStreamingErrorEvent
154
+ ]:
163
155
  completion: Completion | None = None
164
- async for event in await self.llm.generate_completion_stream(
165
- message_hist, tool_choice=tool_choice
156
+ async for event in self.llm.generate_completion_stream( # type: ignore[no-untyped-call]
157
+ memory.message_history,
158
+ tool_choice=tool_choice,
159
+ n_choices=1,
160
+ proc_name=self.agent_name,
161
+ call_id=call_id,
166
162
  ):
167
- yield event
168
163
  if isinstance(event, CompletionEvent):
169
164
  completion = event.data
165
+ yield event
170
166
  if completion is None:
171
- raise RuntimeError("No completion generated during stream.")
167
+ return
172
168
 
173
- memory.update(completion.messages)
169
+ yield GenMessageEvent(
170
+ proc_name=self.agent_name, call_id=call_id, data=completion.messages[0]
171
+ )
174
172
 
175
- for message in completion.messages:
176
- yield GenMessageEvent(name=self.agent_name, data=message)
173
+ memory.update(completion.messages)
177
174
 
178
- if ctx is not None:
179
- self._track_usage(self.agent_name, completion, ctx=ctx)
180
- ctx.completions[self.agent_name].append(completion)
175
+ self._process_completion(
176
+ completion, call_id=call_id, print_messages=True, ctx=ctx
177
+ )
181
178
 
182
179
  async def call_tools(
183
180
  self,
184
181
  calls: Sequence[ToolCall],
185
182
  memory: LLMAgentMemory,
186
- run_id: str,
183
+ call_id: str,
187
184
  ctx: RunContext[CtxT] | None = None,
188
185
  ) -> Sequence[ToolMessage]:
186
+ # TODO: Add image support
189
187
  corouts: list[Coroutine[Any, Any, BaseModel]] = []
190
188
  for call in calls:
191
189
  tool = self.tools[call.tool_name]
@@ -196,11 +194,12 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
196
194
  tool_messages = list(
197
195
  starmap(ToolMessage.from_tool_output, zip(outs, calls, strict=True))
198
196
  )
197
+
199
198
  memory.update(tool_messages)
200
199
 
201
- if ctx is not None:
202
- ctx.printer.print_llm_messages(
203
- tool_messages, agent_name=self.agent_name, run_id=run_id
200
+ if ctx and ctx.printer:
201
+ ctx.printer.print_messages(
202
+ tool_messages, agent_name=self.agent_name, call_id=call_id
204
203
  )
205
204
 
206
205
  return tool_messages
@@ -209,138 +208,129 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
209
208
  self,
210
209
  calls: Sequence[ToolCall],
211
210
  memory: LLMAgentMemory,
212
- run_id: str,
211
+ call_id: str,
213
212
  ctx: RunContext[CtxT] | None = None,
214
213
  ) -> AsyncIterator[ToolMessageEvent]:
215
214
  tool_messages = await self.call_tools(
216
- calls, memory=memory, run_id=run_id, ctx=ctx
215
+ calls, memory=memory, call_id=call_id, ctx=ctx
217
216
  )
218
217
  for tool_message, call in zip(tool_messages, calls, strict=True):
219
- yield ToolMessageEvent(name=call.tool_name, data=tool_message)
218
+ yield ToolMessageEvent(
219
+ proc_name=call.tool_name, call_id=call_id, data=tool_message
220
+ )
220
221
 
221
222
  def _extract_final_answer_from_tool_calls(
222
- self, gen_message: AssistantMessage, memory: LLMAgentMemory
223
+ self, memory: LLMAgentMemory
223
224
  ) -> AssistantMessage | None:
224
- final_answer_message: AssistantMessage | None = None
225
- for tool_call in gen_message.tool_calls or []:
226
- if tool_call.tool_name == self._final_answer_tool_name:
225
+ last_message = memory.message_history[-1]
226
+ if not isinstance(last_message, AssistantMessage):
227
+ return None
228
+
229
+ for tool_call in last_message.tool_calls or []:
230
+ if tool_call.tool_name == self._final_answer_tool.name:
227
231
  final_answer_message = AssistantMessage(
228
232
  name=self.agent_name, content=tool_call.tool_arguments
229
233
  )
230
- gen_message.tool_calls = None
234
+ last_message.tool_calls = None
231
235
  memory.update([final_answer_message])
232
- return final_answer_message
233
236
 
234
- return final_answer_message
237
+ return final_answer_message
235
238
 
236
239
  async def _generate_final_answer(
237
- self, memory: LLMAgentMemory, run_id: str, ctx: RunContext[CtxT] | None = None
240
+ self, memory: LLMAgentMemory, call_id: str, ctx: RunContext[CtxT] | None = None
238
241
  ) -> AssistantMessage:
239
- assert self._final_answer_tool_name is not None
240
-
241
242
  user_message = UserMessage.from_text(
242
243
  "Exceeded the maximum number of turns: provide a final answer now!"
243
244
  )
244
245
  memory.update([user_message])
245
- if ctx is not None:
246
- ctx.printer.print_llm_messages(
247
- [user_message], agent_name=self.agent_name, run_id=run_id
246
+ if ctx and ctx.printer:
247
+ ctx.printer.print_messages(
248
+ [user_message], agent_name=self.agent_name, call_id=call_id
248
249
  )
249
250
 
250
- tool_choice = NamedToolChoice(name=self._final_answer_tool_name)
251
- gen_message = (
252
- await self.generate_messages(
253
- memory, tool_choice=tool_choice, run_id=run_id, ctx=ctx
254
- )
255
- )[0]
256
-
257
- final_answer_message = self._extract_final_answer_from_tool_calls(
258
- gen_message, memory=memory
251
+ tool_choice = NamedToolChoice(name=self._final_answer_tool.name)
252
+ await self.generate_message(
253
+ memory, tool_choice=tool_choice, call_id=call_id, ctx=ctx
259
254
  )
255
+
256
+ final_answer_message = self._extract_final_answer_from_tool_calls(memory=memory)
260
257
  if final_answer_message is None:
261
- raise RuntimeError(
262
- "Final answer tool call did not return a final answer message."
263
- )
258
+ raise AgentFinalAnswerError
264
259
 
265
260
  return final_answer_message
266
261
 
267
262
  async def _generate_final_answer_stream(
268
- self, memory: LLMAgentMemory, run_id: str, ctx: RunContext[CtxT] | None = None
263
+ self, memory: LLMAgentMemory, call_id: str, ctx: RunContext[CtxT] | None = None
269
264
  ) -> AsyncIterator[Event[Any]]:
270
- assert self._final_answer_tool_name is not None
271
-
272
265
  user_message = UserMessage.from_text(
273
266
  "Exceeded the maximum number of turns: provide a final answer now!",
274
267
  )
275
268
  memory.update([user_message])
276
- yield UserMessageEvent(name=self.agent_name, data=user_message)
269
+ yield UserMessageEvent(
270
+ proc_name=self.agent_name, call_id=call_id, data=user_message
271
+ )
272
+ if ctx and ctx.printer:
273
+ ctx.printer.print_messages(
274
+ [user_message], agent_name=self.agent_name, call_id=call_id
275
+ )
277
276
 
278
- tool_choice = NamedToolChoice(name=self._final_answer_tool_name)
279
- event: Event[Any] | None = None
280
- async for event in self.generate_messages_stream(
281
- memory, tool_choice=tool_choice, run_id=run_id, ctx=ctx
277
+ tool_choice = NamedToolChoice(name=self._final_answer_tool.name)
278
+ async for event in self.generate_message_stream(
279
+ memory, tool_choice=tool_choice, call_id=call_id, ctx=ctx
282
280
  ):
283
281
  yield event
284
282
 
285
- assert isinstance(event, GenMessageEvent)
286
- gen_message = event.data
287
- final_answer_message = self._extract_final_answer_from_tool_calls(
288
- gen_message, memory=memory
289
- )
283
+ final_answer_message = self._extract_final_answer_from_tool_calls(memory)
290
284
  if final_answer_message is None:
291
- raise RuntimeError(
292
- "Final answer tool call did not return a final answer message."
293
- )
294
- yield GenMessageEvent(name=self.agent_name, data=final_answer_message)
285
+ raise AgentFinalAnswerError
286
+ yield GenMessageEvent(
287
+ proc_name=self.agent_name, call_id=call_id, data=final_answer_message
288
+ )
295
289
 
296
290
  async def execute(
297
- self, memory: LLMAgentMemory, run_id: str, ctx: RunContext[CtxT] | None = None
291
+ self, memory: LLMAgentMemory, call_id: str, ctx: RunContext[CtxT] | None = None
298
292
  ) -> AssistantMessage | Sequence[AssistantMessage]:
299
293
  # 1. Generate the first message:
300
294
  # In ReAct mode, we generate the first message without tool calls
301
295
  # to force the agent to plan its actions in a separate message.
296
+
302
297
  tool_choice: ToolChoice | None = None
303
298
  if self.tools:
304
299
  tool_choice = "none" if self._react_mode else "auto"
305
- gen_messages = await self.generate_messages(
306
- memory, tool_choice=tool_choice, run_id=run_id, ctx=ctx
300
+ gen_message = await self.generate_message(
301
+ memory, tool_choice=tool_choice, call_id=call_id, ctx=ctx
307
302
  )
308
303
  if not self.tools:
309
- return gen_messages
304
+ return gen_message
310
305
 
311
- if len(gen_messages) > 1:
312
- raise ValueError("n_choices must be 1 when executing the tool call loop.")
313
- gen_message = gen_messages[0]
314
306
  turns = 0
315
307
 
316
308
  while True:
317
309
  # 2. Check if we should exit the tool call loop
318
310
 
319
- # When final_answer_tool_name is None, we use exit_tool_call_loop_impl
320
- # to determine whether to exit the loop.
321
- if self._final_answer_tool_name is None and self._exit_tool_call_loop(
311
+ # If a final answer is not provided via a tool call, we use
312
+ # exit_tool_call_loop to determine whether to exit the loop.
313
+ if not self._final_answer_as_tool_call and self._exit_tool_call_loop(
322
314
  memory.message_history, ctx=ctx, num_turns=turns
323
315
  ):
324
316
  return gen_message
325
317
 
326
- # When final_answer_tool_name is set, we check if the last message contains
327
- # a tool call to the final answer tool. If it does, we exit the loop.
328
- if self._final_answer_tool_name is not None:
329
- final_answer = self._extract_final_answer_from_tool_calls(
330
- gen_message, memory=memory
331
- )
318
+ # If a final answer is provided via a tool call, we check
319
+ # if the last message contains the corresponding tool call.
320
+ # If it does, we exit the loop.
321
+ if self._final_answer_as_tool_call:
322
+ final_answer = self._extract_final_answer_from_tool_calls(memory)
332
323
  if final_answer is not None:
333
324
  return final_answer
334
325
 
335
326
  # Exit if the maximum number of turns is reached
336
327
  if turns >= self.max_turns:
337
- # When final_answer_tool_name is set, we force the agent to provide
338
- # a final answer by generating a message with a final answer
339
- # tool call.
328
+ # If a final answer is provided via a tool call, we force the
329
+ # agent to use the final answer tool.
340
330
  # Otherwise, we simply return the last generated message.
341
- if self._final_answer_tool_name is not None:
331
+ if self._final_answer_as_tool_call:
342
332
  final_answer = await self._generate_final_answer(
343
- memory, run_id=run_id, ctx=ctx
333
+ memory, call_id=call_id, ctx=ctx
344
334
  )
345
335
  else:
346
336
  final_answer = gen_message
@@ -349,22 +339,20 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
349
339
  )
350
340
  return final_answer
351
341
 
352
- # 3. Call tools if there are any tool calls in the generated message.
342
+ # 3. Call tools.
353
343
 
354
344
  if gen_message.tool_calls:
355
345
  await self.call_tools(
356
- gen_message.tool_calls, memory=memory, run_id=run_id, ctx=ctx
346
+ gen_message.tool_calls, memory=memory, call_id=call_id, ctx=ctx
357
347
  )
358
348
 
359
- # Apply the memory management function if provided.
349
+ # Apply memory management (e.g. compacting or pruning memory)
360
350
  self._manage_memory(memory, ctx=ctx, num_turns=turns)
361
351
 
362
352
  # 4. Generate the next message based on the updated memory.
363
353
  # In ReAct mode, we set tool_choice to "none" if we just called tools,
364
354
  # so the next message will be an observation/planning message with
365
355
  # no immediate tool calls.
366
- # If we are not in ReAct mode, we set tool_choice to "auto" to allow
367
- # the LLM to choose freely whether to call tools.
368
356
 
369
357
  if self._react_mode and gen_message.tool_calls:
370
358
  tool_choice = "none"
@@ -373,49 +361,56 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
373
361
  else:
374
362
  tool_choice = "required"
375
363
 
376
- gen_message = (
377
- await self.generate_messages(
378
- memory, tool_choice=tool_choice, run_id=run_id, ctx=ctx
379
- )
380
- )[0]
364
+ gen_message = await self.generate_message(
365
+ memory, tool_choice=tool_choice, call_id=call_id, ctx=ctx
366
+ )
381
367
 
382
368
  turns += 1
383
369
 
384
370
  async def execute_stream(
385
- self, memory: LLMAgentMemory, run_id: str, ctx: RunContext[CtxT] | None = None
371
+ self,
372
+ memory: LLMAgentMemory,
373
+ call_id: str,
374
+ ctx: RunContext[CtxT] | None = None,
386
375
  ) -> AsyncIterator[Event[Any]]:
387
376
  tool_choice: ToolChoice = "none" if self._react_mode else "auto"
388
377
  gen_message: AssistantMessage | None = None
389
- async for event in self.generate_messages_stream(
390
- memory, tool_choice=tool_choice, run_id=run_id, ctx=ctx
378
+ async for event in self.generate_message_stream(
379
+ memory, tool_choice=tool_choice, call_id=call_id, ctx=ctx
391
380
  ):
392
- yield event
393
381
  if isinstance(event, GenMessageEvent):
394
382
  gen_message = event.data
395
- assert isinstance(gen_message, AssistantMessage)
383
+ yield event
384
+ if gen_message is None:
385
+ return
386
+
387
+ if not self.tools:
388
+ return
396
389
 
397
390
  turns = 0
398
391
 
399
392
  while True:
400
- if self._final_answer_tool_name is None and self._exit_tool_call_loop(
393
+ if not self._final_answer_as_tool_call and self._exit_tool_call_loop(
401
394
  memory.message_history, ctx=ctx, num_turns=turns
402
395
  ):
403
396
  return
404
397
 
405
- if self._final_answer_tool_name is not None:
398
+ if self._final_answer_as_tool_call:
406
399
  final_answer_message = self._extract_final_answer_from_tool_calls(
407
- gen_message, memory=memory
400
+ memory
408
401
  )
409
402
  if final_answer_message is not None:
410
403
  yield GenMessageEvent(
411
- name=self.agent_name, data=final_answer_message
404
+ proc_name=self.agent_name,
405
+ call_id=call_id,
406
+ data=final_answer_message,
412
407
  )
413
408
  return
414
409
 
415
410
  if turns >= self.max_turns:
416
- if self._final_answer_tool_name is not None:
411
+ if self._final_answer_as_tool_call:
417
412
  async for event in self._generate_final_answer_stream(
418
- memory, run_id=run_id, ctx=ctx
413
+ memory, call_id=call_id, ctx=ctx
419
414
  ):
420
415
  yield event
421
416
  logger.info(
@@ -425,12 +420,14 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
425
420
 
426
421
  if gen_message.tool_calls:
427
422
  for tool_call in gen_message.tool_calls:
428
- yield ToolCallEvent(name=self.agent_name, data=tool_call)
423
+ yield ToolCallEvent(
424
+ proc_name=self.agent_name, call_id=call_id, data=tool_call
425
+ )
429
426
 
430
- async for tool_message_event in self.call_tools_stream(
431
- gen_message.tool_calls, memory=memory, run_id=run_id, ctx=ctx
427
+ async for event in self.call_tools_stream(
428
+ gen_message.tool_calls, memory=memory, call_id=call_id, ctx=ctx
432
429
  ):
433
- yield tool_message_event
430
+ yield event
434
431
 
435
432
  self._manage_memory(memory, ctx=ctx, num_turns=turns)
436
433
 
@@ -440,8 +437,9 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
440
437
  tool_choice = "auto"
441
438
  else:
442
439
  tool_choice = "required"
443
- async for event in self.generate_messages_stream(
444
- memory, tool_choice=tool_choice, run_id=run_id, ctx=ctx
440
+
441
+ async for event in self.generate_message_stream(
442
+ memory, tool_choice=tool_choice, call_id=call_id, ctx=ctx
445
443
  ):
446
444
  yield event
447
445
  if isinstance(event, GenMessageEvent):
@@ -449,45 +447,40 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
449
447
 
450
448
  turns += 1
451
449
 
452
- def _track_usage(
453
- self,
454
- agent_name: str,
455
- completion: Completion,
456
- ctx: RunContext[CtxT],
457
- ) -> None:
458
- ctx.usage_tracker.update(
459
- agent_name=agent_name,
460
- completions=[completion],
461
- model_name=self.llm.model_name,
462
- )
463
-
464
450
  def get_final_answer_tool(self) -> BaseTool[BaseModel, None, Any]:
465
- if not issubclass(self._final_answer_type, BaseModel):
466
- raise TypeError(
467
- "final_answer_type must be a subclass of BaseModel to create "
468
- "a final answer tool."
469
- )
470
-
471
451
  class FinalAnswerTool(BaseTool[self._final_answer_type, None, Any]):
472
- name: str = FINAL_ANSWER_TOOL_NAME
452
+ name: str = "final_answer"
473
453
  description: str = (
474
454
  "You must call this tool to provide the final answer. "
475
455
  "DO NOT output your answer before calling the tool. "
476
456
  )
477
457
 
478
458
  async def run(
479
- self, inp: _FinalAnswerT, ctx: RunContext[Any] | None = None
459
+ self, inp: BaseModel, ctx: RunContext[Any] | None = None
480
460
  ) -> None:
481
461
  return None
482
462
 
483
463
  return FinalAnswerTool()
484
464
 
485
- def _print_completion(
486
- self, completion: Completion, run_id: str, ctx: RunContext[CtxT]
465
+ def _process_completion(
466
+ self,
467
+ completion: Completion,
468
+ call_id: str,
469
+ print_messages: bool = False,
470
+ ctx: RunContext[CtxT] | None = None,
487
471
  ) -> None:
488
- ctx.printer.print_llm_messages(
489
- completion.messages,
490
- usages=[completion.usage],
491
- agent_name=self.agent_name,
492
- run_id=run_id,
493
- )
472
+ if ctx is not None:
473
+ ctx.completions[self.agent_name].append(completion)
474
+ ctx.usage_tracker.update(
475
+ agent_name=self.agent_name,
476
+ completions=[completion],
477
+ model_name=self.llm.model_name,
478
+ )
479
+ if ctx.printer and print_messages:
480
+ usages = [None] * (len(completion.messages) - 1) + [completion.usage]
481
+ ctx.printer.print_messages(
482
+ completion.messages,
483
+ usages=usages,
484
+ agent_name=self.agent_name,
485
+ call_id=call_id,
486
+ )
grasp_agents/memory.py CHANGED
@@ -15,6 +15,10 @@ class Memory(BaseModel, ABC):
15
15
  ) -> None:
16
16
  pass
17
17
 
18
+ @abstractmethod
19
+ def erase(self) -> None:
20
+ pass
21
+
18
22
  @abstractmethod
19
23
  def update(
20
24
  self, *args: Any, ctx: RunContext[Any] | None = None, **kwargs: Any