grasp_agents 0.2.11__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. grasp_agents/__init__.py +15 -14
  2. grasp_agents/cloud_llm.py +118 -131
  3. grasp_agents/comm_processor.py +201 -0
  4. grasp_agents/generics_utils.py +15 -7
  5. grasp_agents/llm.py +60 -31
  6. grasp_agents/llm_agent.py +229 -273
  7. grasp_agents/llm_agent_memory.py +58 -0
  8. grasp_agents/llm_policy_executor.py +482 -0
  9. grasp_agents/memory.py +20 -134
  10. grasp_agents/message_history.py +140 -0
  11. grasp_agents/openai/__init__.py +54 -36
  12. grasp_agents/openai/completion_chunk_converters.py +78 -0
  13. grasp_agents/openai/completion_converters.py +53 -30
  14. grasp_agents/openai/content_converters.py +13 -14
  15. grasp_agents/openai/converters.py +44 -68
  16. grasp_agents/openai/message_converters.py +58 -72
  17. grasp_agents/openai/openai_llm.py +101 -42
  18. grasp_agents/openai/tool_converters.py +24 -19
  19. grasp_agents/packet.py +24 -0
  20. grasp_agents/packet_pool.py +91 -0
  21. grasp_agents/printer.py +29 -15
  22. grasp_agents/processor.py +194 -0
  23. grasp_agents/prompt_builder.py +175 -192
  24. grasp_agents/run_context.py +20 -37
  25. grasp_agents/typing/completion.py +58 -12
  26. grasp_agents/typing/completion_chunk.py +173 -0
  27. grasp_agents/typing/converters.py +8 -12
  28. grasp_agents/typing/events.py +86 -0
  29. grasp_agents/typing/io.py +4 -13
  30. grasp_agents/typing/message.py +12 -50
  31. grasp_agents/typing/tool.py +52 -26
  32. grasp_agents/usage_tracker.py +6 -6
  33. grasp_agents/utils.py +3 -3
  34. grasp_agents/workflow/looped_workflow.py +132 -0
  35. grasp_agents/workflow/parallel_processor.py +95 -0
  36. grasp_agents/workflow/sequential_workflow.py +66 -0
  37. grasp_agents/workflow/workflow_processor.py +78 -0
  38. {grasp_agents-0.2.11.dist-info → grasp_agents-0.3.1.dist-info}/METADATA +41 -50
  39. grasp_agents-0.3.1.dist-info/RECORD +51 -0
  40. grasp_agents/agent_message.py +0 -27
  41. grasp_agents/agent_message_pool.py +0 -92
  42. grasp_agents/base_agent.py +0 -51
  43. grasp_agents/comm_agent.py +0 -217
  44. grasp_agents/llm_agent_state.py +0 -79
  45. grasp_agents/tool_orchestrator.py +0 -203
  46. grasp_agents/workflow/looped_agent.py +0 -134
  47. grasp_agents/workflow/sequential_agent.py +0 -72
  48. grasp_agents/workflow/workflow_agent.py +0 -88
  49. grasp_agents-0.2.11.dist-info/RECORD +0 -46
  50. {grasp_agents-0.2.11.dist-info → grasp_agents-0.3.1.dist-info}/WHEEL +0 -0
  51. {grasp_agents-0.2.11.dist-info → grasp_agents-0.3.1.dist-info}/licenses/LICENSE.md +0 -0
grasp_agents/llm_agent.py CHANGED
@@ -1,60 +1,47 @@
1
- from collections.abc import Sequence
1
+ from collections.abc import AsyncIterator, Sequence
2
2
  from pathlib import Path
3
- from typing import Any, ClassVar, Generic, Protocol, cast, final
3
+ from typing import Any, ClassVar, Generic, Protocol
4
4
 
5
5
  from pydantic import BaseModel
6
6
 
7
- from .agent_message import AgentMessage
8
- from .agent_message_pool import AgentMessagePool
9
- from .comm_agent import CommunicatingAgent
7
+ from .comm_processor import CommProcessor
10
8
  from .llm import LLM, LLMSettings
