grasp_agents 0.2.10__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_agents/__init__.py +15 -14
- grasp_agents/cloud_llm.py +118 -131
- grasp_agents/comm_processor.py +201 -0
- grasp_agents/generics_utils.py +15 -7
- grasp_agents/llm.py +60 -31
- grasp_agents/llm_agent.py +229 -278
- grasp_agents/llm_agent_memory.py +58 -0
- grasp_agents/llm_policy_executor.py +482 -0
- grasp_agents/memory.py +20 -134
- grasp_agents/message_history.py +140 -0
- grasp_agents/openai/__init__.py +54 -36
- grasp_agents/openai/completion_chunk_converters.py +78 -0
- grasp_agents/openai/completion_converters.py +53 -30
- grasp_agents/openai/content_converters.py +13 -14
- grasp_agents/openai/converters.py +44 -68
- grasp_agents/openai/message_converters.py +58 -72
- grasp_agents/openai/openai_llm.py +101 -42
- grasp_agents/openai/tool_converters.py +24 -19
- grasp_agents/packet.py +24 -0
- grasp_agents/packet_pool.py +91 -0
- grasp_agents/printer.py +29 -15
- grasp_agents/processor.py +194 -0
- grasp_agents/prompt_builder.py +173 -176
- grasp_agents/run_context.py +21 -41
- grasp_agents/typing/completion.py +58 -12
- grasp_agents/typing/completion_chunk.py +173 -0
- grasp_agents/typing/converters.py +8 -12
- grasp_agents/typing/events.py +86 -0
- grasp_agents/typing/io.py +4 -13
- grasp_agents/typing/message.py +12 -50
- grasp_agents/typing/tool.py +52 -26
- grasp_agents/usage_tracker.py +6 -6
- grasp_agents/utils.py +3 -3
- grasp_agents/workflow/looped_workflow.py +132 -0
- grasp_agents/workflow/parallel_processor.py +95 -0
- grasp_agents/workflow/sequential_workflow.py +66 -0
- grasp_agents/workflow/workflow_processor.py +78 -0
- {grasp_agents-0.2.10.dist-info → grasp_agents-0.3.1.dist-info}/METADATA +41 -50
- grasp_agents-0.3.1.dist-info/RECORD +51 -0
- grasp_agents/agent_message.py +0 -27
- grasp_agents/agent_message_pool.py +0 -92
- grasp_agents/base_agent.py +0 -51
- grasp_agents/comm_agent.py +0 -217
- grasp_agents/llm_agent_state.py +0 -79
- grasp_agents/tool_orchestrator.py +0 -203
- grasp_agents/workflow/looped_agent.py +0 -120
- grasp_agents/workflow/sequential_agent.py +0 -63
- grasp_agents/workflow/workflow_agent.py +0 -73
- grasp_agents-0.2.10.dist-info/RECORD +0 -46
- {grasp_agents-0.2.10.dist-info → grasp_agents-0.3.1.dist-info}/WHEEL +0 -0
- {grasp_agents-0.2.10.dist-info → grasp_agents-0.3.1.dist-info}/licenses/LICENSE.md +0 -0
grasp_agents/llm_agent.py
CHANGED
@@ -1,66 +1,47 @@
|
|
1
|
-
from collections.abc import Sequence
|
1
|
+
from collections.abc import AsyncIterator, Sequence
|
2
2
|
from pathlib import Path
|
3
|
-
from typing import Any, ClassVar, Generic, Protocol
|
3
|
+
from typing import Any, ClassVar, Generic, Protocol
|
4
4
|
|
5
5
|
from pydantic import BaseModel
|
6
6
|
|
7
|
-
from .
|
8
|
-
from .agent_message_pool import AgentMessagePool
|
9
|
-
from .comm_agent import CommunicatingAgent
|
7
|
+
from .comm_processor import CommProcessor
|
10
8
|
from .llm import LLM, LLMSettings
|
11
|
-
from .
|
12
|
-
|
13
|
-
|
14
|
-
|
9
|
+
from .llm_agent_memory import LLMAgentMemory, SetMemoryHandler
|
10
|
+
from .llm_policy_executor import (
|
11
|
+
ExitToolCallLoopHandler,
|
12
|
+
LLMPolicyExecutor,
|
13
|
+
ManageMemoryHandler,
|
15
14
|
)
|
15
|
+
from .packet_pool import PacketPool
|
16
16
|
from .prompt_builder import (
|
17
|
-
|
18
|
-
|
17
|
+
MakeInputContentHandler,
|
18
|
+
MakeSystemPromptHandler,
|
19
19
|
PromptBuilder,
|
20
20
|
)
|
21
|
-
from .run_context import
|
22
|
-
|
23
|
-
InteractionRecord,
|
24
|
-
RunContextWrapper,
|
25
|
-
SystemRunArgs,
|
26
|
-
UserRunArgs,
|
27
|
-
)
|
28
|
-
from .tool_orchestrator import (
|
29
|
-
ExitToolCallLoopHandler,
|
30
|
-
ManageAgentStateHandler,
|
31
|
-
ToolOrchestrator,
|
32
|
-
)
|
33
|
-
from .typing.content import ImageData
|
21
|
+
from .run_context import CtxT, RunContext
|
22
|
+
from .typing.content import Content, ImageData
|
34
23
|
from .typing.converters import Converters
|
35
|
-
from .typing.
|
36
|
-
|
37
|
-
|
38
|
-
InT,
|
39
|
-
LLMFormattedArgs,
|
40
|
-
LLMFormattedSystemArgs,
|
41
|
-
LLMPrompt,
|
42
|
-
LLMPromptArgs,
|
43
|
-
OutT,
|
44
|
-
)
|
45
|
-
from .typing.message import Conversation, Message, SystemMessage
|
24
|
+
from .typing.events import Event, ProcOutputEvent, SystemMessageEvent, UserMessageEvent
|
25
|
+
from .typing.io import InT_contra, LLMPrompt, LLMPromptArgs, OutT_co, ProcName
|
26
|
+
from .typing.message import Message, Messages, SystemMessage, UserMessage
|
46
27
|
from .typing.tool import BaseTool
|
47
28
|
from .utils import get_prompt, validate_obj_from_json_or_py_string
|
48
29
|
|
49
30
|
|
50
|
-
class ParseOutputHandler(Protocol[
|
31
|
+
class ParseOutputHandler(Protocol[InT_contra, OutT_co, CtxT]):
|
51
32
|
def __call__(
|
52
33
|
self,
|
53
|
-
conversation:
|
34
|
+
conversation: Messages,
|
54
35
|
*,
|
55
|
-
in_args:
|
36
|
+
in_args: InT_contra | None,
|
56
37
|
batch_idx: int,
|
57
|
-
ctx:
|
58
|
-
) ->
|
38
|
+
ctx: RunContext[CtxT] | None,
|
39
|
+
) -> OutT_co: ...
|
59
40
|
|
60
41
|
|
61
42
|
class LLMAgent(
|
62
|
-
|
63
|
-
Generic[
|
43
|
+
CommProcessor[InT_contra, OutT_co, LLMAgentMemory, CtxT],
|
44
|
+
Generic[InT_contra, OutT_co, CtxT],
|
64
45
|
):
|
65
46
|
_generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
|
66
47
|
0: "_in_type",
|
@@ -69,63 +50,68 @@ class LLMAgent(
|
|
69
50
|
|
70
51
|
def __init__(
|
71
52
|
self,
|
72
|
-
|
53
|
+
name: ProcName,
|
73
54
|
*,
|
74
55
|
# LLM
|
75
56
|
llm: LLM[LLMSettings, Converters],
|
57
|
+
# Tools
|
58
|
+
tools: list[BaseTool[Any, Any, CtxT]] | None = None,
|
76
59
|
# Input prompt template (combines user and received arguments)
|
77
60
|
in_prompt: LLMPrompt | None = None,
|
78
61
|
in_prompt_path: str | Path | None = None,
|
79
62
|
# System prompt template
|
80
63
|
sys_prompt: LLMPrompt | None = None,
|
81
64
|
sys_prompt_path: str | Path | None = None,
|
82
|
-
# System args (static args provided via
|
65
|
+
# System args (static args provided via RunContext)
|
83
66
|
sys_args_schema: type[LLMPromptArgs] = LLMPromptArgs,
|
84
|
-
# User args (static args provided via
|
67
|
+
# User args (static args provided via RunContext)
|
85
68
|
usr_args_schema: type[LLMPromptArgs] = LLMPromptArgs,
|
86
|
-
#
|
87
|
-
|
88
|
-
max_turns: int = 1000,
|
69
|
+
# Agent loop settings
|
70
|
+
max_turns: int = 100,
|
89
71
|
react_mode: bool = False,
|
90
|
-
|
91
|
-
|
72
|
+
final_answer_as_tool_call: bool = False,
|
73
|
+
# Agent memory management
|
74
|
+
reset_memory_on_run: bool = False,
|
92
75
|
# Multi-agent routing
|
93
|
-
|
94
|
-
|
76
|
+
packet_pool: PacketPool[CtxT] | None = None,
|
77
|
+
recipients: list[ProcName] | None = None,
|
95
78
|
) -> None:
|
96
|
-
super().__init__(
|
97
|
-
|
98
|
-
|
79
|
+
super().__init__(name=name, packet_pool=packet_pool, recipients=recipients)
|
80
|
+
|
81
|
+
# Agent memory
|
99
82
|
|
100
|
-
|
101
|
-
self.
|
102
|
-
self.
|
103
|
-
self._set_agent_state_impl: SetAgentState | None = None
|
83
|
+
self._memory: LLMAgentMemory = LLMAgentMemory()
|
84
|
+
self._reset_memory_on_run = reset_memory_on_run
|
85
|
+
self._set_memory_impl: SetMemoryHandler | None = None
|
104
86
|
|
105
|
-
#
|
87
|
+
# LLM policy executor
|
106
88
|
|
107
89
|
self._using_default_llm_response_format: bool = False
|
108
90
|
if llm.response_format is None and tools is None:
|
109
91
|
llm.response_format = self.out_type
|
110
92
|
self._using_default_llm_response_format = True
|
111
93
|
|
112
|
-
self.
|
113
|
-
|
94
|
+
self._policy_executor: LLMPolicyExecutor[OutT_co, CtxT] = LLMPolicyExecutor[
|
95
|
+
self.out_type, CtxT
|
96
|
+
](
|
97
|
+
agent_name=self.name,
|
114
98
|
llm=llm,
|
115
99
|
tools=tools,
|
116
100
|
max_turns=max_turns,
|
117
101
|
react_mode=react_mode,
|
102
|
+
final_answer_as_tool_call=final_answer_as_tool_call,
|
118
103
|
)
|
119
104
|
|
120
105
|
# Prompt builder
|
106
|
+
|
121
107
|
sys_prompt = get_prompt(prompt_text=sys_prompt, prompt_path=sys_prompt_path)
|
122
108
|
in_prompt = get_prompt(prompt_text=in_prompt, prompt_path=in_prompt_path)
|
123
|
-
self._prompt_builder: PromptBuilder[
|
109
|
+
self._prompt_builder: PromptBuilder[InT_contra, CtxT] = PromptBuilder[
|
124
110
|
self.in_type, CtxT
|
125
111
|
](
|
126
|
-
|
127
|
-
|
128
|
-
|
112
|
+
agent_name=self._name,
|
113
|
+
sys_prompt_template=sys_prompt,
|
114
|
+
in_prompt_template=in_prompt,
|
129
115
|
sys_args_schema=sys_args_schema,
|
130
116
|
usr_args_schema=usr_args_schema,
|
131
117
|
)
|
@@ -136,53 +122,50 @@ class LLMAgent(
|
|
136
122
|
|
137
123
|
@property
|
138
124
|
def llm(self) -> LLM[LLMSettings, Converters]:
|
139
|
-
return self.
|
125
|
+
return self._policy_executor.llm
|
140
126
|
|
141
127
|
@property
|
142
128
|
def tools(self) -> dict[str, BaseTool[BaseModel, Any, CtxT]]:
|
143
|
-
return self.
|
129
|
+
return self._policy_executor.tools
|
144
130
|
|
145
131
|
@property
|
146
132
|
def max_turns(self) -> int:
|
147
|
-
return self.
|
133
|
+
return self._policy_executor.max_turns
|
148
134
|
|
149
135
|
@property
|
150
|
-
def sys_args_schema(self) -> type[LLMPromptArgs]:
|
136
|
+
def sys_args_schema(self) -> type[LLMPromptArgs] | None:
|
151
137
|
return self._prompt_builder.sys_args_schema
|
152
138
|
|
153
139
|
@property
|
154
|
-
def usr_args_schema(self) -> type[LLMPromptArgs]:
|
140
|
+
def usr_args_schema(self) -> type[LLMPromptArgs] | None:
|
155
141
|
return self._prompt_builder.usr_args_schema
|
156
142
|
|
157
143
|
@property
|
158
144
|
def sys_prompt(self) -> LLMPrompt | None:
|
159
|
-
return self._prompt_builder.
|
145
|
+
return self._prompt_builder.sys_prompt_template
|
160
146
|
|
161
147
|
@property
|
162
148
|
def in_prompt(self) -> LLMPrompt | None:
|
163
|
-
return self._prompt_builder.
|
149
|
+
return self._prompt_builder.in_prompt_template
|
164
150
|
|
165
151
|
def _parse_output(
|
166
152
|
self,
|
167
|
-
conversation:
|
153
|
+
conversation: Messages,
|
168
154
|
*,
|
169
|
-
in_args:
|
155
|
+
in_args: InT_contra | None = None,
|
170
156
|
batch_idx: int = 0,
|
171
|
-
ctx:
|
172
|
-
) ->
|
157
|
+
ctx: RunContext[CtxT] | None = None,
|
158
|
+
) -> OutT_co:
|
173
159
|
if self._parse_output_impl:
|
174
160
|
if self._using_default_llm_response_format:
|
175
161
|
# When using custom output parsing, the required LLM response format
|
176
162
|
# can differ from the final agent output type ->
|
177
163
|
# set it back to None unless it was specified explicitly at init.
|
178
|
-
self.
|
179
|
-
self._using_default_llm_response_format = False
|
164
|
+
self._policy_executor.llm.response_format = None
|
165
|
+
# self._using_default_llm_response_format = False
|
180
166
|
|
181
167
|
return self._parse_output_impl(
|
182
|
-
conversation=conversation,
|
183
|
-
in_args=in_args,
|
184
|
-
batch_idx=batch_idx,
|
185
|
-
ctx=ctx,
|
168
|
+
conversation=conversation, in_args=in_args, batch_idx=batch_idx, ctx=ctx
|
186
169
|
)
|
187
170
|
|
188
171
|
return validate_obj_from_json_or_py_string(
|
@@ -191,214 +174,182 @@ class LLMAgent(
|
|
191
174
|
from_substring=True,
|
192
175
|
)
|
193
176
|
|
194
|
-
|
195
|
-
def _validate_run_inputs(
|
196
|
-
chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
|
197
|
-
in_args: InT | Sequence[InT] | None = None,
|
198
|
-
in_message: AgentMessage[InT, AgentState] | None = None,
|
199
|
-
entry_point: bool = False,
|
200
|
-
) -> None:
|
201
|
-
multiple_inputs_err_message = (
|
202
|
-
"Only one of chat_inputs, in_args, or in_message must be provided."
|
203
|
-
)
|
204
|
-
if chat_inputs is not None and in_args is not None:
|
205
|
-
raise ValueError(multiple_inputs_err_message)
|
206
|
-
if chat_inputs is not None and in_message is not None:
|
207
|
-
raise ValueError(multiple_inputs_err_message)
|
208
|
-
if in_args is not None and in_message is not None:
|
209
|
-
raise ValueError(multiple_inputs_err_message)
|
210
|
-
|
211
|
-
if entry_point and in_message is not None:
|
212
|
-
raise ValueError(
|
213
|
-
"Entry point agent cannot receive messages from other agents."
|
214
|
-
)
|
215
|
-
|
216
|
-
@final
|
217
|
-
async def run(
|
177
|
+
def _memorize_inputs(
|
218
178
|
self,
|
219
179
|
chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
|
220
180
|
*,
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
ctx: RunContextWrapper[CtxT] | None = None,
|
225
|
-
forbid_state_change: bool = False,
|
226
|
-
**gen_kwargs: Any, # noqa: ARG002
|
227
|
-
) -> AgentMessage[OutT, LLMAgentState]:
|
181
|
+
in_args: Sequence[InT_contra] | None = None,
|
182
|
+
ctx: RunContext[CtxT] | None = None,
|
183
|
+
) -> tuple[SystemMessage | None, Sequence[UserMessage], LLMAgentMemory]:
|
228
184
|
# Get run arguments
|
229
|
-
sys_args:
|
230
|
-
usr_args:
|
185
|
+
sys_args: LLMPromptArgs | None = None
|
186
|
+
usr_args: LLMPromptArgs | None = None
|
231
187
|
if ctx is not None:
|
232
|
-
run_args = ctx.run_args.get(self.
|
188
|
+
run_args = ctx.run_args.get(self.name)
|
233
189
|
if run_args is not None:
|
234
190
|
sys_args = run_args.sys
|
235
191
|
usr_args = run_args.usr
|
236
192
|
|
237
|
-
self._validate_run_inputs(
|
238
|
-
chat_inputs=chat_inputs,
|
239
|
-
in_args=in_args,
|
240
|
-
in_message=in_message,
|
241
|
-
entry_point=entry_point,
|
242
|
-
)
|
243
|
-
|
244
193
|
# 1. Make system prompt (can be None)
|
194
|
+
|
245
195
|
formatted_sys_prompt = self._prompt_builder.make_sys_prompt(
|
246
196
|
sys_args=sys_args, ctx=ctx
|
247
197
|
)
|
248
198
|
|
249
|
-
# 2. Set agent
|
250
|
-
|
251
|
-
cur_state = self.state.model_copy(deep=True)
|
252
|
-
in_state = in_message.sender_state if in_message else None
|
253
|
-
prev_mh_len = len(cur_state.message_history)
|
254
|
-
|
255
|
-
state = LLMAgentState.from_cur_and_in_states(
|
256
|
-
cur_state=cur_state,
|
257
|
-
in_state=in_state,
|
258
|
-
sys_prompt=formatted_sys_prompt,
|
259
|
-
strategy=self.set_state_strategy,
|
260
|
-
set_agent_state_impl=self._set_agent_state_impl,
|
261
|
-
ctx=ctx,
|
262
|
-
)
|
199
|
+
# 2. Set agent memory
|
263
200
|
|
264
|
-
self.
|
201
|
+
_memory = self.memory.model_copy(deep=True)
|
202
|
+
prev_message_hist_length = len(_memory.message_history)
|
203
|
+
if self._reset_memory_on_run or _memory.is_empty:
|
204
|
+
_memory.reset(formatted_sys_prompt)
|
205
|
+
elif self._set_memory_impl:
|
206
|
+
_memory = self._set_memory_impl(
|
207
|
+
prev_memory=_memory,
|
208
|
+
in_args=in_args,
|
209
|
+
sys_prompt=formatted_sys_prompt,
|
210
|
+
ctx=ctx,
|
211
|
+
)
|
265
212
|
|
266
|
-
# 3. Make and add user messages
|
267
|
-
_in_args_batch: Sequence[InT] | None = None
|
268
|
-
if in_message is not None:
|
269
|
-
_in_args_batch = in_message.payloads
|
270
|
-
elif in_args is not None:
|
271
|
-
_in_args_batch = in_args if isinstance(in_args, Sequence) else [in_args] # type: ignore[assignment]
|
213
|
+
# 3. Make and add user messages
|
272
214
|
|
273
215
|
user_message_batch = self._prompt_builder.make_user_messages(
|
274
|
-
chat_inputs=chat_inputs,
|
275
|
-
usr_args=usr_args,
|
276
|
-
in_args_batch=_in_args_batch,
|
277
|
-
entry_point=entry_point,
|
278
|
-
ctx=ctx,
|
216
|
+
chat_inputs=chat_inputs, in_args_batch=in_args, usr_args=usr_args, ctx=ctx
|
279
217
|
)
|
280
218
|
if user_message_batch:
|
281
|
-
|
282
|
-
self._print_msgs(user_message_batch, ctx=ctx)
|
283
|
-
|
284
|
-
if not self.tools:
|
285
|
-
# 4. Generate messages without tools
|
286
|
-
await self._tool_orchestrator.generate_once(state=state, ctx=ctx)
|
287
|
-
else:
|
288
|
-
# 4. Run tool call loop (new messages are added to the message
|
289
|
-
# history inside the loop)
|
290
|
-
await self._tool_orchestrator.run_loop(state=state, ctx=ctx)
|
291
|
-
|
292
|
-
# 5. Parse outputs
|
293
|
-
batch_size = state.message_history.batch_size
|
294
|
-
in_args_batch = in_message.payloads if in_message else batch_size * [None]
|
295
|
-
val_output_batch = [
|
296
|
-
self._out_type_adapter.validate_python(
|
297
|
-
self._parse_output(conversation=conv, in_args=in_args, ctx=ctx)
|
298
|
-
)
|
299
|
-
for conv, in_args in zip(
|
300
|
-
state.message_history.batched_conversations,
|
301
|
-
in_args_batch,
|
302
|
-
strict=False,
|
303
|
-
)
|
304
|
-
]
|
219
|
+
_memory.update(message_batch=user_message_batch)
|
305
220
|
|
306
|
-
#
|
221
|
+
# 4. Extract system message if it was added
|
307
222
|
|
308
|
-
|
223
|
+
system_message: SystemMessage | None = None
|
224
|
+
if (
|
225
|
+
len(_memory.message_history) == 1
|
226
|
+
and prev_message_hist_length == 0
|
227
|
+
and isinstance(_memory.message_history[0][0], SystemMessage)
|
228
|
+
):
|
229
|
+
system_message = _memory.message_history[0][0]
|
309
230
|
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
231
|
+
return system_message, user_message_batch, _memory
|
232
|
+
|
233
|
+
def _extract_outputs(
|
234
|
+
self,
|
235
|
+
memory: LLMAgentMemory,
|
236
|
+
in_args: Sequence[InT_contra] | None = None,
|
237
|
+
ctx: RunContext[CtxT] | None = None,
|
238
|
+
) -> Sequence[OutT_co]:
|
239
|
+
outputs: list[OutT_co] = []
|
240
|
+
for i, _conv in enumerate(memory.message_history.conversations):
|
241
|
+
if in_args is not None:
|
242
|
+
_in_args_single = in_args[min(i, len(in_args) - 1)]
|
243
|
+
else:
|
244
|
+
_in_args_single = None
|
245
|
+
|
246
|
+
outputs.append(
|
247
|
+
self._parse_output(
|
248
|
+
conversation=_conv, in_args=_in_args_single, batch_idx=i, ctx=ctx
|
327
249
|
)
|
328
250
|
)
|
329
251
|
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
252
|
+
return outputs
|
253
|
+
|
254
|
+
async def _process(
|
255
|
+
self,
|
256
|
+
chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
|
257
|
+
*,
|
258
|
+
in_args: Sequence[InT_contra] | None = None,
|
259
|
+
forgetful: bool = False,
|
260
|
+
ctx: RunContext[CtxT] | None = None,
|
261
|
+
) -> Sequence[OutT_co]:
|
262
|
+
system_message, user_message_batch, memory = self._memorize_inputs(
|
263
|
+
chat_inputs=chat_inputs, in_args=in_args, ctx=ctx
|
335
264
|
)
|
336
265
|
|
337
|
-
if not
|
338
|
-
self.
|
266
|
+
if system_message is not None:
|
267
|
+
self._print_messages([system_message], ctx=ctx)
|
268
|
+
if user_message_batch:
|
269
|
+
self._print_messages(user_message_batch, ctx=ctx)
|
339
270
|
|
340
|
-
|
271
|
+
await self._policy_executor.execute(memory, ctx=ctx)
|
341
272
|
|
342
|
-
|
343
|
-
|
344
|
-
messages: Sequence[Message],
|
345
|
-
ctx: RunContextWrapper[CtxT] | None = None,
|
346
|
-
) -> None:
|
347
|
-
if ctx:
|
348
|
-
ctx.printer.print_llm_messages(messages, agent_id=self.agent_id)
|
273
|
+
if not forgetful:
|
274
|
+
self._memory = memory
|
349
275
|
|
350
|
-
|
276
|
+
return self._extract_outputs(memory=memory, in_args=in_args, ctx=ctx)
|
277
|
+
|
278
|
+
async def _process_stream(
|
351
279
|
self,
|
352
|
-
|
353
|
-
|
354
|
-
|
280
|
+
chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
|
281
|
+
*,
|
282
|
+
in_args: Sequence[InT_contra] | None = None,
|
283
|
+
forgetful: bool = False,
|
284
|
+
ctx: RunContext[CtxT] | None = None,
|
285
|
+
) -> AsyncIterator[Event[Any]]:
|
286
|
+
system_message, user_message_batch, memory = self._memorize_inputs(
|
287
|
+
chat_inputs=chat_inputs, in_args=in_args, ctx=ctx
|
288
|
+
)
|
289
|
+
|
290
|
+
if system_message is not None:
|
291
|
+
yield SystemMessageEvent(data=system_message)
|
292
|
+
if user_message_batch:
|
293
|
+
for user_message in user_message_batch:
|
294
|
+
yield UserMessageEvent(data=user_message)
|
295
|
+
|
296
|
+
# 4. Run tool call loop (new messages are added to the message
|
297
|
+
# history inside the loop)
|
298
|
+
async for event in self._policy_executor.execute_stream(memory, ctx=ctx):
|
299
|
+
yield event
|
300
|
+
|
301
|
+
if not forgetful:
|
302
|
+
self._memory = memory
|
303
|
+
|
304
|
+
outputs = self._extract_outputs(memory=memory, in_args=in_args, ctx=ctx)
|
305
|
+
for output in outputs:
|
306
|
+
yield ProcOutputEvent(data=output, name=self.name)
|
307
|
+
|
308
|
+
def _print_messages(
|
309
|
+
self, messages: Sequence[Message], ctx: RunContext[CtxT] | None = None
|
355
310
|
) -> None:
|
356
|
-
if
|
357
|
-
|
358
|
-
and prev_mh_len == 0
|
359
|
-
and isinstance(state.message_history[0][0], SystemMessage)
|
360
|
-
):
|
361
|
-
self._print_msgs([state.message_history[0][0]], ctx=ctx)
|
311
|
+
if ctx:
|
312
|
+
ctx.printer.print_llm_messages(messages, agent_name=self.name)
|
362
313
|
|
363
|
-
# --
|
314
|
+
# -- Decorators for custom implementations --
|
364
315
|
|
365
|
-
def
|
366
|
-
self, func:
|
367
|
-
) ->
|
368
|
-
self._prompt_builder.
|
316
|
+
def make_sys_prompt(
|
317
|
+
self, func: MakeSystemPromptHandler[CtxT]
|
318
|
+
) -> MakeSystemPromptHandler[CtxT]:
|
319
|
+
self._prompt_builder.make_sys_prompt_impl = func
|
369
320
|
|
370
321
|
return func
|
371
322
|
|
372
|
-
def
|
373
|
-
self, func:
|
374
|
-
) ->
|
375
|
-
self._prompt_builder.
|
323
|
+
def make_in_content(
|
324
|
+
self, func: MakeInputContentHandler[InT_contra, CtxT]
|
325
|
+
) -> MakeInputContentHandler[InT_contra, CtxT]:
|
326
|
+
self._prompt_builder.make_in_content_impl = func
|
376
327
|
|
377
328
|
return func
|
378
329
|
|
379
|
-
def
|
380
|
-
self, func: ParseOutputHandler[
|
381
|
-
) -> ParseOutputHandler[
|
330
|
+
def parse_output(
|
331
|
+
self, func: ParseOutputHandler[InT_contra, OutT_co, CtxT]
|
332
|
+
) -> ParseOutputHandler[InT_contra, OutT_co, CtxT]:
|
382
333
|
self._parse_output_impl = func
|
383
334
|
|
384
335
|
return func
|
385
336
|
|
386
|
-
def
|
387
|
-
self.
|
337
|
+
def set_memory(self, func: SetMemoryHandler) -> SetMemoryHandler:
|
338
|
+
self._set_memory_impl = func
|
388
339
|
|
389
340
|
return func
|
390
341
|
|
391
|
-
def
|
392
|
-
self, func:
|
393
|
-
) ->
|
394
|
-
self.
|
342
|
+
def manage_memory(
|
343
|
+
self, func: ManageMemoryHandler[CtxT]
|
344
|
+
) -> ManageMemoryHandler[CtxT]:
|
345
|
+
self._policy_executor.manage_memory_impl = func
|
395
346
|
|
396
347
|
return func
|
397
348
|
|
398
|
-
def
|
399
|
-
self, func:
|
400
|
-
) ->
|
401
|
-
self.
|
349
|
+
def exit_tool_call_loop(
|
350
|
+
self, func: ExitToolCallLoopHandler[CtxT]
|
351
|
+
) -> ExitToolCallLoopHandler[CtxT]:
|
352
|
+
self._policy_executor.exit_tool_call_loop_impl = func
|
402
353
|
|
403
354
|
return func
|
404
355
|
|
@@ -408,78 +359,78 @@ class LLMAgent(
|
|
408
359
|
cur_cls = type(self)
|
409
360
|
base_cls = LLMAgent[Any, Any, Any]
|
410
361
|
|
411
|
-
if cur_cls.
|
412
|
-
self._prompt_builder.
|
362
|
+
if cur_cls._make_sys_prompt_fn is not base_cls._make_sys_prompt_fn: # noqa: SLF001
|
363
|
+
self._prompt_builder.make_sys_prompt_impl = self._make_sys_prompt_fn
|
413
364
|
|
414
|
-
if cur_cls.
|
415
|
-
self._prompt_builder.
|
365
|
+
if cur_cls._make_in_content_fn is not base_cls._make_in_content_fn: # noqa: SLF001
|
366
|
+
self._prompt_builder.make_in_content_impl = self._make_in_content_fn
|
416
367
|
|
417
|
-
if cur_cls.
|
418
|
-
self.
|
368
|
+
if cur_cls._set_memory_fn is not base_cls._set_memory_fn: # noqa: SLF001
|
369
|
+
self._set_memory_impl = self._set_memory_fn
|
419
370
|
|
420
|
-
if cur_cls.
|
421
|
-
self.
|
371
|
+
if cur_cls._manage_memory_fn is not base_cls._manage_memory_fn: # noqa: SLF001
|
372
|
+
self._policy_executor.manage_memory_impl = self._manage_memory_fn
|
422
373
|
|
423
374
|
if (
|
424
|
-
cur_cls.
|
375
|
+
cur_cls._exit_tool_call_loop_fn is not base_cls._exit_tool_call_loop_fn # noqa: SLF001
|
425
376
|
):
|
426
|
-
self.
|
377
|
+
self._policy_executor.exit_tool_call_loop_impl = (
|
378
|
+
self._exit_tool_call_loop_fn
|
379
|
+
)
|
427
380
|
|
428
|
-
self._parse_output_impl:
|
381
|
+
self._parse_output_impl: (
|
382
|
+
ParseOutputHandler[InT_contra, OutT_co, CtxT] | None
|
383
|
+
) = None
|
429
384
|
|
430
|
-
def
|
431
|
-
self,
|
432
|
-
|
433
|
-
*,
|
434
|
-
ctx: RunContextWrapper[CtxT] | None = None,
|
435
|
-
) -> LLMFormattedSystemArgs:
|
385
|
+
def _make_sys_prompt_fn(
|
386
|
+
self, sys_args: LLMPromptArgs | None, *, ctx: RunContext[CtxT] | None = None
|
387
|
+
) -> str:
|
436
388
|
raise NotImplementedError(
|
437
389
|
"LLMAgent._format_sys_args must be overridden by a subclass "
|
438
390
|
"if it's intended to be used as the system arguments formatter."
|
439
391
|
)
|
440
392
|
|
441
|
-
def
|
393
|
+
def _make_in_content_fn(
|
442
394
|
self,
|
443
395
|
*,
|
444
|
-
|
445
|
-
|
396
|
+
in_args: InT_contra | None = None,
|
397
|
+
usr_args: LLMPromptArgs | None = None,
|
446
398
|
batch_idx: int = 0,
|
447
|
-
ctx:
|
448
|
-
) ->
|
399
|
+
ctx: RunContext[CtxT] | None = None,
|
400
|
+
) -> Content:
|
449
401
|
raise NotImplementedError(
|
450
402
|
"LLMAgent._format_in_args must be overridden by a subclass"
|
451
403
|
)
|
452
404
|
|
453
|
-
def
|
405
|
+
def _set_memory_fn(
|
454
406
|
self,
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
) -> LLMAgentState:
|
407
|
+
prev_memory: LLMAgentMemory,
|
408
|
+
in_args: InT_contra | Sequence[InT_contra] | None = None,
|
409
|
+
sys_prompt: LLMPrompt | None = None,
|
410
|
+
ctx: RunContext[Any] | None = None,
|
411
|
+
) -> LLMAgentMemory:
|
461
412
|
raise NotImplementedError(
|
462
|
-
"LLMAgent.
|
413
|
+
"LLMAgent._set_memory must be overridden by a subclass"
|
463
414
|
)
|
464
415
|
|
465
|
-
def
|
416
|
+
def _exit_tool_call_loop_fn(
|
466
417
|
self,
|
467
|
-
conversation:
|
418
|
+
conversation: Messages,
|
468
419
|
*,
|
469
|
-
ctx:
|
420
|
+
ctx: RunContext[CtxT] | None = None,
|
470
421
|
**kwargs: Any,
|
471
422
|
) -> bool:
|
472
423
|
raise NotImplementedError(
|
473
|
-
"LLMAgent.
|
424
|
+
"LLMAgent._exit_tool_call_loop must be overridden by a subclass"
|
474
425
|
)
|
475
426
|
|
476
|
-
def
|
427
|
+
def _manage_memory_fn(
|
477
428
|
self,
|
478
|
-
|
429
|
+
memory: LLMAgentMemory,
|
479
430
|
*,
|
480
|
-
ctx:
|
431
|
+
ctx: RunContext[CtxT] | None = None,
|
481
432
|
**kwargs: Any,
|
482
433
|
) -> None:
|
483
434
|
raise NotImplementedError(
|
484
|
-
"LLMAgent.
|
435
|
+
"LLMAgent._manage_memory must be overridden by a subclass"
|
485
436
|
)
|