grasp_agents 0.2.11__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_agents/__init__.py +15 -14
- grasp_agents/cloud_llm.py +118 -131
- grasp_agents/comm_processor.py +201 -0
- grasp_agents/generics_utils.py +15 -7
- grasp_agents/llm.py +60 -31
- grasp_agents/llm_agent.py +229 -273
- grasp_agents/llm_agent_memory.py +58 -0
- grasp_agents/llm_policy_executor.py +482 -0
- grasp_agents/memory.py +20 -134
- grasp_agents/message_history.py +140 -0
- grasp_agents/openai/__init__.py +54 -36
- grasp_agents/openai/completion_chunk_converters.py +78 -0
- grasp_agents/openai/completion_converters.py +53 -30
- grasp_agents/openai/content_converters.py +13 -14
- grasp_agents/openai/converters.py +44 -68
- grasp_agents/openai/message_converters.py +58 -72
- grasp_agents/openai/openai_llm.py +101 -42
- grasp_agents/openai/tool_converters.py +24 -19
- grasp_agents/packet.py +24 -0
- grasp_agents/packet_pool.py +91 -0
- grasp_agents/printer.py +29 -15
- grasp_agents/processor.py +194 -0
- grasp_agents/prompt_builder.py +175 -192
- grasp_agents/run_context.py +20 -37
- grasp_agents/typing/completion.py +58 -12
- grasp_agents/typing/completion_chunk.py +173 -0
- grasp_agents/typing/converters.py +8 -12
- grasp_agents/typing/events.py +86 -0
- grasp_agents/typing/io.py +4 -13
- grasp_agents/typing/message.py +12 -50
- grasp_agents/typing/tool.py +52 -26
- grasp_agents/usage_tracker.py +6 -6
- grasp_agents/utils.py +3 -3
- grasp_agents/workflow/looped_workflow.py +132 -0
- grasp_agents/workflow/parallel_processor.py +95 -0
- grasp_agents/workflow/sequential_workflow.py +66 -0
- grasp_agents/workflow/workflow_processor.py +78 -0
- {grasp_agents-0.2.11.dist-info → grasp_agents-0.3.1.dist-info}/METADATA +41 -50
- grasp_agents-0.3.1.dist-info/RECORD +51 -0
- grasp_agents/agent_message.py +0 -27
- grasp_agents/agent_message_pool.py +0 -92
- grasp_agents/base_agent.py +0 -51
- grasp_agents/comm_agent.py +0 -217
- grasp_agents/llm_agent_state.py +0 -79
- grasp_agents/tool_orchestrator.py +0 -203
- grasp_agents/workflow/looped_agent.py +0 -134
- grasp_agents/workflow/sequential_agent.py +0 -72
- grasp_agents/workflow/workflow_agent.py +0 -88
- grasp_agents-0.2.11.dist-info/RECORD +0 -46
- {grasp_agents-0.2.11.dist-info → grasp_agents-0.3.1.dist-info}/WHEEL +0 -0
- {grasp_agents-0.2.11.dist-info → grasp_agents-0.3.1.dist-info}/licenses/LICENSE.md +0 -0
grasp_agents/llm_agent.py
CHANGED
@@ -1,60 +1,47 @@
|
|
1
|
-
from collections.abc import Sequence
|
1
|
+
from collections.abc import AsyncIterator, Sequence
|
2
2
|
from pathlib import Path
|
3
|
-
from typing import Any, ClassVar, Generic, Protocol
|
3
|
+
from typing import Any, ClassVar, Generic, Protocol
|
4
4
|
|
5
5
|
from pydantic import BaseModel
|
6
6
|
|
7
|
-
from .
|
8
|
-
from .agent_message_pool import AgentMessagePool
|
9
|
-
from .comm_agent import CommunicatingAgent
|
7
|
+
from .comm_processor import CommProcessor
|
10
8
|
from .llm import LLM, LLMSettings
|
11
|
-
from .
|
12
|
-
|
13
|
-
|
14
|
-
|
9
|
+
from .llm_agent_memory import LLMAgentMemory, SetMemoryHandler
|
10
|
+
from .llm_policy_executor import (
|
11
|
+
ExitToolCallLoopHandler,
|
12
|
+
LLMPolicyExecutor,
|
13
|
+
ManageMemoryHandler,
|
15
14
|
)
|
15
|
+
from .packet_pool import PacketPool
|
16
16
|
from .prompt_builder import (
|
17
|
-
|
18
|
-
|
17
|
+
MakeInputContentHandler,
|
18
|
+
MakeSystemPromptHandler,
|
19
19
|
PromptBuilder,
|
20
20
|
)
|
21
|
-
from .run_context import CtxT,
|
22
|
-
from .
|
23
|
-
ExitToolCallLoopHandler,
|
24
|
-
ManageAgentStateHandler,
|
25
|
-
ToolOrchestrator,
|
26
|
-
)
|
27
|
-
from .typing.content import ImageData
|
21
|
+
from .run_context import CtxT, RunContext
|
22
|
+
from .typing.content import Content, ImageData
|
28
23
|
from .typing.converters import Converters
|
29
|
-
from .typing.
|
30
|
-
|
31
|
-
|
32
|
-
InT,
|
33
|
-
LLMFormattedArgs,
|
34
|
-
LLMFormattedSystemArgs,
|
35
|
-
LLMPrompt,
|
36
|
-
LLMPromptArgs,
|
37
|
-
OutT,
|
38
|
-
)
|
39
|
-
from .typing.message import Conversation, Message, SystemMessage
|
24
|
+
from .typing.events import Event, ProcOutputEvent, SystemMessageEvent, UserMessageEvent
|
25
|
+
from .typing.io import InT_contra, LLMPrompt, LLMPromptArgs, OutT_co, ProcName
|
26
|
+
from .typing.message import Message, Messages, SystemMessage, UserMessage
|
40
27
|
from .typing.tool import BaseTool
|
41
28
|
from .utils import get_prompt, validate_obj_from_json_or_py_string
|
42
29
|
|
43
30
|
|
44
|
-
class ParseOutputHandler(Protocol[
|
31
|
+
class ParseOutputHandler(Protocol[InT_contra, OutT_co, CtxT]):
|
45
32
|
def __call__(
|
46
33
|
self,
|
47
|
-
conversation:
|
34
|
+
conversation: Messages,
|
48
35
|
*,
|
49
|
-
in_args:
|
36
|
+
in_args: InT_contra | None,
|
50
37
|
batch_idx: int,
|
51
|
-
ctx:
|
52
|
-
) ->
|
38
|
+
ctx: RunContext[CtxT] | None,
|
39
|
+
) -> OutT_co: ...
|
53
40
|
|
54
41
|
|
55
42
|
class LLMAgent(
|
56
|
-
|
57
|
-
Generic[
|
43
|
+
CommProcessor[InT_contra, OutT_co, LLMAgentMemory, CtxT],
|
44
|
+
Generic[InT_contra, OutT_co, CtxT],
|
58
45
|
):
|
59
46
|
_generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
|
60
47
|
0: "_in_type",
|
@@ -63,63 +50,68 @@ class LLMAgent(
|
|
63
50
|
|
64
51
|
def __init__(
|
65
52
|
self,
|
66
|
-
|
53
|
+
name: ProcName,
|
67
54
|
*,
|
68
55
|
# LLM
|
69
56
|
llm: LLM[LLMSettings, Converters],
|
57
|
+
# Tools
|
58
|
+
tools: list[BaseTool[Any, Any, CtxT]] | None = None,
|
70
59
|
# Input prompt template (combines user and received arguments)
|
71
60
|
in_prompt: LLMPrompt | None = None,
|
72
61
|
in_prompt_path: str | Path | None = None,
|
73
62
|
# System prompt template
|
74
63
|
sys_prompt: LLMPrompt | None = None,
|
75
64
|
sys_prompt_path: str | Path | None = None,
|
76
|
-
# System args (static args provided via
|
65
|
+
# System args (static args provided via RunContext)
|
77
66
|
sys_args_schema: type[LLMPromptArgs] = LLMPromptArgs,
|
78
|
-
# User args (static args provided via
|
67
|
+
# User args (static args provided via RunContext)
|
79
68
|
usr_args_schema: type[LLMPromptArgs] = LLMPromptArgs,
|
80
|
-
#
|
81
|
-
|
82
|
-
max_turns: int = 1000,
|
69
|
+
# Agent loop settings
|
70
|
+
max_turns: int = 100,
|
83
71
|
react_mode: bool = False,
|
84
|
-
|
85
|
-
|
72
|
+
final_answer_as_tool_call: bool = False,
|
73
|
+
# Agent memory management
|
74
|
+
reset_memory_on_run: bool = False,
|
86
75
|
# Multi-agent routing
|
87
|
-
|
88
|
-
|
76
|
+
packet_pool: PacketPool[CtxT] | None = None,
|
77
|
+
recipients: list[ProcName] | None = None,
|
89
78
|
) -> None:
|
90
|
-
super().__init__(
|
91
|
-
agent_id=agent_id, message_pool=message_pool, recipient_ids=recipient_ids
|
92
|
-
)
|
79
|
+
super().__init__(name=name, packet_pool=packet_pool, recipients=recipients)
|
93
80
|
|
94
|
-
# Agent
|
95
|
-
self._state: LLMAgentState = LLMAgentState()
|
96
|
-
self.set_state_strategy: SetAgentStateStrategy = set_state_strategy
|
97
|
-
self._set_agent_state_impl: SetAgentState | None = None
|
81
|
+
# Agent memory
|
98
82
|
|
99
|
-
|
83
|
+
self._memory: LLMAgentMemory = LLMAgentMemory()
|
84
|
+
self._reset_memory_on_run = reset_memory_on_run
|
85
|
+
self._set_memory_impl: SetMemoryHandler | None = None
|
86
|
+
|
87
|
+
# LLM policy executor
|
100
88
|
|
101
89
|
self._using_default_llm_response_format: bool = False
|
102
90
|
if llm.response_format is None and tools is None:
|
103
91
|
llm.response_format = self.out_type
|
104
92
|
self._using_default_llm_response_format = True
|
105
93
|
|
106
|
-
self.
|
107
|
-
|
94
|
+
self._policy_executor: LLMPolicyExecutor[OutT_co, CtxT] = LLMPolicyExecutor[
|
95
|
+
self.out_type, CtxT
|
96
|
+
](
|
97
|
+
agent_name=self.name,
|
108
98
|
llm=llm,
|
109
99
|
tools=tools,
|
110
100
|
max_turns=max_turns,
|
111
101
|
react_mode=react_mode,
|
102
|
+
final_answer_as_tool_call=final_answer_as_tool_call,
|
112
103
|
)
|
113
104
|
|
114
105
|
# Prompt builder
|
106
|
+
|
115
107
|
sys_prompt = get_prompt(prompt_text=sys_prompt, prompt_path=sys_prompt_path)
|
116
108
|
in_prompt = get_prompt(prompt_text=in_prompt, prompt_path=in_prompt_path)
|
117
|
-
self._prompt_builder: PromptBuilder[
|
109
|
+
self._prompt_builder: PromptBuilder[InT_contra, CtxT] = PromptBuilder[
|
118
110
|
self.in_type, CtxT
|
119
111
|
](
|
120
|
-
|
121
|
-
|
122
|
-
|
112
|
+
agent_name=self._name,
|
113
|
+
sys_prompt_template=sys_prompt,
|
114
|
+
in_prompt_template=in_prompt,
|
123
115
|
sys_args_schema=sys_args_schema,
|
124
116
|
usr_args_schema=usr_args_schema,
|
125
117
|
)
|
@@ -130,53 +122,50 @@ class LLMAgent(
|
|
130
122
|
|
131
123
|
@property
|
132
124
|
def llm(self) -> LLM[LLMSettings, Converters]:
|
133
|
-
return self.
|
125
|
+
return self._policy_executor.llm
|
134
126
|
|
135
127
|
@property
|
136
128
|
def tools(self) -> dict[str, BaseTool[BaseModel, Any, CtxT]]:
|
137
|
-
return self.
|
129
|
+
return self._policy_executor.tools
|
138
130
|
|
139
131
|
@property
|
140
132
|
def max_turns(self) -> int:
|
141
|
-
return self.
|
133
|
+
return self._policy_executor.max_turns
|
142
134
|
|
143
135
|
@property
|
144
|
-
def sys_args_schema(self) -> type[LLMPromptArgs]:
|
136
|
+
def sys_args_schema(self) -> type[LLMPromptArgs] | None:
|
145
137
|
return self._prompt_builder.sys_args_schema
|
146
138
|
|
147
139
|
@property
|
148
|
-
def usr_args_schema(self) -> type[LLMPromptArgs]:
|
140
|
+
def usr_args_schema(self) -> type[LLMPromptArgs] | None:
|
149
141
|
return self._prompt_builder.usr_args_schema
|
150
142
|
|
151
143
|
@property
|
152
144
|
def sys_prompt(self) -> LLMPrompt | None:
|
153
|
-
return self._prompt_builder.
|
145
|
+
return self._prompt_builder.sys_prompt_template
|
154
146
|
|
155
147
|
@property
|
156
148
|
def in_prompt(self) -> LLMPrompt | None:
|
157
|
-
return self._prompt_builder.
|
149
|
+
return self._prompt_builder.in_prompt_template
|
158
150
|
|
159
151
|
def _parse_output(
|
160
152
|
self,
|
161
|
-
conversation:
|
153
|
+
conversation: Messages,
|
162
154
|
*,
|
163
|
-
in_args:
|
155
|
+
in_args: InT_contra | None = None,
|
164
156
|
batch_idx: int = 0,
|
165
|
-
ctx:
|
166
|
-
) ->
|
157
|
+
ctx: RunContext[CtxT] | None = None,
|
158
|
+
) -> OutT_co:
|
167
159
|
if self._parse_output_impl:
|
168
160
|
if self._using_default_llm_response_format:
|
169
161
|
# When using custom output parsing, the required LLM response format
|
170
162
|
# can differ from the final agent output type ->
|
171
163
|
# set it back to None unless it was specified explicitly at init.
|
172
|
-
self.
|
173
|
-
self._using_default_llm_response_format = False
|
164
|
+
self._policy_executor.llm.response_format = None
|
165
|
+
# self._using_default_llm_response_format = False
|
174
166
|
|
175
167
|
return self._parse_output_impl(
|
176
|
-
conversation=conversation,
|
177
|
-
in_args=in_args,
|
178
|
-
batch_idx=batch_idx,
|
179
|
-
ctx=ctx,
|
168
|
+
conversation=conversation, in_args=in_args, batch_idx=batch_idx, ctx=ctx
|
180
169
|
)
|
181
170
|
|
182
171
|
return validate_obj_from_json_or_py_string(
|
@@ -185,215 +174,182 @@ class LLMAgent(
|
|
185
174
|
from_substring=True,
|
186
175
|
)
|
187
176
|
|
188
|
-
|
189
|
-
def _validate_run_inputs(
|
190
|
-
chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
|
191
|
-
in_args: InT | Sequence[InT] | None = None,
|
192
|
-
in_message: AgentMessage[InT, AgentState] | None = None,
|
193
|
-
entry_point: bool = False,
|
194
|
-
) -> None:
|
195
|
-
multiple_inputs_err_message = (
|
196
|
-
"Only one of chat_inputs, in_args, or in_message must be provided."
|
197
|
-
)
|
198
|
-
if chat_inputs is not None and in_args is not None:
|
199
|
-
raise ValueError(multiple_inputs_err_message)
|
200
|
-
if chat_inputs is not None and in_message is not None:
|
201
|
-
raise ValueError(multiple_inputs_err_message)
|
202
|
-
if in_args is not None and in_message is not None:
|
203
|
-
raise ValueError(multiple_inputs_err_message)
|
204
|
-
|
205
|
-
if entry_point and in_message is not None:
|
206
|
-
raise ValueError(
|
207
|
-
"Entry point agent cannot receive messages from other agents."
|
208
|
-
)
|
209
|
-
|
210
|
-
@final
|
211
|
-
async def run(
|
177
|
+
def _memorize_inputs(
|
212
178
|
self,
|
213
179
|
chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
|
214
180
|
*,
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
ctx: RunContextWrapper[CtxT] | None = None,
|
219
|
-
forbid_state_change: bool = False,
|
220
|
-
**gen_kwargs: Any, # noqa: ARG002
|
221
|
-
) -> AgentMessage[OutT, LLMAgentState]:
|
181
|
+
in_args: Sequence[InT_contra] | None = None,
|
182
|
+
ctx: RunContext[CtxT] | None = None,
|
183
|
+
) -> tuple[SystemMessage | None, Sequence[UserMessage], LLMAgentMemory]:
|
222
184
|
# Get run arguments
|
223
|
-
sys_args: LLMPromptArgs =
|
224
|
-
usr_args: LLMPromptArgs |
|
185
|
+
sys_args: LLMPromptArgs | None = None
|
186
|
+
usr_args: LLMPromptArgs | None = None
|
225
187
|
if ctx is not None:
|
226
|
-
run_args = ctx.run_args.get(self.
|
188
|
+
run_args = ctx.run_args.get(self.name)
|
227
189
|
if run_args is not None:
|
228
190
|
sys_args = run_args.sys
|
229
191
|
usr_args = run_args.usr
|
230
192
|
|
231
|
-
self._validate_run_inputs(
|
232
|
-
chat_inputs=chat_inputs,
|
233
|
-
in_args=in_args,
|
234
|
-
in_message=in_message,
|
235
|
-
entry_point=entry_point,
|
236
|
-
)
|
237
|
-
resolved_in_args = in_message.payloads if in_message else in_args
|
238
|
-
|
239
193
|
# 1. Make system prompt (can be None)
|
194
|
+
|
240
195
|
formatted_sys_prompt = self._prompt_builder.make_sys_prompt(
|
241
196
|
sys_args=sys_args, ctx=ctx
|
242
197
|
)
|
243
198
|
|
244
|
-
# 2. Set agent
|
245
|
-
|
246
|
-
cur_state = self.state.model_copy(deep=True)
|
247
|
-
in_state = in_message.sender_state if in_message else None
|
248
|
-
prev_mh_len = len(cur_state.message_history)
|
199
|
+
# 2. Set agent memory
|
249
200
|
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
201
|
+
_memory = self.memory.model_copy(deep=True)
|
202
|
+
prev_message_hist_length = len(_memory.message_history)
|
203
|
+
if self._reset_memory_on_run or _memory.is_empty:
|
204
|
+
_memory.reset(formatted_sys_prompt)
|
205
|
+
elif self._set_memory_impl:
|
206
|
+
_memory = self._set_memory_impl(
|
207
|
+
prev_memory=_memory,
|
208
|
+
in_args=in_args,
|
209
|
+
sys_prompt=formatted_sys_prompt,
|
210
|
+
ctx=ctx,
|
211
|
+
)
|
260
212
|
|
261
|
-
# 3. Make and add user messages
|
213
|
+
# 3. Make and add user messages
|
262
214
|
|
263
215
|
user_message_batch = self._prompt_builder.make_user_messages(
|
264
|
-
chat_inputs=chat_inputs,
|
265
|
-
in_args=resolved_in_args,
|
266
|
-
usr_args=usr_args,
|
267
|
-
entry_point=entry_point,
|
268
|
-
ctx=ctx,
|
216
|
+
chat_inputs=chat_inputs, in_args_batch=in_args, usr_args=usr_args, ctx=ctx
|
269
217
|
)
|
270
218
|
if user_message_batch:
|
271
|
-
|
272
|
-
self._print_msgs(user_message_batch, ctx=ctx)
|
273
|
-
|
274
|
-
if not self.tools:
|
275
|
-
# 4. Generate messages without tools
|
276
|
-
await self._tool_orchestrator.generate_once(state=state, ctx=ctx)
|
277
|
-
else:
|
278
|
-
# 4. Run tool call loop (new messages are added to the message
|
279
|
-
# history inside the loop)
|
280
|
-
await self._tool_orchestrator.run_loop(state=state, ctx=ctx)
|
281
|
-
|
282
|
-
# 5. Parse outputs
|
283
|
-
|
284
|
-
val_output_batch: list[OutT] = []
|
285
|
-
for i, _conv in enumerate(state.message_history.batched_conversations):
|
286
|
-
if isinstance(resolved_in_args, Sequence):
|
287
|
-
_resolved_in_args = cast("Sequence[InT]", resolved_in_args)
|
288
|
-
_in_args = _resolved_in_args[min(i, len(_resolved_in_args) - 1)]
|
289
|
-
else:
|
290
|
-
_resolved_in_args = cast("InT | None", resolved_in_args)
|
291
|
-
_in_args = _resolved_in_args
|
292
|
-
|
293
|
-
val_output_batch.append(
|
294
|
-
self._out_type_adapter.validate_python(
|
295
|
-
self._parse_output(
|
296
|
-
conversation=_conv, in_args=_in_args, batch_idx=i, ctx=ctx
|
297
|
-
)
|
298
|
-
)
|
299
|
-
)
|
219
|
+
_memory.update(message_batch=user_message_batch)
|
300
220
|
|
301
|
-
#
|
221
|
+
# 4. Extract system message if it was added
|
302
222
|
|
303
|
-
|
223
|
+
system_message: SystemMessage | None = None
|
224
|
+
if (
|
225
|
+
len(_memory.message_history) == 1
|
226
|
+
and prev_message_hist_length == 0
|
227
|
+
and isinstance(_memory.message_history[0][0], SystemMessage)
|
228
|
+
):
|
229
|
+
system_message = _memory.message_history[0][0]
|
304
230
|
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
231
|
+
return system_message, user_message_batch, _memory
|
232
|
+
|
233
|
+
def _extract_outputs(
|
234
|
+
self,
|
235
|
+
memory: LLMAgentMemory,
|
236
|
+
in_args: Sequence[InT_contra] | None = None,
|
237
|
+
ctx: RunContext[CtxT] | None = None,
|
238
|
+
) -> Sequence[OutT_co]:
|
239
|
+
outputs: list[OutT_co] = []
|
240
|
+
for i, _conv in enumerate(memory.message_history.conversations):
|
241
|
+
if in_args is not None:
|
242
|
+
_in_args_single = in_args[min(i, len(in_args) - 1)]
|
243
|
+
else:
|
244
|
+
_in_args_single = None
|
245
|
+
|
246
|
+
outputs.append(
|
247
|
+
self._parse_output(
|
248
|
+
conversation=_conv, in_args=_in_args_single, batch_idx=i, ctx=ctx
|
322
249
|
)
|
323
250
|
)
|
324
251
|
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
252
|
+
return outputs
|
253
|
+
|
254
|
+
async def _process(
|
255
|
+
self,
|
256
|
+
chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
|
257
|
+
*,
|
258
|
+
in_args: Sequence[InT_contra] | None = None,
|
259
|
+
forgetful: bool = False,
|
260
|
+
ctx: RunContext[CtxT] | None = None,
|
261
|
+
) -> Sequence[OutT_co]:
|
262
|
+
system_message, user_message_batch, memory = self._memorize_inputs(
|
263
|
+
chat_inputs=chat_inputs, in_args=in_args, ctx=ctx
|
330
264
|
)
|
331
265
|
|
332
|
-
if not
|
333
|
-
self.
|
266
|
+
if system_message is not None:
|
267
|
+
self._print_messages([system_message], ctx=ctx)
|
268
|
+
if user_message_batch:
|
269
|
+
self._print_messages(user_message_batch, ctx=ctx)
|
334
270
|
|
335
|
-
|
271
|
+
await self._policy_executor.execute(memory, ctx=ctx)
|
336
272
|
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
) -> None:
|
342
|
-
if ctx:
|
343
|
-
ctx.printer.print_llm_messages(messages, agent_id=self.agent_id)
|
273
|
+
if not forgetful:
|
274
|
+
self._memory = memory
|
275
|
+
|
276
|
+
return self._extract_outputs(memory=memory, in_args=in_args, ctx=ctx)
|
344
277
|
|
345
|
-
def
|
278
|
+
async def _process_stream(
|
346
279
|
self,
|
347
|
-
|
348
|
-
|
349
|
-
|
280
|
+
chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
|
281
|
+
*,
|
282
|
+
in_args: Sequence[InT_contra] | None = None,
|
283
|
+
forgetful: bool = False,
|
284
|
+
ctx: RunContext[CtxT] | None = None,
|
285
|
+
) -> AsyncIterator[Event[Any]]:
|
286
|
+
system_message, user_message_batch, memory = self._memorize_inputs(
|
287
|
+
chat_inputs=chat_inputs, in_args=in_args, ctx=ctx
|
288
|
+
)
|
289
|
+
|
290
|
+
if system_message is not None:
|
291
|
+
yield SystemMessageEvent(data=system_message)
|
292
|
+
if user_message_batch:
|
293
|
+
for user_message in user_message_batch:
|
294
|
+
yield UserMessageEvent(data=user_message)
|
295
|
+
|
296
|
+
# 4. Run tool call loop (new messages are added to the message
|
297
|
+
# history inside the loop)
|
298
|
+
async for event in self._policy_executor.execute_stream(memory, ctx=ctx):
|
299
|
+
yield event
|
300
|
+
|
301
|
+
if not forgetful:
|
302
|
+
self._memory = memory
|
303
|
+
|
304
|
+
outputs = self._extract_outputs(memory=memory, in_args=in_args, ctx=ctx)
|
305
|
+
for output in outputs:
|
306
|
+
yield ProcOutputEvent(data=output, name=self.name)
|
307
|
+
|
308
|
+
def _print_messages(
|
309
|
+
self, messages: Sequence[Message], ctx: RunContext[CtxT] | None = None
|
350
310
|
) -> None:
|
351
|
-
if
|
352
|
-
|
353
|
-
and prev_mh_len == 0
|
354
|
-
and isinstance(state.message_history[0][0], SystemMessage)
|
355
|
-
):
|
356
|
-
self._print_msgs([state.message_history[0][0]], ctx=ctx)
|
311
|
+
if ctx:
|
312
|
+
ctx.printer.print_llm_messages(messages, agent_name=self.name)
|
357
313
|
|
358
|
-
# --
|
314
|
+
# -- Decorators for custom implementations --
|
359
315
|
|
360
|
-
def
|
361
|
-
self, func:
|
362
|
-
) ->
|
363
|
-
self._prompt_builder.
|
316
|
+
def make_sys_prompt(
|
317
|
+
self, func: MakeSystemPromptHandler[CtxT]
|
318
|
+
) -> MakeSystemPromptHandler[CtxT]:
|
319
|
+
self._prompt_builder.make_sys_prompt_impl = func
|
364
320
|
|
365
321
|
return func
|
366
322
|
|
367
|
-
def
|
368
|
-
self, func:
|
369
|
-
) ->
|
370
|
-
self._prompt_builder.
|
323
|
+
def make_in_content(
|
324
|
+
self, func: MakeInputContentHandler[InT_contra, CtxT]
|
325
|
+
) -> MakeInputContentHandler[InT_contra, CtxT]:
|
326
|
+
self._prompt_builder.make_in_content_impl = func
|
371
327
|
|
372
328
|
return func
|
373
329
|
|
374
|
-
def
|
375
|
-
self, func: ParseOutputHandler[
|
376
|
-
) -> ParseOutputHandler[
|
330
|
+
def parse_output(
|
331
|
+
self, func: ParseOutputHandler[InT_contra, OutT_co, CtxT]
|
332
|
+
) -> ParseOutputHandler[InT_contra, OutT_co, CtxT]:
|
377
333
|
self._parse_output_impl = func
|
378
334
|
|
379
335
|
return func
|
380
336
|
|
381
|
-
def
|
382
|
-
self.
|
337
|
+
def set_memory(self, func: SetMemoryHandler) -> SetMemoryHandler:
|
338
|
+
self._set_memory_impl = func
|
383
339
|
|
384
340
|
return func
|
385
341
|
|
386
|
-
def
|
387
|
-
self, func:
|
388
|
-
) ->
|
389
|
-
self.
|
342
|
+
def manage_memory(
|
343
|
+
self, func: ManageMemoryHandler[CtxT]
|
344
|
+
) -> ManageMemoryHandler[CtxT]:
|
345
|
+
self._policy_executor.manage_memory_impl = func
|
390
346
|
|
391
347
|
return func
|
392
348
|
|
393
|
-
def
|
394
|
-
self, func:
|
395
|
-
) ->
|
396
|
-
self.
|
349
|
+
def exit_tool_call_loop(
|
350
|
+
self, func: ExitToolCallLoopHandler[CtxT]
|
351
|
+
) -> ExitToolCallLoopHandler[CtxT]:
|
352
|
+
self._policy_executor.exit_tool_call_loop_impl = func
|
397
353
|
|
398
354
|
return func
|
399
355
|
|
@@ -403,78 +359,78 @@ class LLMAgent(
|
|
403
359
|
cur_cls = type(self)
|
404
360
|
base_cls = LLMAgent[Any, Any, Any]
|
405
361
|
|
406
|
-
if cur_cls.
|
407
|
-
self._prompt_builder.
|
362
|
+
if cur_cls._make_sys_prompt_fn is not base_cls._make_sys_prompt_fn: # noqa: SLF001
|
363
|
+
self._prompt_builder.make_sys_prompt_impl = self._make_sys_prompt_fn
|
408
364
|
|
409
|
-
if cur_cls.
|
410
|
-
self._prompt_builder.
|
365
|
+
if cur_cls._make_in_content_fn is not base_cls._make_in_content_fn: # noqa: SLF001
|
366
|
+
self._prompt_builder.make_in_content_impl = self._make_in_content_fn
|
411
367
|
|
412
|
-
if cur_cls.
|
413
|
-
self.
|
368
|
+
if cur_cls._set_memory_fn is not base_cls._set_memory_fn: # noqa: SLF001
|
369
|
+
self._set_memory_impl = self._set_memory_fn
|
414
370
|
|
415
|
-
if cur_cls.
|
416
|
-
self.
|
371
|
+
if cur_cls._manage_memory_fn is not base_cls._manage_memory_fn: # noqa: SLF001
|
372
|
+
self._policy_executor.manage_memory_impl = self._manage_memory_fn
|
417
373
|
|
418
374
|
if (
|
419
|
-
cur_cls.
|
375
|
+
cur_cls._exit_tool_call_loop_fn is not base_cls._exit_tool_call_loop_fn # noqa: SLF001
|
420
376
|
):
|
421
|
-
self.
|
377
|
+
self._policy_executor.exit_tool_call_loop_impl = (
|
378
|
+
self._exit_tool_call_loop_fn
|
379
|
+
)
|
422
380
|
|
423
|
-
self._parse_output_impl:
|
381
|
+
self._parse_output_impl: (
|
382
|
+
ParseOutputHandler[InT_contra, OutT_co, CtxT] | None
|
383
|
+
) = None
|
424
384
|
|
425
|
-
def
|
426
|
-
self,
|
427
|
-
|
428
|
-
*,
|
429
|
-
ctx: RunContextWrapper[CtxT] | None = None,
|
430
|
-
) -> LLMFormattedSystemArgs:
|
385
|
+
def _make_sys_prompt_fn(
|
386
|
+
self, sys_args: LLMPromptArgs | None, *, ctx: RunContext[CtxT] | None = None
|
387
|
+
) -> str:
|
431
388
|
raise NotImplementedError(
|
432
389
|
"LLMAgent._format_sys_args must be overridden by a subclass "
|
433
390
|
"if it's intended to be used as the system arguments formatter."
|
434
391
|
)
|
435
392
|
|
436
|
-
def
|
393
|
+
def _make_in_content_fn(
|
437
394
|
self,
|
438
395
|
*,
|
439
|
-
|
440
|
-
|
396
|
+
in_args: InT_contra | None = None,
|
397
|
+
usr_args: LLMPromptArgs | None = None,
|
441
398
|
batch_idx: int = 0,
|
442
|
-
ctx:
|
443
|
-
) ->
|
399
|
+
ctx: RunContext[CtxT] | None = None,
|
400
|
+
) -> Content:
|
444
401
|
raise NotImplementedError(
|
445
402
|
"LLMAgent._format_in_args must be overridden by a subclass"
|
446
403
|
)
|
447
404
|
|
448
|
-
def
|
405
|
+
def _set_memory_fn(
|
449
406
|
self,
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
) -> LLMAgentState:
|
407
|
+
prev_memory: LLMAgentMemory,
|
408
|
+
in_args: InT_contra | Sequence[InT_contra] | None = None,
|
409
|
+
sys_prompt: LLMPrompt | None = None,
|
410
|
+
ctx: RunContext[Any] | None = None,
|
411
|
+
) -> LLMAgentMemory:
|
456
412
|
raise NotImplementedError(
|
457
|
-
"LLMAgent.
|
413
|
+
"LLMAgent._set_memory must be overridden by a subclass"
|
458
414
|
)
|
459
415
|
|
460
|
-
def
|
416
|
+
def _exit_tool_call_loop_fn(
|
461
417
|
self,
|
462
|
-
conversation:
|
418
|
+
conversation: Messages,
|
463
419
|
*,
|
464
|
-
ctx:
|
420
|
+
ctx: RunContext[CtxT] | None = None,
|
465
421
|
**kwargs: Any,
|
466
422
|
) -> bool:
|
467
423
|
raise NotImplementedError(
|
468
|
-
"LLMAgent.
|
424
|
+
"LLMAgent._exit_tool_call_loop must be overridden by a subclass"
|
469
425
|
)
|
470
426
|
|
471
|
-
def
|
427
|
+
def _manage_memory_fn(
|
472
428
|
self,
|
473
|
-
|
429
|
+
memory: LLMAgentMemory,
|
474
430
|
*,
|
475
|
-
ctx:
|
431
|
+
ctx: RunContext[CtxT] | None = None,
|
476
432
|
**kwargs: Any,
|
477
433
|
) -> None:
|
478
434
|
raise NotImplementedError(
|
479
|
-
"LLMAgent.
|
435
|
+
"LLMAgent._manage_memory must be overridden by a subclass"
|
480
436
|
)
|