11
- from .llm_agent_state import (
12
- LLMAgentState,
13
- SetAgentState,
14
- SetAgentStateStrategy,
9
+ from .llm_agent_memory import LLMAgentMemory, SetMemoryHandler
10
+ from .llm_policy_executor import (
11
+ ExitToolCallLoopHandler,
12
+ LLMPolicyExecutor,
13
+ ManageMemoryHandler,
15
14
  )
15
+ from .packet_pool import PacketPool
16
16
  from .prompt_builder import (
17
- FormatInputArgsHandler,
18
- FormatSystemArgsHandler,
17
+ MakeInputContentHandler,
18
+ MakeSystemPromptHandler,
19
19
  PromptBuilder,
20
20
  )
21
- from .run_context import CtxT, InteractionRecord, RunContextWrapper
22
- from .tool_orchestrator import (
23
- ExitToolCallLoopHandler,
24
- ManageAgentStateHandler,
25
- ToolOrchestrator,
26
- )
27
- from .typing.content import ImageData
21
+ from .run_context import CtxT, RunContext
22
+ from .typing.content import Content, ImageData
28
23
  from .typing.converters import Converters
29
- from .typing.io import (
30
- AgentID,
31
- AgentState,
32
- InT,
33
- LLMFormattedArgs,
34
- LLMFormattedSystemArgs,
35
- LLMPrompt,
36
- LLMPromptArgs,
37
- OutT,
38
- )
39
- from .typing.message import Conversation, Message, SystemMessage
24
+ from .typing.events import Event, ProcOutputEvent, SystemMessageEvent, UserMessageEvent
25
+ from .typing.io import InT_contra, LLMPrompt, LLMPromptArgs, OutT_co, ProcName
26
+ from .typing.message import Message, Messages, SystemMessage, UserMessage
40
27
  from .typing.tool import BaseTool
41
28
  from .utils import get_prompt, validate_obj_from_json_or_py_string
42
29
 
43
30
 
44
- class ParseOutputHandler(Protocol[InT, OutT, CtxT]):
31
+ class ParseOutputHandler(Protocol[InT_contra, OutT_co, CtxT]):
45
32
  def __call__(
46
33
  self,
47
- conversation: Conversation,
34
+ conversation: Messages,
48
35
  *,
49
- in_args: InT | None,
36
+ in_args: InT_contra | None,
50
37
  batch_idx: int,
51
- ctx: RunContextWrapper[CtxT] | None,
52
- ) -> OutT: ...
38
+ ctx: RunContext[CtxT] | None,
39
+ ) -> OutT_co: ...
53
40
 
54
41
 
55
42
  class LLMAgent(
56
- CommunicatingAgent[InT, OutT, LLMAgentState, CtxT],
57
- Generic[InT, OutT, CtxT],
43
+ CommProcessor[InT_contra, OutT_co, LLMAgentMemory, CtxT],
44
+ Generic[InT_contra, OutT_co, CtxT],
58
45
  ):
59
46
  _generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
60
47
  0: "_in_type",
