grasp_agents 0.3.10__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_agents/cloud_llm.py +70 -77
- grasp_agents/comm_processor.py +21 -11
- grasp_agents/errors.py +34 -0
- grasp_agents/http_client.py +7 -5
- grasp_agents/llm.py +3 -9
- grasp_agents/llm_agent.py +92 -103
- grasp_agents/llm_agent_memory.py +36 -27
- grasp_agents/llm_policy_executor.py +66 -63
- grasp_agents/memory.py +3 -1
- grasp_agents/openai/completion_chunk_converters.py +4 -3
- grasp_agents/openai/openai_llm.py +14 -20
- grasp_agents/openai/tool_converters.py +0 -1
- grasp_agents/packet_pool.py +1 -1
- grasp_agents/printer.py +6 -6
- grasp_agents/processor.py +182 -48
- grasp_agents/prompt_builder.py +41 -55
- grasp_agents/run_context.py +1 -5
- grasp_agents/typing/completion_chunk.py +10 -5
- grasp_agents/typing/content.py +2 -2
- grasp_agents/typing/io.py +4 -4
- grasp_agents/typing/message.py +3 -6
- grasp_agents/typing/tool.py +5 -23
- grasp_agents/usage_tracker.py +2 -4
- grasp_agents/utils.py +37 -15
- grasp_agents/workflow/looped_workflow.py +14 -9
- grasp_agents/workflow/sequential_workflow.py +11 -6
- grasp_agents/workflow/workflow_processor.py +30 -13
- {grasp_agents-0.3.10.dist-info → grasp_agents-0.4.0.dist-info}/METADATA +2 -1
- grasp_agents-0.4.0.dist-info/RECORD +50 -0
- grasp_agents/message_history.py +0 -140
- grasp_agents/workflow/parallel_processor.py +0 -95
- grasp_agents-0.3.10.dist-info/RECORD +0 -51
- {grasp_agents-0.3.10.dist-info → grasp_agents-0.4.0.dist-info}/WHEEL +0 -0
- {grasp_agents-0.3.10.dist-info → grasp_agents-0.4.0.dist-info}/licenses/LICENSE.md +0 -0
grasp_agents/llm_agent.py
CHANGED
@@ -1,12 +1,12 @@
|
|
1
1
|
from collections.abc import AsyncIterator, Sequence
|
2
2
|
from pathlib import Path
|
3
|
-
from typing import Any, ClassVar, Generic, Protocol
|
3
|
+
from typing import Any, ClassVar, Generic, Protocol, TypeVar
|
4
4
|
|
5
5
|
from pydantic import BaseModel
|
6
6
|
|
7
7
|
from .comm_processor import CommProcessor
|
8
8
|
from .llm import LLM, LLMSettings
|
9
|
-
from .llm_agent_memory import LLMAgentMemory,
|
9
|
+
from .llm_agent_memory import LLMAgentMemory, MakeMemoryHandler
|
10
10
|
from .llm_policy_executor import (
|
11
11
|
ExitToolCallLoopHandler,
|
12
12
|
LLMPolicyExecutor,
|
@@ -22,26 +22,28 @@ from .run_context import CtxT, RunContext
|
|
22
22
|
from .typing.content import Content, ImageData
|
23
23
|
from .typing.converters import Converters
|
24
24
|
from .typing.events import Event, ProcOutputEvent, SystemMessageEvent, UserMessageEvent
|
25
|
-
from .typing.io import
|
25
|
+
from .typing.io import InT, LLMPrompt, LLMPromptArgs, OutT_co, ProcName
|
26
26
|
from .typing.message import Message, Messages, SystemMessage, UserMessage
|
27
27
|
from .typing.tool import BaseTool
|
28
28
|
from .utils import get_prompt, validate_obj_from_json_or_py_string
|
29
29
|
|
30
|
+
_InT_contra = TypeVar("_InT_contra", contravariant=True)
|
31
|
+
_OutT_co = TypeVar("_OutT_co", covariant=True)
|
30
32
|
|
31
|
-
|
33
|
+
|
34
|
+
class ParseOutputHandler(Protocol[_InT_contra, _OutT_co, CtxT]):
|
32
35
|
def __call__(
|
33
36
|
self,
|
34
37
|
conversation: Messages,
|
35
38
|
*,
|
36
|
-
in_args:
|
37
|
-
batch_idx: int,
|
39
|
+
in_args: _InT_contra | None,
|
38
40
|
ctx: RunContext[CtxT] | None,
|
39
|
-
) ->
|
41
|
+
) -> _OutT_co: ...
|
40
42
|
|
41
43
|
|
42
44
|
class LLMAgent(
|
43
|
-
CommProcessor[
|
44
|
-
Generic[
|
45
|
+
CommProcessor[InT, OutT_co, LLMAgentMemory, CtxT],
|
46
|
+
Generic[InT, OutT_co, CtxT],
|
45
47
|
):
|
46
48
|
_generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
|
47
49
|
0: "_in_type",
|
@@ -72,11 +74,18 @@ class LLMAgent(
|
|
72
74
|
final_answer_as_tool_call: bool = False,
|
73
75
|
# Agent memory management
|
74
76
|
reset_memory_on_run: bool = False,
|
77
|
+
# Retries
|
78
|
+
num_par_run_retries: int = 0,
|
75
79
|
# Multi-agent routing
|
76
80
|
packet_pool: PacketPool[CtxT] | None = None,
|
77
81
|
recipients: list[ProcName] | None = None,
|
78
82
|
) -> None:
|
79
|
-
super().__init__(
|
83
|
+
super().__init__(
|
84
|
+
name=name,
|
85
|
+
packet_pool=packet_pool,
|
86
|
+
recipients=recipients,
|
87
|
+
num_par_run_retries=num_par_run_retries,
|
88
|
+
)
|
80
89
|
|
81
90
|
# Agent memory
|
82
91
|
|
@@ -105,7 +114,7 @@ class LLMAgent(
|
|
105
114
|
|
106
115
|
sys_prompt = get_prompt(prompt_text=sys_prompt, prompt_path=sys_prompt_path)
|
107
116
|
in_prompt = get_prompt(prompt_text=in_prompt, prompt_path=in_prompt_path)
|
108
|
-
self._prompt_builder: PromptBuilder[
|
117
|
+
self._prompt_builder: PromptBuilder[InT, CtxT] = PromptBuilder[
|
109
118
|
self.in_type, CtxT
|
110
119
|
](
|
111
120
|
agent_name=self._name,
|
@@ -115,12 +124,10 @@ class LLMAgent(
|
|
115
124
|
usr_args_schema=usr_args_schema,
|
116
125
|
)
|
117
126
|
|
118
|
-
self.no_tqdm = getattr(llm, "no_tqdm", False)
|
127
|
+
# self.no_tqdm = getattr(llm, "no_tqdm", False)
|
119
128
|
|
120
|
-
self.
|
121
|
-
self._parse_output_impl:
|
122
|
-
ParseOutputHandler[InT_contra, OutT_co, CtxT] | None
|
123
|
-
) = None
|
129
|
+
self._make_memory_impl: MakeMemoryHandler | None = None
|
130
|
+
self._parse_output_impl: ParseOutputHandler[InT, OutT_co, CtxT] | None = None
|
124
131
|
self._register_overridden_handlers()
|
125
132
|
|
126
133
|
@property
|
@@ -155,10 +162,11 @@ class LLMAgent(
|
|
155
162
|
self,
|
156
163
|
chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
|
157
164
|
*,
|
158
|
-
in_args:
|
165
|
+
in_args: InT | None = None,
|
166
|
+
memory: LLMAgentMemory,
|
159
167
|
ctx: RunContext[CtxT] | None = None,
|
160
|
-
) -> tuple[SystemMessage | None,
|
161
|
-
# Get run arguments
|
168
|
+
) -> tuple[SystemMessage | None, UserMessage | None, LLMAgentMemory]:
|
169
|
+
# 1. Get run arguments
|
162
170
|
sys_args: LLMPromptArgs | None = None
|
163
171
|
usr_args: LLMPromptArgs | None = None
|
164
172
|
if ctx is not None:
|
@@ -167,23 +175,22 @@ class LLMAgent(
|
|
167
175
|
sys_args = run_args.sys
|
168
176
|
usr_args = run_args.usr
|
169
177
|
|
170
|
-
#
|
178
|
+
# 2. Make system prompt (can be None)
|
171
179
|
|
172
180
|
formatted_sys_prompt = self._prompt_builder.make_system_prompt(
|
173
181
|
sys_args=sys_args, ctx=ctx
|
174
182
|
)
|
175
183
|
|
176
|
-
#
|
184
|
+
# 3. Set agent memory
|
177
185
|
|
178
186
|
system_message: SystemMessage | None = None
|
179
|
-
|
180
|
-
|
181
|
-
_memory.reset(formatted_sys_prompt)
|
187
|
+
if self._reset_memory_on_run or memory.is_empty:
|
188
|
+
memory.reset(formatted_sys_prompt)
|
182
189
|
if formatted_sys_prompt is not None:
|
183
|
-
system_message =
|
184
|
-
elif self.
|
185
|
-
|
186
|
-
prev_memory=
|
190
|
+
system_message = memory.message_history[0] # type: ignore[union-attr]
|
191
|
+
elif self._make_memory_impl:
|
192
|
+
memory = self._make_memory_impl(
|
193
|
+
prev_memory=memory,
|
187
194
|
in_args=in_args,
|
188
195
|
sys_prompt=formatted_sys_prompt,
|
189
196
|
ctx=ctx,
|
@@ -191,113 +198,96 @@ class LLMAgent(
|
|
191
198
|
|
192
199
|
# 3. Make and add user messages
|
193
200
|
|
194
|
-
|
195
|
-
chat_inputs=chat_inputs,
|
201
|
+
user_message = self._prompt_builder.make_user_message(
|
202
|
+
chat_inputs=chat_inputs, in_args=in_args, usr_args=usr_args, ctx=ctx
|
196
203
|
)
|
197
|
-
if
|
198
|
-
|
204
|
+
if user_message:
|
205
|
+
memory.update([user_message])
|
199
206
|
|
200
|
-
return system_message,
|
201
|
-
|
202
|
-
def _extract_outputs(
|
203
|
-
self,
|
204
|
-
memory: LLMAgentMemory,
|
205
|
-
in_args: Sequence[InT_contra] | None = None,
|
206
|
-
ctx: RunContext[CtxT] | None = None,
|
207
|
-
) -> Sequence[OutT_co]:
|
208
|
-
outputs: list[OutT_co] = []
|
209
|
-
for i, _conv in enumerate(memory.message_history.conversations):
|
210
|
-
if in_args is not None:
|
211
|
-
_in_args_single = in_args[min(i, len(in_args) - 1)]
|
212
|
-
else:
|
213
|
-
_in_args_single = None
|
214
|
-
|
215
|
-
outputs.append(
|
216
|
-
self._parse_output(
|
217
|
-
conversation=_conv, in_args=_in_args_single, batch_idx=i, ctx=ctx
|
218
|
-
)
|
219
|
-
)
|
220
|
-
|
221
|
-
return outputs
|
207
|
+
return system_message, user_message, memory
|
222
208
|
|
223
209
|
def _parse_output(
|
224
210
|
self,
|
225
211
|
conversation: Messages,
|
226
212
|
*,
|
227
|
-
in_args:
|
228
|
-
batch_idx: int = 0,
|
213
|
+
in_args: InT | None = None,
|
229
214
|
ctx: RunContext[CtxT] | None = None,
|
230
215
|
) -> OutT_co:
|
231
216
|
if self._parse_output_impl:
|
232
217
|
return self._parse_output_impl(
|
233
|
-
conversation=conversation, in_args=in_args,
|
218
|
+
conversation=conversation, in_args=in_args, ctx=ctx
|
234
219
|
)
|
235
220
|
|
236
221
|
return validate_obj_from_json_or_py_string(
|
237
222
|
str(conversation[-1].content or ""),
|
238
223
|
adapter=self._out_type_adapter,
|
239
|
-
from_substring=
|
224
|
+
from_substring=False,
|
240
225
|
)
|
241
226
|
|
242
227
|
async def _process(
|
243
228
|
self,
|
244
229
|
chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
|
245
230
|
*,
|
246
|
-
in_args:
|
247
|
-
|
231
|
+
in_args: InT | None = None,
|
232
|
+
memory: LLMAgentMemory,
|
233
|
+
run_id: str,
|
248
234
|
ctx: RunContext[CtxT] | None = None,
|
249
235
|
) -> Sequence[OutT_co]:
|
250
|
-
system_message,
|
251
|
-
chat_inputs=chat_inputs, in_args=in_args, ctx=ctx
|
236
|
+
system_message, user_message, memory = self._memorize_inputs(
|
237
|
+
chat_inputs=chat_inputs, in_args=in_args, memory=memory, ctx=ctx
|
252
238
|
)
|
239
|
+
if system_message:
|
240
|
+
self._print_messages([system_message], run_id=run_id, ctx=ctx)
|
241
|
+
if user_message:
|
242
|
+
self._print_messages([user_message], run_id=run_id, ctx=ctx)
|
253
243
|
|
254
|
-
|
255
|
-
self._print_messages([system_message], ctx=ctx)
|
256
|
-
if user_message_batch:
|
257
|
-
self._print_messages(user_message_batch, ctx=ctx)
|
244
|
+
await self._policy_executor.execute(memory, run_id=run_id, ctx=ctx)
|
258
245
|
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
return self._extract_outputs(memory=memory, in_args=in_args, ctx=ctx)
|
246
|
+
return [
|
247
|
+
self._parse_output(
|
248
|
+
conversation=memory.message_history, in_args=in_args, ctx=ctx
|
249
|
+
)
|
250
|
+
]
|
265
251
|
|
266
252
|
async def _process_stream(
|
267
253
|
self,
|
268
254
|
chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
|
269
255
|
*,
|
270
|
-
in_args:
|
271
|
-
|
256
|
+
in_args: InT | None = None,
|
257
|
+
memory: LLMAgentMemory,
|
258
|
+
run_id: str,
|
272
259
|
ctx: RunContext[CtxT] | None = None,
|
273
260
|
) -> AsyncIterator[Event[Any]]:
|
274
|
-
system_message,
|
275
|
-
chat_inputs=chat_inputs, in_args=in_args, ctx=ctx
|
261
|
+
system_message, user_message, memory = self._memorize_inputs(
|
262
|
+
chat_inputs=chat_inputs, in_args=in_args, memory=memory, ctx=ctx
|
276
263
|
)
|
277
|
-
|
278
|
-
if system_message is not None:
|
264
|
+
if system_message:
|
279
265
|
yield SystemMessageEvent(data=system_message)
|
280
|
-
if
|
281
|
-
|
282
|
-
yield UserMessageEvent(data=user_message)
|
266
|
+
if user_message:
|
267
|
+
yield UserMessageEvent(data=user_message)
|
283
268
|
|
284
269
|
# 4. Run tool call loop (new messages are added to the message
|
285
270
|
# history inside the loop)
|
286
|
-
async for event in self._policy_executor.execute_stream(
|
271
|
+
async for event in self._policy_executor.execute_stream(
|
272
|
+
memory, run_id=run_id, ctx=ctx
|
273
|
+
):
|
287
274
|
yield event
|
288
275
|
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
for output in outputs:
|
294
|
-
yield ProcOutputEvent(data=output, name=self.name)
|
276
|
+
output = self._parse_output(
|
277
|
+
conversation=memory.message_history, in_args=in_args, ctx=ctx
|
278
|
+
)
|
279
|
+
yield ProcOutputEvent(data=output, name=self.name)
|
295
280
|
|
296
281
|
def _print_messages(
|
297
|
-
self,
|
282
|
+
self,
|
283
|
+
messages: Sequence[Message],
|
284
|
+
run_id: str,
|
285
|
+
ctx: RunContext[CtxT] | None = None,
|
298
286
|
) -> None:
|
299
287
|
if ctx:
|
300
|
-
ctx.printer.print_llm_messages(
|
288
|
+
ctx.printer.print_llm_messages(
|
289
|
+
messages, agent_name=self.name, run_id=run_id
|
290
|
+
)
|
301
291
|
|
302
292
|
# -- Decorators for custom implementations --
|
303
293
|
|
@@ -309,23 +299,23 @@ class LLMAgent(
|
|
309
299
|
return func
|
310
300
|
|
311
301
|
def make_input_content(
|
312
|
-
self, func: MakeInputContentHandler[
|
313
|
-
) -> MakeInputContentHandler[
|
302
|
+
self, func: MakeInputContentHandler[InT, CtxT]
|
303
|
+
) -> MakeInputContentHandler[InT, CtxT]:
|
314
304
|
self._prompt_builder.make_input_content_impl = func
|
315
305
|
|
316
306
|
return func
|
317
307
|
|
318
308
|
def parse_output(
|
319
|
-
self, func: ParseOutputHandler[
|
320
|
-
) -> ParseOutputHandler[
|
309
|
+
self, func: ParseOutputHandler[InT, OutT_co, CtxT]
|
310
|
+
) -> ParseOutputHandler[InT, OutT_co, CtxT]:
|
321
311
|
if self._used_default_llm_response_format:
|
322
312
|
self._policy_executor.llm.response_format = None
|
323
313
|
self._parse_output_impl = func
|
324
314
|
|
325
315
|
return func
|
326
316
|
|
327
|
-
def
|
328
|
-
self.
|
317
|
+
def make_memory(self, func: MakeMemoryHandler) -> MakeMemoryHandler:
|
318
|
+
self._make_memory_impl = func
|
329
319
|
|
330
320
|
return func
|
331
321
|
|
@@ -355,8 +345,8 @@ class LLMAgent(
|
|
355
345
|
if cur_cls._make_input_content is not base_cls._make_input_content: # noqa: SLF001
|
356
346
|
self._prompt_builder.make_input_content_impl = self._make_input_content
|
357
347
|
|
358
|
-
if cur_cls.
|
359
|
-
self.
|
348
|
+
if cur_cls._make_memory is not base_cls._make_memory: # noqa: SLF001
|
349
|
+
self._make_memory_impl = self._make_memory
|
360
350
|
|
361
351
|
if cur_cls._manage_memory is not base_cls._manage_memory: # noqa: SLF001
|
362
352
|
self._policy_executor.manage_memory_impl = self._manage_memory
|
@@ -383,24 +373,23 @@ class LLMAgent(
|
|
383
373
|
def _make_input_content(
|
384
374
|
self,
|
385
375
|
*,
|
386
|
-
in_args:
|
376
|
+
in_args: InT | None = None,
|
387
377
|
usr_args: LLMPromptArgs | None = None,
|
388
|
-
batch_idx: int = 0,
|
389
378
|
ctx: RunContext[CtxT] | None = None,
|
390
379
|
) -> Content:
|
391
380
|
raise NotImplementedError(
|
392
381
|
"LLMAgent._format_in_args must be overridden by a subclass"
|
393
382
|
)
|
394
383
|
|
395
|
-
def
|
384
|
+
def _make_memory(
|
396
385
|
self,
|
397
386
|
prev_memory: LLMAgentMemory,
|
398
|
-
in_args: Sequence[
|
387
|
+
in_args: Sequence[InT] | None = None,
|
399
388
|
sys_prompt: LLMPrompt | None = None,
|
400
389
|
ctx: RunContext[Any] | None = None,
|
401
390
|
) -> LLMAgentMemory:
|
402
391
|
raise NotImplementedError(
|
403
|
-
"LLMAgent.
|
392
|
+
"LLMAgent._make_memory must be overridden by a subclass"
|
404
393
|
)
|
405
394
|
|
406
395
|
def _exit_tool_call_loop(
|
grasp_agents/llm_agent_memory.py
CHANGED
@@ -1,16 +1,15 @@
|
|
1
1
|
from collections.abc import Sequence
|
2
2
|
from typing import Any, Protocol
|
3
3
|
|
4
|
-
from pydantic import
|
4
|
+
from pydantic import PrivateAttr
|
5
5
|
|
6
6
|
from .memory import Memory
|
7
|
-
from .message_history import MessageHistory
|
8
7
|
from .run_context import RunContext
|
9
8
|
from .typing.io import LLMPrompt
|
10
|
-
from .typing.message import Message
|
9
|
+
from .typing.message import Message, Messages, SystemMessage
|
11
10
|
|
12
11
|
|
13
|
-
class
|
12
|
+
class MakeMemoryHandler(Protocol):
|
14
13
|
def __call__(
|
15
14
|
self,
|
16
15
|
prev_memory: "LLMAgentMemory",
|
@@ -21,38 +20,48 @@ class SetMemoryHandler(Protocol):
|
|
21
20
|
|
22
21
|
|
23
22
|
class LLMAgentMemory(Memory):
|
24
|
-
|
23
|
+
_message_history: Messages = PrivateAttr(default_factory=list) # type: ignore
|
24
|
+
_sys_prompt: LLMPrompt | None = PrivateAttr(default=None)
|
25
|
+
|
26
|
+
def __init__(self, sys_prompt: LLMPrompt | None = None) -> None:
|
27
|
+
super().__init__()
|
28
|
+
self._sys_prompt = sys_prompt
|
29
|
+
self.reset(sys_prompt)
|
30
|
+
|
31
|
+
@property
|
32
|
+
def sys_prompt(self) -> LLMPrompt | None:
|
33
|
+
return self._sys_prompt
|
34
|
+
|
35
|
+
@property
|
36
|
+
def message_history(self) -> Messages:
|
37
|
+
return self._message_history
|
25
38
|
|
26
39
|
def reset(
|
27
40
|
self, sys_prompt: LLMPrompt | None = None, ctx: RunContext[Any] | None = None
|
28
41
|
):
|
29
|
-
|
42
|
+
if sys_prompt is not None:
|
43
|
+
self._sys_prompt = sys_prompt
|
44
|
+
|
45
|
+
self._message_history = (
|
46
|
+
[SystemMessage(content=self._sys_prompt)]
|
47
|
+
if self._sys_prompt is not None
|
48
|
+
else []
|
49
|
+
)
|
50
|
+
|
51
|
+
def erase(self) -> None:
|
52
|
+
self._message_history = []
|
30
53
|
|
31
54
|
def update(
|
32
|
-
self,
|
33
|
-
message_list: Sequence[Message] | None = None,
|
34
|
-
*,
|
35
|
-
message_batch: Sequence[Message] | None = None,
|
36
|
-
ctx: RunContext[Any] | None = None,
|
55
|
+
self, messages: Sequence[Message], *, ctx: RunContext[Any] | None = None
|
37
56
|
):
|
38
|
-
|
39
|
-
raise ValueError(
|
40
|
-
"Only one of message_batch or messages should be provided."
|
41
|
-
)
|
42
|
-
if message_batch is not None:
|
43
|
-
self.message_history.add_message_batch(message_batch)
|
44
|
-
elif message_list is not None:
|
45
|
-
self.message_history.add_message_list(message_list)
|
46
|
-
else:
|
47
|
-
raise ValueError("Either message_batch or messages must be provided.")
|
57
|
+
self._message_history.extend(messages)
|
48
58
|
|
49
59
|
@property
|
50
60
|
def is_empty(self) -> bool:
|
51
|
-
return len(self.
|
52
|
-
|
53
|
-
@property
|
54
|
-
def batch_size(self) -> int:
|
55
|
-
return self.message_history.batch_size
|
61
|
+
return len(self._message_history) == 0
|
56
62
|
|
57
63
|
def __repr__(self) -> str:
|
58
|
-
return
|
64
|
+
return (
|
65
|
+
"LLMAgentMemory with message history of "
|
66
|
+
f"length {len(self._message_history)}"
|
67
|
+
)
|