grasp_agents 0.2.10__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. grasp_agents/__init__.py +15 -14
  2. grasp_agents/cloud_llm.py +118 -131
  3. grasp_agents/comm_processor.py +201 -0
  4. grasp_agents/generics_utils.py +15 -7
  5. grasp_agents/llm.py +60 -31
  6. grasp_agents/llm_agent.py +229 -278
  7. grasp_agents/llm_agent_memory.py +58 -0
  8. grasp_agents/llm_policy_executor.py +482 -0
  9. grasp_agents/memory.py +20 -134
  10. grasp_agents/message_history.py +140 -0
  11. grasp_agents/openai/__init__.py +54 -36
  12. grasp_agents/openai/completion_chunk_converters.py +78 -0
  13. grasp_agents/openai/completion_converters.py +53 -30
  14. grasp_agents/openai/content_converters.py +13 -14
  15. grasp_agents/openai/converters.py +44 -68
  16. grasp_agents/openai/message_converters.py +58 -72
  17. grasp_agents/openai/openai_llm.py +101 -42
  18. grasp_agents/openai/tool_converters.py +24 -19
  19. grasp_agents/packet.py +24 -0
  20. grasp_agents/packet_pool.py +91 -0
  21. grasp_agents/printer.py +29 -15
  22. grasp_agents/processor.py +194 -0
  23. grasp_agents/prompt_builder.py +173 -176
  24. grasp_agents/run_context.py +21 -41
  25. grasp_agents/typing/completion.py +58 -12
  26. grasp_agents/typing/completion_chunk.py +173 -0
  27. grasp_agents/typing/converters.py +8 -12
  28. grasp_agents/typing/events.py +86 -0
  29. grasp_agents/typing/io.py +4 -13
  30. grasp_agents/typing/message.py +12 -50
  31. grasp_agents/typing/tool.py +52 -26
  32. grasp_agents/usage_tracker.py +6 -6
  33. grasp_agents/utils.py +3 -3
  34. grasp_agents/workflow/looped_workflow.py +132 -0
  35. grasp_agents/workflow/parallel_processor.py +95 -0
  36. grasp_agents/workflow/sequential_workflow.py +66 -0
  37. grasp_agents/workflow/workflow_processor.py +78 -0
  38. {grasp_agents-0.2.10.dist-info → grasp_agents-0.3.1.dist-info}/METADATA +41 -50
  39. grasp_agents-0.3.1.dist-info/RECORD +51 -0
  40. grasp_agents/agent_message.py +0 -27
  41. grasp_agents/agent_message_pool.py +0 -92
  42. grasp_agents/base_agent.py +0 -51
  43. grasp_agents/comm_agent.py +0 -217
  44. grasp_agents/llm_agent_state.py +0 -79
  45. grasp_agents/tool_orchestrator.py +0 -203
  46. grasp_agents/workflow/looped_agent.py +0 -120
  47. grasp_agents/workflow/sequential_agent.py +0 -63
  48. grasp_agents/workflow/workflow_agent.py +0 -73
  49. grasp_agents-0.2.10.dist-info/RECORD +0 -46
  50. {grasp_agents-0.2.10.dist-info → grasp_agents-0.3.1.dist-info}/WHEEL +0 -0
  51. {grasp_agents-0.2.10.dist-info → grasp_agents-0.3.1.dist-info}/licenses/LICENSE.md +0 -0
grasp_agents/llm_agent.py CHANGED
@@ -1,66 +1,47 @@
1
- from collections.abc import Sequence
1
+ from collections.abc import AsyncIterator, Sequence
2
2
  from pathlib import Path
3
- from typing import Any, ClassVar, Generic, Protocol, cast, final
3
+ from typing import Any, ClassVar, Generic, Protocol
4
4
 
5
5
  from pydantic import BaseModel
6
6
 
7
- from .agent_message import AgentMessage
8
- from .agent_message_pool import AgentMessagePool
9
- from .comm_agent import CommunicatingAgent
7
+ from .comm_processor import CommProcessor
10
8
  from .llm import LLM, LLMSettings