@@ -63,63 +50,68 @@ class LLMAgent(
63
50
 
64
51
  def __init__(
65
52
  self,
66
- agent_id: AgentID,
53
+ name: ProcName,
67
54
  *,
68
55
  # LLM
69
56
  llm: LLM[LLMSettings, Converters],
57
+ # Tools
58
+ tools: list[BaseTool[Any, Any, CtxT]] | None = None,
70
59
  # Input prompt template (combines user and received arguments)
71
60
  in_prompt: LLMPrompt | None = None,
72
61
  in_prompt_path: str | Path | None = None,
73
62
  # System prompt template
74
63
  sys_prompt: LLMPrompt | None = None,
75
64
  sys_prompt_path: str | Path | None = None,
76
- # System args (static args provided via RunContextWrapper)
65
+ # System args (static args provided via RunContext)
77
66
  sys_args_schema: type[LLMPromptArgs] = LLMPromptArgs,
78
- # User args (static args provided via RunContextWrapper)
67
+ # User args (static args provided via RunContext)
79
68
  usr_args_schema: type[LLMPromptArgs] = LLMPromptArgs,
80
- # Tools
81
- tools: list[BaseTool[Any, Any, CtxT]] | None = None,
82
- max_turns: int = 1000,
69
+ # Agent loop settings
70
+ max_turns: int = 100,
83
71
  react_mode: bool = False,
84
- # Agent state management
85
- set_state_strategy: SetAgentStateStrategy = "keep",
72
+ final_answer_as_tool_call: bool = False,
73
+ # Agent memory management
74
+ reset_memory_on_run: bool = False,
86
75
  # Multi-agent routing
87
- message_pool: AgentMessagePool[CtxT] | None = None,
88
- recipient_ids: list[AgentID] | None = None,
76
+ packet_pool: PacketPool[CtxT] | None = None,
77
+ recipients: list[ProcName] | None = None,
89
78
  ) -> None:
90
- super().__init__(
91
- agent_id=agent_id, message_pool=message_pool, recipient_ids=recipient_ids
92
- )
79
+ super().__init__(name=name, packet_pool=packet_pool, recipients=recipients)
93
80
 
94
- # Agent state
95
- self._state: LLMAgentState = LLMAgentState()
96
- self.set_state_strategy: SetAgentStateStrategy = set_state_strategy
97
- self._set_agent_state_impl: SetAgentState | None = None
81
+ # Agent memory
98
82
 
99
- # Tool orchestrator
83
+ self._memory: LLMAgentMemory = LLMAgentMemory()
84
+ self._reset_memory_on_run = reset_memory_on_run
85
+ self._set_memory_impl: SetMemoryHandler | None = None
86
+
87
+ # LLM policy executor
100
88
 
101
89
  self._using_default_llm_response_format: bool = False
102
90
  if llm.response_format is None and tools is None:
103
91
  llm.response_format = self.out_type
104
92
  self._using_default_llm_response_format = True
105
93
 
106
- self._tool_orchestrator: ToolOrchestrator[CtxT] = ToolOrchestrator[CtxT](
107
- agent_id=self.agent_id,
94
+ self._policy_executor: LLMPolicyExecutor[OutT_co, CtxT] = LLMPolicyExecutor[
95
+ self.out_type, CtxT
96
+ ](
97
+ agent_name=self.name,
108
98
  llm=llm,
109
99
  tools=tools,
110
100
  max_turns=max_turns,
111
101
  react_mode=react_mode,
102
+ final_answer_as_tool_call=final_answer_as_tool_call,
112
103
  )
113
104
 
114
105
  # Prompt builder
106
+
115
107
  sys_prompt = get_prompt(prompt_text=sys_prompt, prompt_path=sys_prompt_path)
116
108
  in_prompt = get_prompt(prompt_text=in_prompt, prompt_path=in_prompt_path)
117
- self._prompt_builder: PromptBuilder[InT, CtxT] = PromptBuilder[
109
+ self._prompt_builder: PromptBuilder[InT_contra, CtxT] = PromptBuilder[
118
110
  self.in_type, CtxT
119
111
  ](
120
- agent_id=self._agent_id,
121
- sys_prompt=sys_prompt,
122
- in_prompt=in_prompt,
112
+ agent_name=self._name,
113
+ sys_prompt_template=sys_prompt,
114
+ in_prompt_template=in_prompt,
123
115
  sys_args_schema=sys_args_schema,
124
116
  usr_args_schema=usr_args_schema,
125
117
  )
@@ -130,53 +122,50 @@ class LLMAgent(
130
122
 
131
123
  @property
132
124
  def llm(self) -> LLM[LLMSettings, Converters]:
133
- return self._tool_orchestrator.llm
125
+ return self._policy_executor.llm
134
126
 
135
127
  @property
136
128
  def tools(self) -> dict[str, BaseTool[BaseModel, Any, CtxT]]:
137
- return self._tool_orchestrator.tools
129
+ return self._policy_executor.tools
138
130
 
139
131
  @property
140
132
  def max_turns(self) -> int:
141
- return self._tool_orchestrator.max_turns
133
+ return self._policy_executor.max_turns
142
134
 
143
135
  @property
144
- def sys_args_schema(self) -> type[LLMPromptArgs]:
136
+ def sys_args_schema(self) -> type[LLMPromptArgs] | None:
145
137
  return self._prompt_builder.sys_args_schema
146
138
 
147
139
  @property
148
- def usr_args_schema(self) -> type[LLMPromptArgs]:
140
+ def usr_args_schema(self) -> type[LLMPromptArgs] | None:
149
141
  return self._prompt_builder.usr_args_schema
150
142
 
151
143
  @property
152
144
  def sys_prompt(self) -> LLMPrompt | None:
153
- return self._prompt_builder.sys_prompt
145
+ return self._prompt_builder.sys_prompt_template
154
146
 
155
147
  @property
156
148
  def in_prompt(self) -> LLMPrompt | None:
157
- return self._prompt_builder.in_prompt
149
+ return self._prompt_builder.in_prompt_template
158
150
 
159
151
  def _parse_output(
160
152
  self,
161
- conversation: Conversation,
153
+ conversation: Messages,
162
154
  *,
163
- in_args: InT | None = None,
155
+ in_args: InT_contra | None = None,
164
156
  batch_idx: int = 0,
165
- ctx: RunContextWrapper[CtxT] | None = None,
166
- ) -> OutT:
157
+ ctx: RunContext[CtxT] | None = None,
158
+ ) -> OutT_co:
167
159
  if self._parse_output_impl:
168
160
  if self._using_default_llm_response_format:
169
161
  # When using custom output parsing, the required LLM response format
170
162
  # can differ from the final agent output type ->
171
163
  # set it back to None unless it was specified explicitly at init.
172
- self._tool_orchestrator.llm.response_format = None
173
- self._using_default_llm_response_format = False
164
+ self._policy_executor.llm.response_format = None
165
+ # self._using_default_llm_response_format = False
174
166
 
175
167
  return self._parse_output_impl(
176
- conversation=conversation,
177
- in_args=in_args,
178
- batch_idx=batch_idx,
179
- ctx=ctx,
168
+ conversation=conversation, in_args=in_args, batch_idx=batch_idx, ctx=ctx
180
169
  )
181
170
 
182
171
  return validate_obj_from_json_or_py_string(
@@ -185,215 +174,182 @@ class LLMAgent(
185
174
  from_substring=True,
186
175
  )
187
176
 
188
- @staticmethod
189
- def _validate_run_inputs(
190
- chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
191
- in_args: InT | Sequence[InT] | None = None,
192
- in_message: AgentMessage[InT, AgentState] | None = None,
193
- entry_point: bool = False,
194
- ) -> None:
195
- multiple_inputs_err_message = (
196
- "Only one of chat_inputs, in_args, or in_message must be provided."
197
- )
198
- if chat_inputs is not None and in_args is not None:
199
- raise ValueError(multiple_inputs_err_message)
200
- if chat_inputs is not None and in_message is not None:
201
- raise ValueError(multiple_inputs_err_message)
202
- if in_args is not None and in_message is not None:
203
- raise ValueError(multiple_inputs_err_message)
204
-
205
- if entry_point and in_message is not None:
206
- raise ValueError(
207
- "Entry point agent cannot receive messages from other agents."
208
- )
209
-
210
- @final
211
- async def run(
177
+ def _memorize_inputs(
212
178
  self,
213
179
  chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
214
180
  *,
215
- in_message: AgentMessage[InT, AgentState] | None = None,
216
- in_args: InT | Sequence[InT] | None = None,
217
- entry_point: bool = False,
218
- ctx: RunContextWrapper[CtxT] | None = None,
219
- forbid_state_change: bool = False,
220
- **gen_kwargs: Any, # noqa: ARG002
221
- ) -> AgentMessage[OutT, LLMAgentState]:
181
+ in_args: Sequence[InT_contra] | None = None,
182
+ ctx: RunContext[CtxT] | None = None,
183
+ ) -> tuple[SystemMessage | None, Sequence[UserMessage], LLMAgentMemory]:
222
184
  # Get run arguments
223
- sys_args: LLMPromptArgs = LLMPromptArgs()
224
- usr_args: LLMPromptArgs | Sequence[LLMPromptArgs] = LLMPromptArgs()
185
+ sys_args: LLMPromptArgs | None = None
186
+ usr_args: LLMPromptArgs | None = None
225
187
  if ctx is not None:
226
- run_args = ctx.run_args.get(self.agent_id)
188
+ run_args = ctx.run_args.get(self.name)
227
189
  if run_args is not None:
228
190
  sys_args = run_args.sys
229
191
  usr_args = run_args.usr
230
192
 
231
- self._validate_run_inputs(
232
- chat_inputs=chat_inputs,
233
- in_args=in_args,
234
- in_message=in_message,
235
- entry_point=entry_point,
236
- )
237
- resolved_in_args = in_message.payloads if in_message else in_args
238
-
239
193
  # 1. Make system prompt (can be None)
194
+
240
195
  formatted_sys_prompt = self._prompt_builder.make_sys_prompt(
241
196
  sys_args=sys_args, ctx=ctx
242
197
  )
243
198
 
244
- # 2. Set agent state
245
-
246
- cur_state = self.state.model_copy(deep=True)
247
- in_state = in_message.sender_state if in_message else None
248
- prev_mh_len = len(cur_state.message_history)
199
+ # 2. Set agent memory
249
200
 
250
- state = LLMAgentState.from_cur_and_in_states(
251
- cur_state=cur_state,
252
- in_state=in_state,
253
- sys_prompt=formatted_sys_prompt,
254
- strategy=self.set_state_strategy,
255
- set_agent_state_impl=self._set_agent_state_impl,
256
- ctx=ctx,
257
- )
258
-
259
- self._print_sys_msg(state=state, prev_mh_len=prev_mh_len, ctx=ctx)
201
+ _memory = self.memory.model_copy(deep=True)
202
+ prev_message_hist_length = len(_memory.message_history)
203
+ if self._reset_memory_on_run or _memory.is_empty:
204
+ _memory.reset(formatted_sys_prompt)
205
+ elif self._set_memory_impl:
206
+ _memory = self._set_memory_impl(
207
+ prev_memory=_memory,
208
+ in_args=in_args,
209
+ sys_prompt=formatted_sys_prompt,
210
+ ctx=ctx,
211
+ )
260
212
 
261
- # 3. Make and add user messages (can be empty)
213
+ # 3. Make and add user messages
262
214
 
263
215
  user_message_batch = self._prompt_builder.make_user_messages(
264
- chat_inputs=chat_inputs,
265
- in_args=resolved_in_args,
266
- usr_args=usr_args,
267
- entry_point=entry_point,
268
- ctx=ctx,
216
+ chat_inputs=chat_inputs, in_args_batch=in_args, usr_args=usr_args, ctx=ctx
269
217
  )
270
218
  if user_message_batch:
271
- state.message_history.add_message_batch(user_message_batch)
272
- self._print_msgs(user_message_batch, ctx=ctx)
273
-
274
- if not self.tools:
275
- # 4. Generate messages without tools
276
- await self._tool_orchestrator.generate_once(state=state, ctx=ctx)
277
- else:
278
- # 4. Run tool call loop (new messages are added to the message
279
- # history inside the loop)
280
- await self._tool_orchestrator.run_loop(state=state, ctx=ctx)
281
-
282
- # 5. Parse outputs
283
-
284
- val_output_batch: list[OutT] = []
285
- for i, _conv in enumerate(state.message_history.batched_conversations):
286
- if isinstance(resolved_in_args, Sequence):
287
- _resolved_in_args = cast("Sequence[InT]", resolved_in_args)
288
- _in_args = _resolved_in_args[min(i, len(_resolved_in_args) - 1)]
289
- else:
290
- _resolved_in_args = cast("InT | None", resolved_in_args)
291
- _in_args = _resolved_in_args
292
-
293
- val_output_batch.append(
294
- self._out_type_adapter.validate_python(
295
- self._parse_output(
296
- conversation=_conv, in_args=_in_args, batch_idx=i, ctx=ctx
297
- )
298
- )
299
- )
219
+ _memory.update(message_batch=user_message_batch)
300
220
 
301
- # 6. Write interaction history to context
221
+ # 4. Extract system message if it was added
302
222
 
303
- recipient_ids = self._validate_routing(val_output_batch)
223
+ system_message: SystemMessage | None = None
224
+ if (
225
+ len(_memory.message_history) == 1
226
+ and prev_message_hist_length == 0
227
+ and isinstance(_memory.message_history[0][0], SystemMessage)
228
+ ):
229
+ system_message = _memory.message_history[0][0]
304
230
 
305
- if ctx:
306
- interaction_record = InteractionRecord(
307
- source_id=self.agent_id,
308
- recipient_ids=recipient_ids,
309
- chat_inputs=chat_inputs,
310
- sys_prompt=self.sys_prompt,
311
- in_prompt=self.in_prompt,
312
- sys_args=sys_args,
313
- usr_args=usr_args,
314
- in_args=resolved_in_args, # type: ignore[valid-type]
315
- outputs=val_output_batch,
316
- state=state,
317
- )
318
- ctx.interaction_history.append(
319
- cast(
320
- "InteractionRecord[Any, Any, AgentState]",
321
- interaction_record,
231
+ return system_message, user_message_batch, _memory
232
+
233
+ def _extract_outputs(
234
+ self,
235
+ memory: LLMAgentMemory,
236
+ in_args: Sequence[InT_contra] | None = None,
237
+ ctx: RunContext[CtxT] | None = None,
238
+ ) -> Sequence[OutT_co]:
239
+ outputs: list[OutT_co] = []
240
+ for i, _conv in enumerate(memory.message_history.conversations):
241
+ if in_args is not None:
242
+ _in_args_single = in_args[min(i, len(in_args) - 1)]
243
+ else:
244
+ _in_args_single = None
245
+
246
+ outputs.append(
247
+ self._parse_output(
248
+ conversation=_conv, in_args=_in_args_single, batch_idx=i, ctx=ctx
322
249
  )
323
250
  )
324
251
 
325
- agent_message = AgentMessage(
326
- payloads=val_output_batch,
327
- sender_id=self.agent_id,
328
- sender_state=state,
329
- recipient_ids=recipient_ids,
252
+ return outputs
253
+
254
+ async def _process(
255
+ self,
256
+ chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
257
+ *,
258
+ in_args: Sequence[InT_contra] | None = None,
259
+ forgetful: bool = False,
260
+ ctx: RunContext[CtxT] | None = None,
261
+ ) -> Sequence[OutT_co]:
262
+ system_message, user_message_batch, memory = self._memorize_inputs(
263
+ chat_inputs=chat_inputs, in_args=in_args, ctx=ctx
330
264
  )
331
265
 
332
- if not forbid_state_change:
333
- self._state = state
266
+ if system_message is not None:
267
+ self._print_messages([system_message], ctx=ctx)
268
+ if user_message_batch:
269
+ self._print_messages(user_message_batch, ctx=ctx)
334
270
 
335
- return agent_message
271
+ await self._policy_executor.execute(memory, ctx=ctx)
336
272
 
337
- def _print_msgs(
338
- self,
339
- messages: Sequence[Message],
340
- ctx: RunContextWrapper[CtxT] | None = None,
341
- ) -> None:
342
- if ctx:
343
- ctx.printer.print_llm_messages(messages, agent_id=self.agent_id)
273
+ if not forgetful:
274
+ self._memory = memory
275
+
276
+ return self._extract_outputs(memory=memory, in_args=in_args, ctx=ctx)
344
277
 
345
- def _print_sys_msg(
278
+ async def _process_stream(
346
279
  self,
347
- state: LLMAgentState,
348
- prev_mh_len: int,
349
- ctx: RunContextWrapper[CtxT] | None = None,
280
+ chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
281
+ *,
282
+ in_args: Sequence[InT_contra] | None = None,
283
+ forgetful: bool = False,
284
+ ctx: RunContext[CtxT] | None = None,
285
+ ) -> AsyncIterator[Event[Any]]:
286
+ system_message, user_message_batch, memory = self._memorize_inputs(
287
+ chat_inputs=chat_inputs, in_args=in_args, ctx=ctx
288
+ )
289
+
290
+ if system_message is not None:
291
+ yield SystemMessageEvent(data=system_message)
292
+ if user_message_batch:
293
+ for user_message in user_message_batch:
294
+ yield UserMessageEvent(data=user_message)
295
+
296
+ # 4. Run tool call loop (new messages are added to the message
297
+ # history inside the loop)
298
+ async for event in self._policy_executor.execute_stream(memory, ctx=ctx):
299
+ yield event
300
+
301
+ if not forgetful:
302
+ self._memory = memory
303
+
304
+ outputs = self._extract_outputs(memory=memory, in_args=in_args, ctx=ctx)
305
+ for output in outputs:
306
+ yield ProcOutputEvent(data=output, name=self.name)
307
+
308
+ def _print_messages(
309
+ self, messages: Sequence[Message], ctx: RunContext[CtxT] | None = None
350
310
  ) -> None:
351
- if (
352
- len(state.message_history) == 1
353
- and prev_mh_len == 0
354
- and isinstance(state.message_history[0][0], SystemMessage)
355
- ):
356
- self._print_msgs([state.message_history[0][0]], ctx=ctx)
311
+ if ctx:
312
+ ctx.printer.print_llm_messages(messages, agent_name=self.name)
357
313
 
358
- # -- Handlers for custom implementations --
314
+ # -- Decorators for custom implementations --
359
315
 
360
- def format_sys_args_handler(
361
- self, func: FormatSystemArgsHandler[CtxT]
362
- ) -> FormatSystemArgsHandler[CtxT]:
363
- self._prompt_builder.format_sys_args_impl = func
316
+ def make_sys_prompt(
317
+ self, func: MakeSystemPromptHandler[CtxT]
318
+ ) -> MakeSystemPromptHandler[CtxT]:
319
+ self._prompt_builder.make_sys_prompt_impl = func
364
320
 
365
321
  return func
366
322
 
367
- def format_in_args_handler(
368
- self, func: FormatInputArgsHandler[InT, CtxT]
369
- ) -> FormatInputArgsHandler[InT, CtxT]:
370
- self._prompt_builder.format_in_args_impl = func
323
+ def make_in_content(
324
+ self, func: MakeInputContentHandler[InT_contra, CtxT]
325
+ ) -> MakeInputContentHandler[InT_contra, CtxT]:
326
+ self._prompt_builder.make_in_content_impl = func
371
327
 
372
328
  return func
373
329
 
374
- def parse_output_handler(
375
- self, func: ParseOutputHandler[InT, OutT, CtxT]
376
- ) -> ParseOutputHandler[InT, OutT, CtxT]:
330
+ def parse_output(
331
+ self, func: ParseOutputHandler[InT_contra, OutT_co, CtxT]
332
+ ) -> ParseOutputHandler[InT_contra, OutT_co, CtxT]:
377
333
  self._parse_output_impl = func
378
334
 
379
335
  return func
380
336
 
381
- def set_agent_state_handler(self, func: SetAgentState) -> SetAgentState:
382
- self._make_custom_agent_state_impl = func
337
+ def set_memory(self, func: SetMemoryHandler) -> SetMemoryHandler:
338
+ self._set_memory_impl = func
383
339
 
384
340
  return func
385
341
 
386
- def exit_tool_call_loop_handler(
387
- self, func: ExitToolCallLoopHandler[CtxT]
388
- ) -> ExitToolCallLoopHandler[CtxT]:
389
- self._tool_orchestrator.exit_tool_call_loop_impl = func
342
+ def manage_memory(
343
+ self, func: ManageMemoryHandler[CtxT]
344
+ ) -> ManageMemoryHandler[CtxT]:
345
+ self._policy_executor.manage_memory_impl = func
390
346
 
391
347
  return func
392
348
 
393
- def manage_agent_state_handler(
394
- self, func: ManageAgentStateHandler[CtxT]
395
- ) -> ManageAgentStateHandler[CtxT]:
396
- self._tool_orchestrator.manage_agent_state_impl = func
349
+ def exit_tool_call_loop(
350
+ self, func: ExitToolCallLoopHandler[CtxT]
351
+ ) -> ExitToolCallLoopHandler[CtxT]:
352
+ self._policy_executor.exit_tool_call_loop_impl = func
397
353
 
398
354
  return func
399
355
 
@@ -403,78 +359,78 @@ class LLMAgent(
403
359
  cur_cls = type(self)
404
360
  base_cls = LLMAgent[Any, Any, Any]
405
361
 
406
- if cur_cls._format_sys_args is not base_cls._format_sys_args: # noqa: SLF001
407
- self._prompt_builder.format_sys_args_impl = self._format_sys_args
362
+ if cur_cls._make_sys_prompt_fn is not base_cls._make_sys_prompt_fn: # noqa: SLF001
363
+ self._prompt_builder.make_sys_prompt_impl = self._make_sys_prompt_fn
408
364
 
409
- if cur_cls._format_in_args is not base_cls._format_in_args: # noqa: SLF001
410
- self._prompt_builder.format_in_args_impl = self._format_in_args
365
+ if cur_cls._make_in_content_fn is not base_cls._make_in_content_fn: # noqa: SLF001
366
+ self._prompt_builder.make_in_content_impl = self._make_in_content_fn
411
367
 
412
- if cur_cls._set_agent_state is not base_cls._set_agent_state: # noqa: SLF001
413
- self._set_agent_state_impl = self._set_agent_state
368
+ if cur_cls._set_memory_fn is not base_cls._set_memory_fn: # noqa: SLF001
369
+ self._set_memory_impl = self._set_memory_fn
414
370
 
415
- if cur_cls._manage_agent_state is not base_cls._manage_agent_state: # noqa: SLF001
416
- self._tool_orchestrator.manage_agent_state_impl = self._manage_agent_state
371
+ if cur_cls._manage_memory_fn is not base_cls._manage_memory_fn: # noqa: SLF001
372
+ self._policy_executor.manage_memory_impl = self._manage_memory_fn
417
373
 
418
374
  if (
419
- cur_cls._exit_tool_call_loop is not base_cls._exit_tool_call_loop # noqa: SLF001
375
+ cur_cls._exit_tool_call_loop_fn is not base_cls._exit_tool_call_loop_fn # noqa: SLF001
420
376
  ):
421
- self._tool_orchestrator.exit_tool_call_loop_impl = self._exit_tool_call_loop
377
+ self._policy_executor.exit_tool_call_loop_impl = (
378
+ self._exit_tool_call_loop_fn
379
+ )
422
380
 
423
- self._parse_output_impl: ParseOutputHandler[InT, OutT, CtxT] | None = None
381
+ self._parse_output_impl: (
382
+ ParseOutputHandler[InT_contra, OutT_co, CtxT] | None
383
+ ) = None
424
384
 
425
- def _format_sys_args(
426
- self,
427
- sys_args: LLMPromptArgs,
428
- *,
429
- ctx: RunContextWrapper[CtxT] | None = None,
430
- ) -> LLMFormattedSystemArgs:
385
+ def _make_sys_prompt_fn(
386
+ self, sys_args: LLMPromptArgs | None, *, ctx: RunContext[CtxT] | None = None
387
+ ) -> str:
431
388
  raise NotImplementedError(
432
389
  "LLMAgent._format_sys_args must be overridden by a subclass "
433
390
  "if it's intended to be used as the system arguments formatter."
434
391
  )
435
392
 
436
- def _format_in_args(
393
+ def _make_in_content_fn(
437
394
  self,
438
395
  *,
439
- usr_args: LLMPromptArgs,
440
- in_args: InT,
396
+ in_args: InT_contra | None = None,
397
+ usr_args: LLMPromptArgs | None = None,
441
398
  batch_idx: int = 0,
442
- ctx: RunContextWrapper[CtxT] | None = None,
443
- ) -> LLMFormattedArgs:
399
+ ctx: RunContext[CtxT] | None = None,
400
+ ) -> Content:
444
401
  raise NotImplementedError(
445
402
  "LLMAgent._format_in_args must be overridden by a subclass"
446
403
  )
447
404
 
448
- def _set_agent_state(
405
+ def _set_memory_fn(
449
406
  self,
450
- cur_state: LLMAgentState,
451
- *,
452
- in_state: AgentState | None,
453
- sys_prompt: LLMPrompt | None,
454
- ctx: RunContextWrapper[Any] | None,
455
- ) -> LLMAgentState:
407
+ prev_memory: LLMAgentMemory,
408
+ in_args: InT_contra | Sequence[InT_contra] | None = None,
409
+ sys_prompt: LLMPrompt | None = None,
410
+ ctx: RunContext[Any] | None = None,
411
+ ) -> LLMAgentMemory:
456
412
  raise NotImplementedError(
457
- "LLMAgent._set_agent_state_handler must be overridden by a subclass"
413
+ "LLMAgent._set_memory must be overridden by a subclass"
458
414
  )
459
415
 
460
- def _exit_tool_call_loop(
416
+ def _exit_tool_call_loop_fn(
461
417
  self,
462
- conversation: Conversation,
418
+ conversation: Messages,
463
419
  *,
464
- ctx: RunContextWrapper[CtxT] | None = None,
420
+ ctx: RunContext[CtxT] | None = None,
465
421
  **kwargs: Any,
466
422
  ) -> bool:
467
423
  raise NotImplementedError(
468
- "LLMAgent._tool_call_loop_exit must be overridden by a subclass"
424
+ "LLMAgent._exit_tool_call_loop must be overridden by a subclass"
469
425
  )
470
426
 
471
- def _manage_agent_state(
427
+ def _manage_memory_fn(
472
428
  self,
473
- state: LLMAgentState,
429
+ memory: LLMAgentMemory,
474
430
  *,
475
- ctx: RunContextWrapper[CtxT] | None = None,
431
+ ctx: RunContext[CtxT] | None = None,
476
432
  **kwargs: Any,
477
433
  ) -> None:
478
434
  raise NotImplementedError(
479
- "LLMAgent._manage_agent_state must be overridden by a subclass"
435
+ "LLMAgent._manage_memory must be overridden by a subclass"
480
436
  )