11
- from .llm_agent_state import (
12
- LLMAgentState,
13
- SetAgentState,
14
- SetAgentStateStrategy,
9
+ from .llm_agent_memory import LLMAgentMemory, SetMemoryHandler
10
+ from .llm_policy_executor import (
11
+ ExitToolCallLoopHandler,
12
+ LLMPolicyExecutor,
13
+ ManageMemoryHandler,
15
14
  )
15
+ from .packet_pool import PacketPool
16
16
  from .prompt_builder import (
17
- FormatInputArgsHandler,
18
- FormatSystemArgsHandler,
17
+ MakeInputContentHandler,
18
+ MakeSystemPromptHandler,
19
19
  PromptBuilder,
20
20
  )
21
- from .run_context import (
22
- CtxT,
23
- InteractionRecord,
24
- RunContextWrapper,
25
- SystemRunArgs,
26
- UserRunArgs,
27
- )
28
- from .tool_orchestrator import (
29
- ExitToolCallLoopHandler,
30
- ManageAgentStateHandler,
31
- ToolOrchestrator,
32
- )
33
- from .typing.content import ImageData
21
+ from .run_context import CtxT, RunContext
22
+ from .typing.content import Content, ImageData
34
23
  from .typing.converters import Converters
35
- from .typing.io import (
36
- AgentID,
37
- AgentState,
38
- InT,
39
- LLMFormattedArgs,
40
- LLMFormattedSystemArgs,
41
- LLMPrompt,
42
- LLMPromptArgs,
43
- OutT,
44
- )
45
- from .typing.message import Conversation, Message, SystemMessage
24
+ from .typing.events import Event, ProcOutputEvent, SystemMessageEvent, UserMessageEvent
25
+ from .typing.io import InT_contra, LLMPrompt, LLMPromptArgs, OutT_co, ProcName
26
+ from .typing.message import Message, Messages, SystemMessage, UserMessage
46
27
  from .typing.tool import BaseTool
47
28
  from .utils import get_prompt, validate_obj_from_json_or_py_string
48
29
 
49
30
 
50
- class ParseOutputHandler(Protocol[InT, OutT, CtxT]):
31
+ class ParseOutputHandler(Protocol[InT_contra, OutT_co, CtxT]):
51
32
  def __call__(
52
33
  self,
53
- conversation: Conversation,
34
+ conversation: Messages,
54
35
  *,
55
- in_args: InT | None,
36
+ in_args: InT_contra | None,
56
37
  batch_idx: int,
57
- ctx: RunContextWrapper[CtxT] | None,
58
- ) -> OutT: ...
38
+ ctx: RunContext[CtxT] | None,
39
+ ) -> OutT_co: ...
59
40
 
60
41
 
61
42
  class LLMAgent(
62
- CommunicatingAgent[InT, OutT, LLMAgentState, CtxT],
63
- Generic[InT, OutT, CtxT],
43
+ CommProcessor[InT_contra, OutT_co, LLMAgentMemory, CtxT],
44
+ Generic[InT_contra, OutT_co, CtxT],
64
45
  ):
65
46
  _generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
66
47
  0: "_in_type",
@@ -69,63 +50,68 @@ class LLMAgent(
69
50
 
70
51
  def __init__(
71
52
  self,
72
- agent_id: AgentID,
53
+ name: ProcName,
73
54
  *,
74
55
  # LLM
75
56
  llm: LLM[LLMSettings, Converters],
57
+ # Tools
58
+ tools: list[BaseTool[Any, Any, CtxT]] | None = None,
76
59
  # Input prompt template (combines user and received arguments)
77
60
  in_prompt: LLMPrompt | None = None,
78
61
  in_prompt_path: str | Path | None = None,
79
62
  # System prompt template
80
63
  sys_prompt: LLMPrompt | None = None,
81
64
  sys_prompt_path: str | Path | None = None,
82
- # System args (static args provided via RunContextWrapper)
65
+ # System args (static args provided via RunContext)
83
66
  sys_args_schema: type[LLMPromptArgs] = LLMPromptArgs,
84
- # User args (static args provided via RunContextWrapper)
67
+ # User args (static args provided via RunContext)
85
68
  usr_args_schema: type[LLMPromptArgs] = LLMPromptArgs,
86
- # Tools
87
- tools: list[BaseTool[Any, Any, CtxT]] | None = None,
88
- max_turns: int = 1000,
69
+ # Agent loop settings
70
+ max_turns: int = 100,
89
71
  react_mode: bool = False,
90
- # Agent state management
91
- set_state_strategy: SetAgentStateStrategy = "keep",
72
+ final_answer_as_tool_call: bool = False,
73
+ # Agent memory management
74
+ reset_memory_on_run: bool = False,
92
75
  # Multi-agent routing
93
- message_pool: AgentMessagePool[CtxT] | None = None,
94
- recipient_ids: list[AgentID] | None = None,
76
+ packet_pool: PacketPool[CtxT] | None = None,
77
+ recipients: list[ProcName] | None = None,
95
78
  ) -> None:
96
- super().__init__(
97
- agent_id=agent_id, message_pool=message_pool, recipient_ids=recipient_ids
98
- )
79
+ super().__init__(name=name, packet_pool=packet_pool, recipients=recipients)
80
+
81
+ # Agent memory
99
82
 
100
- # Agent state
101
- self._state: LLMAgentState = LLMAgentState()
102
- self.set_state_strategy: SetAgentStateStrategy = set_state_strategy
103
- self._set_agent_state_impl: SetAgentState | None = None
83
+ self._memory: LLMAgentMemory = LLMAgentMemory()
84
+ self._reset_memory_on_run = reset_memory_on_run
85
+ self._set_memory_impl: SetMemoryHandler | None = None
104
86
 
105
- # Tool orchestrator
87
+ # LLM policy executor
106
88
 
107
89
  self._using_default_llm_response_format: bool = False
108
90
  if llm.response_format is None and tools is None:
109
91
  llm.response_format = self.out_type
110
92
  self._using_default_llm_response_format = True
111
93
 
112
- self._tool_orchestrator: ToolOrchestrator[CtxT] = ToolOrchestrator[CtxT](
113
- agent_id=self.agent_id,
94
+ self._policy_executor: LLMPolicyExecutor[OutT_co, CtxT] = LLMPolicyExecutor[
95
+ self.out_type, CtxT
96
+ ](
97
+ agent_name=self.name,
114
98
  llm=llm,
115
99
  tools=tools,
116
100
  max_turns=max_turns,
117
101
  react_mode=react_mode,
102
+ final_answer_as_tool_call=final_answer_as_tool_call,
118
103
  )
119
104
 
120
105
  # Prompt builder
106
+
121
107
  sys_prompt = get_prompt(prompt_text=sys_prompt, prompt_path=sys_prompt_path)
122
108
  in_prompt = get_prompt(prompt_text=in_prompt, prompt_path=in_prompt_path)
123
- self._prompt_builder: PromptBuilder[InT, CtxT] = PromptBuilder[
109
+ self._prompt_builder: PromptBuilder[InT_contra, CtxT] = PromptBuilder[
124
110
  self.in_type, CtxT
125
111
  ](
126
- agent_id=self._agent_id,
127
- sys_prompt=sys_prompt,
128
- in_prompt=in_prompt,
112
+ agent_name=self._name,
113
+ sys_prompt_template=sys_prompt,
114
+ in_prompt_template=in_prompt,
129
115
  sys_args_schema=sys_args_schema,
130
116
  usr_args_schema=usr_args_schema,
131
117
  )
@@ -136,53 +122,50 @@ class LLMAgent(
136
122
 
137
123
  @property
138
124
  def llm(self) -> LLM[LLMSettings, Converters]:
139
- return self._tool_orchestrator.llm
125
+ return self._policy_executor.llm
140
126
 
141
127
  @property
142
128
  def tools(self) -> dict[str, BaseTool[BaseModel, Any, CtxT]]:
143
- return self._tool_orchestrator.tools
129
+ return self._policy_executor.tools
144
130
 
145
131
  @property
146
132
  def max_turns(self) -> int:
147
- return self._tool_orchestrator.max_turns
133
+ return self._policy_executor.max_turns
148
134
 
149
135
  @property
150
- def sys_args_schema(self) -> type[LLMPromptArgs]:
136
+ def sys_args_schema(self) -> type[LLMPromptArgs] | None:
151
137
  return self._prompt_builder.sys_args_schema
152
138
 
153
139
  @property
154
- def usr_args_schema(self) -> type[LLMPromptArgs]:
140
+ def usr_args_schema(self) -> type[LLMPromptArgs] | None:
155
141
  return self._prompt_builder.usr_args_schema
156
142
 
157
143
  @property
158
144
  def sys_prompt(self) -> LLMPrompt | None:
159
- return self._prompt_builder.sys_prompt
145
+ return self._prompt_builder.sys_prompt_template
160
146
 
161
147
  @property
162
148
  def in_prompt(self) -> LLMPrompt | None:
163
- return self._prompt_builder.in_prompt
149
+ return self._prompt_builder.in_prompt_template
164
150
 
165
151
  def _parse_output(
166
152
  self,
167
- conversation: Conversation,
153
+ conversation: Messages,
168
154
  *,
169
- in_args: InT | None = None,
155
+ in_args: InT_contra | None = None,
170
156
  batch_idx: int = 0,
171
- ctx: RunContextWrapper[CtxT] | None = None,
172
- ) -> OutT:
157
+ ctx: RunContext[CtxT] | None = None,
158
+ ) -> OutT_co:
173
159
  if self._parse_output_impl:
174
160
  if self._using_default_llm_response_format:
175
161
  # When using custom output parsing, the required LLM response format
176
162
  # can differ from the final agent output type ->
177
163
  # set it back to None unless it was specified explicitly at init.
178
- self._tool_orchestrator.llm.response_format = None
179
- self._using_default_llm_response_format = False
164
+ self._policy_executor.llm.response_format = None
165
+ # self._using_default_llm_response_format = False
180
166
 
181
167
  return self._parse_output_impl(
182
- conversation=conversation,
183
- in_args=in_args,
184
- batch_idx=batch_idx,
185
- ctx=ctx,
168
+ conversation=conversation, in_args=in_args, batch_idx=batch_idx, ctx=ctx
186
169
  )
187
170
 
188
171
  return validate_obj_from_json_or_py_string(
@@ -191,214 +174,182 @@ class LLMAgent(
191
174
  from_substring=True,
192
175
  )
193
176
 
194
- @staticmethod
195
- def _validate_run_inputs(
196
- chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
197
- in_args: InT | Sequence[InT] | None = None,
198
- in_message: AgentMessage[InT, AgentState] | None = None,
199
- entry_point: bool = False,
200
- ) -> None:
201
- multiple_inputs_err_message = (
202
- "Only one of chat_inputs, in_args, or in_message must be provided."
203
- )
204
- if chat_inputs is not None and in_args is not None:
205
- raise ValueError(multiple_inputs_err_message)
206
- if chat_inputs is not None and in_message is not None:
207
- raise ValueError(multiple_inputs_err_message)
208
- if in_args is not None and in_message is not None:
209
- raise ValueError(multiple_inputs_err_message)
210
-
211
- if entry_point and in_message is not None:
212
- raise ValueError(
213
- "Entry point agent cannot receive messages from other agents."
214
- )
215
-
216
- @final
217
- async def run(
177
+ def _memorize_inputs(
218
178
  self,
219
179
  chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
220
180
  *,
221
- in_message: AgentMessage[InT, AgentState] | None = None,
222
- in_args: InT | Sequence[InT] | None = None,
223
- entry_point: bool = False,
224
- ctx: RunContextWrapper[CtxT] | None = None,
225
- forbid_state_change: bool = False,
226
- **gen_kwargs: Any, # noqa: ARG002
227
- ) -> AgentMessage[OutT, LLMAgentState]:
181
+ in_args: Sequence[InT_contra] | None = None,
182
+ ctx: RunContext[CtxT] | None = None,
183
+ ) -> tuple[SystemMessage | None, Sequence[UserMessage], LLMAgentMemory]:
228
184
  # Get run arguments
229
- sys_args: SystemRunArgs = LLMPromptArgs()
230
- usr_args: UserRunArgs = LLMPromptArgs()
185
+ sys_args: LLMPromptArgs | None = None
186
+ usr_args: LLMPromptArgs | None = None
231
187
  if ctx is not None:
232
- run_args = ctx.run_args.get(self.agent_id)
188
+ run_args = ctx.run_args.get(self.name)
233
189
  if run_args is not None:
234
190
  sys_args = run_args.sys
235
191
  usr_args = run_args.usr
236
192
 
237
- self._validate_run_inputs(
238
- chat_inputs=chat_inputs,
239
- in_args=in_args,
240
- in_message=in_message,
241
- entry_point=entry_point,
242
- )
243
-
244
193
  # 1. Make system prompt (can be None)
194
+
245
195
  formatted_sys_prompt = self._prompt_builder.make_sys_prompt(
246
196
  sys_args=sys_args, ctx=ctx
247
197
  )
248
198
 
249
- # 2. Set agent state
250
-
251
- cur_state = self.state.model_copy(deep=True)
252
- in_state = in_message.sender_state if in_message else None
253
- prev_mh_len = len(cur_state.message_history)
254
-
255
- state = LLMAgentState.from_cur_and_in_states(
256
- cur_state=cur_state,
257
- in_state=in_state,
258
- sys_prompt=formatted_sys_prompt,
259
- strategy=self.set_state_strategy,
260
- set_agent_state_impl=self._set_agent_state_impl,
261
- ctx=ctx,
262
- )
199
+ # 2. Set agent memory
263
200
 
264
- self._print_sys_msg(state=state, prev_mh_len=prev_mh_len, ctx=ctx)
201
+ _memory = self.memory.model_copy(deep=True)
202
+ prev_message_hist_length = len(_memory.message_history)
203
+ if self._reset_memory_on_run or _memory.is_empty:
204
+ _memory.reset(formatted_sys_prompt)
205
+ elif self._set_memory_impl:
206
+ _memory = self._set_memory_impl(
207
+ prev_memory=_memory,
208
+ in_args=in_args,
209
+ sys_prompt=formatted_sys_prompt,
210
+ ctx=ctx,
211
+ )
265
212
 
266
- # 3. Make and add user messages (can be empty)
267
- _in_args_batch: Sequence[InT] | None = None
268
- if in_message is not None:
269
- _in_args_batch = in_message.payloads
270
- elif in_args is not None:
271
- _in_args_batch = in_args if isinstance(in_args, Sequence) else [in_args] # type: ignore[assignment]
213
+ # 3. Make and add user messages
272
214
 
273
215
  user_message_batch = self._prompt_builder.make_user_messages(
274
- chat_inputs=chat_inputs,
275
- usr_args=usr_args,
276
- in_args_batch=_in_args_batch,
277
- entry_point=entry_point,
278
- ctx=ctx,
216
+ chat_inputs=chat_inputs, in_args_batch=in_args, usr_args=usr_args, ctx=ctx
279
217
  )
280
218
  if user_message_batch:
281
- state.message_history.add_message_batch(user_message_batch)
282
- self._print_msgs(user_message_batch, ctx=ctx)
283
-
284
- if not self.tools:
285
- # 4. Generate messages without tools
286
- await self._tool_orchestrator.generate_once(state=state, ctx=ctx)
287
- else:
288
- # 4. Run tool call loop (new messages are added to the message
289
- # history inside the loop)
290
- await self._tool_orchestrator.run_loop(state=state, ctx=ctx)
291
-
292
- # 5. Parse outputs
293
- batch_size = state.message_history.batch_size
294
- in_args_batch = in_message.payloads if in_message else batch_size * [None]
295
- val_output_batch = [
296
- self._out_type_adapter.validate_python(
297
- self._parse_output(conversation=conv, in_args=in_args, ctx=ctx)
298
- )
299
- for conv, in_args in zip(
300
- state.message_history.batched_conversations,
301
- in_args_batch,
302
- strict=False,
303
- )
304
- ]
219
+ _memory.update(message_batch=user_message_batch)
305
220
 
306
- # 6. Write interaction history to context
221
+ # 4. Extract system message if it was added
307
222
 
308
- recipient_ids = self._validate_routing(val_output_batch)
223
+ system_message: SystemMessage | None = None
224
+ if (
225
+ len(_memory.message_history) == 1
226
+ and prev_message_hist_length == 0
227
+ and isinstance(_memory.message_history[0][0], SystemMessage)
228
+ ):
229
+ system_message = _memory.message_history[0][0]
309
230
 
310
- if ctx:
311
- interaction_record = InteractionRecord(
312
- source_id=self.agent_id,
313
- recipient_ids=recipient_ids,
314
- chat_inputs=chat_inputs,
315
- sys_prompt=self.sys_prompt,
316
- in_prompt=self.in_prompt,
317
- sys_args=sys_args,
318
- usr_args=usr_args,
319
- in_args=(in_message.payloads if in_message is not None else None),
320
- outputs=val_output_batch,
321
- state=state,
322
- )
323
- ctx.interaction_history.append(
324
- cast(
325
- "InteractionRecord[Any, Any, AgentState]",
326
- interaction_record,
231
+ return system_message, user_message_batch, _memory
232
+
233
+ def _extract_outputs(
234
+ self,
235
+ memory: LLMAgentMemory,
236
+ in_args: Sequence[InT_contra] | None = None,
237
+ ctx: RunContext[CtxT] | None = None,
238
+ ) -> Sequence[OutT_co]:
239
+ outputs: list[OutT_co] = []
240
+ for i, _conv in enumerate(memory.message_history.conversations):
241
+ if in_args is not None:
242
+ _in_args_single = in_args[min(i, len(in_args) - 1)]
243
+ else:
244
+ _in_args_single = None
245
+
246
+ outputs.append(
247
+ self._parse_output(
248
+ conversation=_conv, in_args=_in_args_single, batch_idx=i, ctx=ctx
327
249
  )
328
250
  )
329
251
 
330
- agent_message = AgentMessage(
331
- payloads=val_output_batch,
332
- sender_id=self.agent_id,
333
- sender_state=state,
334
- recipient_ids=recipient_ids,
252
+ return outputs
253
+
254
+ async def _process(
255
+ self,
256
+ chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
257
+ *,
258
+ in_args: Sequence[InT_contra] | None = None,
259
+ forgetful: bool = False,
260
+ ctx: RunContext[CtxT] | None = None,
261
+ ) -> Sequence[OutT_co]:
262
+ system_message, user_message_batch, memory = self._memorize_inputs(
263
+ chat_inputs=chat_inputs, in_args=in_args, ctx=ctx
335
264
  )
336
265
 
337
- if not forbid_state_change:
338
- self._state = state
266
+ if system_message is not None:
267
+ self._print_messages([system_message], ctx=ctx)
268
+ if user_message_batch:
269
+ self._print_messages(user_message_batch, ctx=ctx)
339
270
 
340
- return agent_message
271
+ await self._policy_executor.execute(memory, ctx=ctx)
341
272
 
342
- def _print_msgs(
343
- self,
344
- messages: Sequence[Message],
345
- ctx: RunContextWrapper[CtxT] | None = None,
346
- ) -> None:
347
- if ctx:
348
- ctx.printer.print_llm_messages(messages, agent_id=self.agent_id)
273
+ if not forgetful:
274
+ self._memory = memory
349
275
 
350
- def _print_sys_msg(
276
+ return self._extract_outputs(memory=memory, in_args=in_args, ctx=ctx)
277
+
278
+ async def _process_stream(
351
279
  self,
352
- state: LLMAgentState,
353
- prev_mh_len: int,
354
- ctx: RunContextWrapper[CtxT] | None = None,
280
+ chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
281
+ *,
282
+ in_args: Sequence[InT_contra] | None = None,
283
+ forgetful: bool = False,
284
+ ctx: RunContext[CtxT] | None = None,
285
+ ) -> AsyncIterator[Event[Any]]:
286
+ system_message, user_message_batch, memory = self._memorize_inputs(
287
+ chat_inputs=chat_inputs, in_args=in_args, ctx=ctx
288
+ )
289
+
290
+ if system_message is not None:
291
+ yield SystemMessageEvent(data=system_message)
292
+ if user_message_batch:
293
+ for user_message in user_message_batch:
294
+ yield UserMessageEvent(data=user_message)
295
+
296
+ # 4. Run tool call loop (new messages are added to the message
297
+ # history inside the loop)
298
+ async for event in self._policy_executor.execute_stream(memory, ctx=ctx):
299
+ yield event
300
+
301
+ if not forgetful:
302
+ self._memory = memory
303
+
304
+ outputs = self._extract_outputs(memory=memory, in_args=in_args, ctx=ctx)
305
+ for output in outputs:
306
+ yield ProcOutputEvent(data=output, name=self.name)
307
+
308
+ def _print_messages(
309
+ self, messages: Sequence[Message], ctx: RunContext[CtxT] | None = None
355
310
  ) -> None:
356
- if (
357
- len(state.message_history) == 1
358
- and prev_mh_len == 0
359
- and isinstance(state.message_history[0][0], SystemMessage)
360
- ):
361
- self._print_msgs([state.message_history[0][0]], ctx=ctx)
311
+ if ctx:
312
+ ctx.printer.print_llm_messages(messages, agent_name=self.name)
362
313
 
363
- # -- Handlers for custom implementations --
314
+ # -- Decorators for custom implementations --
364
315
 
365
- def format_sys_args_handler(
366
- self, func: FormatSystemArgsHandler[CtxT]
367
- ) -> FormatSystemArgsHandler[CtxT]:
368
- self._prompt_builder.format_sys_args_impl = func
316
+ def make_sys_prompt(
317
+ self, func: MakeSystemPromptHandler[CtxT]
318
+ ) -> MakeSystemPromptHandler[CtxT]:
319
+ self._prompt_builder.make_sys_prompt_impl = func
369
320
 
370
321
  return func
371
322
 
372
- def format_in_args_handler(
373
- self, func: FormatInputArgsHandler[InT, CtxT]
374
- ) -> FormatInputArgsHandler[InT, CtxT]:
375
- self._prompt_builder.format_in_args_impl = func
323
+ def make_in_content(
324
+ self, func: MakeInputContentHandler[InT_contra, CtxT]
325
+ ) -> MakeInputContentHandler[InT_contra, CtxT]:
326
+ self._prompt_builder.make_in_content_impl = func
376
327
 
377
328
  return func
378
329
 
379
- def parse_output_handler(
380
- self, func: ParseOutputHandler[InT, OutT, CtxT]
381
- ) -> ParseOutputHandler[InT, OutT, CtxT]:
330
+ def parse_output(
331
+ self, func: ParseOutputHandler[InT_contra, OutT_co, CtxT]
332
+ ) -> ParseOutputHandler[InT_contra, OutT_co, CtxT]:
382
333
  self._parse_output_impl = func
383
334
 
384
335
  return func
385
336
 
386
- def set_agent_state_handler(self, func: SetAgentState) -> SetAgentState:
387
- self._make_custom_agent_state_impl = func
337
+ def set_memory(self, func: SetMemoryHandler) -> SetMemoryHandler:
338
+ self._set_memory_impl = func
388
339
 
389
340
  return func
390
341
 
391
- def exit_tool_call_loop_handler(
392
- self, func: ExitToolCallLoopHandler[CtxT]
393
- ) -> ExitToolCallLoopHandler[CtxT]:
394
- self._tool_orchestrator.exit_tool_call_loop_impl = func
342
+ def manage_memory(
343
+ self, func: ManageMemoryHandler[CtxT]
344
+ ) -> ManageMemoryHandler[CtxT]:
345
+ self._policy_executor.manage_memory_impl = func
395
346
 
396
347
  return func
397
348
 
398
- def manage_agent_state_handler(
399
- self, func: ManageAgentStateHandler[CtxT]
400
- ) -> ManageAgentStateHandler[CtxT]:
401
- self._tool_orchestrator.manage_agent_state_impl = func
349
+ def exit_tool_call_loop(
350
+ self, func: ExitToolCallLoopHandler[CtxT]
351
+ ) -> ExitToolCallLoopHandler[CtxT]:
352
+ self._policy_executor.exit_tool_call_loop_impl = func
402
353
 
403
354
  return func
404
355
 
@@ -408,78 +359,78 @@ class LLMAgent(
408
359
  cur_cls = type(self)
409
360
  base_cls = LLMAgent[Any, Any, Any]
410
361
 
411
- if cur_cls._format_sys_args is not base_cls._format_sys_args: # noqa: SLF001
412
- self._prompt_builder.format_sys_args_impl = self._format_sys_args
362
+ if cur_cls._make_sys_prompt_fn is not base_cls._make_sys_prompt_fn: # noqa: SLF001
363
+ self._prompt_builder.make_sys_prompt_impl = self._make_sys_prompt_fn
413
364
 
414
- if cur_cls._format_in_args is not base_cls._format_in_args: # noqa: SLF001
415
- self._prompt_builder.format_in_args_impl = self._format_in_args
365
+ if cur_cls._make_in_content_fn is not base_cls._make_in_content_fn: # noqa: SLF001
366
+ self._prompt_builder.make_in_content_impl = self._make_in_content_fn
416
367
 
417
- if cur_cls._set_agent_state is not base_cls._set_agent_state: # noqa: SLF001
418
- self._set_agent_state_impl = self._set_agent_state
368
+ if cur_cls._set_memory_fn is not base_cls._set_memory_fn: # noqa: SLF001
369
+ self._set_memory_impl = self._set_memory_fn
419
370
 
420
- if cur_cls._manage_agent_state is not base_cls._manage_agent_state: # noqa: SLF001
421
- self._tool_orchestrator.manage_agent_state_impl = self._manage_agent_state
371
+ if cur_cls._manage_memory_fn is not base_cls._manage_memory_fn: # noqa: SLF001
372
+ self._policy_executor.manage_memory_impl = self._manage_memory_fn
422
373
 
423
374
  if (
424
- cur_cls._exit_tool_call_loop is not base_cls._exit_tool_call_loop # noqa: SLF001
375
+ cur_cls._exit_tool_call_loop_fn is not base_cls._exit_tool_call_loop_fn # noqa: SLF001
425
376
  ):
426
- self._tool_orchestrator.exit_tool_call_loop_impl = self._exit_tool_call_loop
377
+ self._policy_executor.exit_tool_call_loop_impl = (
378
+ self._exit_tool_call_loop_fn
379
+ )
427
380
 
428
- self._parse_output_impl: ParseOutputHandler[InT, OutT, CtxT] | None = None
381
+ self._parse_output_impl: (
382
+ ParseOutputHandler[InT_contra, OutT_co, CtxT] | None
383
+ ) = None
429
384
 
430
- def _format_sys_args(
431
- self,
432
- sys_args: LLMPromptArgs,
433
- *,
434
- ctx: RunContextWrapper[CtxT] | None = None,
435
- ) -> LLMFormattedSystemArgs:
385
+ def _make_sys_prompt_fn(
386
+ self, sys_args: LLMPromptArgs | None, *, ctx: RunContext[CtxT] | None = None
387
+ ) -> str:
436
388
  raise NotImplementedError(
437
389
  "LLMAgent._format_sys_args must be overridden by a subclass "
438
390
  "if it's intended to be used as the system arguments formatter."
439
391
  )
440
392
 
441
- def _format_in_args(
393
+ def _make_in_content_fn(
442
394
  self,
443
395
  *,
444
- usr_args: LLMPromptArgs,
445
- in_args: InT,
396
+ in_args: InT_contra | None = None,
397
+ usr_args: LLMPromptArgs | None = None,
446
398
  batch_idx: int = 0,
447
- ctx: RunContextWrapper[CtxT] | None = None,
448
- ) -> LLMFormattedArgs:
399
+ ctx: RunContext[CtxT] | None = None,
400
+ ) -> Content:
449
401
  raise NotImplementedError(
450
402
  "LLMAgent._format_in_args must be overridden by a subclass"
451
403
  )
452
404
 
453
- def _set_agent_state(
405
+ def _set_memory_fn(
454
406
  self,
455
- cur_state: LLMAgentState,
456
- *,
457
- in_state: AgentState | None,
458
- sys_prompt: LLMPrompt | None,
459
- ctx: RunContextWrapper[Any] | None,
460
- ) -> LLMAgentState:
407
+ prev_memory: LLMAgentMemory,
408
+ in_args: InT_contra | Sequence[InT_contra] | None = None,
409
+ sys_prompt: LLMPrompt | None = None,
410
+ ctx: RunContext[Any] | None = None,
411
+ ) -> LLMAgentMemory:
461
412
  raise NotImplementedError(
462
- "LLMAgent._set_agent_state_handler must be overridden by a subclass"
413
+ "LLMAgent._set_memory must be overridden by a subclass"
463
414
  )
464
415
 
465
- def _exit_tool_call_loop(
416
+ def _exit_tool_call_loop_fn(
466
417
  self,
467
- conversation: Conversation,
418
+ conversation: Messages,
468
419
  *,
469
- ctx: RunContextWrapper[CtxT] | None = None,
420
+ ctx: RunContext[CtxT] | None = None,
470
421
  **kwargs: Any,
471
422
  ) -> bool:
472
423
  raise NotImplementedError(
473
- "LLMAgent._tool_call_loop_exit must be overridden by a subclass"
424
+ "LLMAgent._exit_tool_call_loop must be overridden by a subclass"
474
425
  )
475
426
 
476
- def _manage_agent_state(
427
+ def _manage_memory_fn(
477
428
  self,
478
- state: LLMAgentState,
429
+ memory: LLMAgentMemory,
479
430
  *,
480
- ctx: RunContextWrapper[CtxT] | None = None,
431
+ ctx: RunContext[CtxT] | None = None,
481
432
  **kwargs: Any,
482
433
  ) -> None:
483
434
  raise NotImplementedError(
484
- "LLMAgent._manage_agent_state must be overridden by a subclass"
435
+ "LLMAgent._manage_memory must be overridden by a subclass"
485
436
  )