grasp_agents 0.3.10__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
grasp_agents/llm_agent.py CHANGED
@@ -1,12 +1,12 @@
1
1
  from collections.abc import AsyncIterator, Sequence
2
2
  from pathlib import Path
3
- from typing import Any, ClassVar, Generic, Protocol
3
+ from typing import Any, ClassVar, Generic, Protocol, TypeVar
4
4
 
5
5
  from pydantic import BaseModel
6
6
 
7
7
  from .comm_processor import CommProcessor
8
8
  from .llm import LLM, LLMSettings
9
- from .llm_agent_memory import LLMAgentMemory, SetMemoryHandler
9
+ from .llm_agent_memory import LLMAgentMemory, MakeMemoryHandler
10
10
  from .llm_policy_executor import (
11
11
  ExitToolCallLoopHandler,
12
12
  LLMPolicyExecutor,
@@ -22,26 +22,28 @@ from .run_context import CtxT, RunContext
22
22
  from .typing.content import Content, ImageData
23
23
  from .typing.converters import Converters
24
24
  from .typing.events import Event, ProcOutputEvent, SystemMessageEvent, UserMessageEvent
25
- from .typing.io import InT_contra, LLMPrompt, LLMPromptArgs, OutT_co, ProcName
25
+ from .typing.io import InT, LLMPrompt, LLMPromptArgs, OutT_co, ProcName
26
26
  from .typing.message import Message, Messages, SystemMessage, UserMessage
27
27
  from .typing.tool import BaseTool
28
28
  from .utils import get_prompt, validate_obj_from_json_or_py_string
29
29
 
30
+ _InT_contra = TypeVar("_InT_contra", contravariant=True)
31
+ _OutT_co = TypeVar("_OutT_co", covariant=True)
30
32
 
31
- class ParseOutputHandler(Protocol[InT_contra, OutT_co, CtxT]):
33
+
34
+ class ParseOutputHandler(Protocol[_InT_contra, _OutT_co, CtxT]):
32
35
  def __call__(
33
36
  self,
34
37
  conversation: Messages,
35
38
  *,
36
- in_args: InT_contra | None,
37
- batch_idx: int,
39
+ in_args: _InT_contra | None,
38
40
  ctx: RunContext[CtxT] | None,
39
- ) -> OutT_co: ...
41
+ ) -> _OutT_co: ...
40
42
 
41
43
 
42
44
  class LLMAgent(
43
- CommProcessor[InT_contra, OutT_co, LLMAgentMemory, CtxT],
44
- Generic[InT_contra, OutT_co, CtxT],
45
+ CommProcessor[InT, OutT_co, LLMAgentMemory, CtxT],
46
+ Generic[InT, OutT_co, CtxT],
45
47
  ):
46
48
  _generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
47
49
  0: "_in_type",
@@ -72,11 +74,18 @@ class LLMAgent(
72
74
  final_answer_as_tool_call: bool = False,
73
75
  # Agent memory management
74
76
  reset_memory_on_run: bool = False,
77
+ # Retries
78
+ num_par_run_retries: int = 0,
75
79
  # Multi-agent routing
76
80
  packet_pool: PacketPool[CtxT] | None = None,
77
81
  recipients: list[ProcName] | None = None,
78
82
  ) -> None:
79
- super().__init__(name=name, packet_pool=packet_pool, recipients=recipients)
83
+ super().__init__(
84
+ name=name,
85
+ packet_pool=packet_pool,
86
+ recipients=recipients,
87
+ num_par_run_retries=num_par_run_retries,
88
+ )
80
89
 
81
90
  # Agent memory
82
91
 
@@ -105,7 +114,7 @@ class LLMAgent(
105
114
 
106
115
  sys_prompt = get_prompt(prompt_text=sys_prompt, prompt_path=sys_prompt_path)
107
116
  in_prompt = get_prompt(prompt_text=in_prompt, prompt_path=in_prompt_path)
108
- self._prompt_builder: PromptBuilder[InT_contra, CtxT] = PromptBuilder[
117
+ self._prompt_builder: PromptBuilder[InT, CtxT] = PromptBuilder[
109
118
  self.in_type, CtxT
110
119
  ](
111
120
  agent_name=self._name,
@@ -115,12 +124,10 @@ class LLMAgent(
115
124
  usr_args_schema=usr_args_schema,
116
125
  )
117
126
 
118
- self.no_tqdm = getattr(llm, "no_tqdm", False)
127
+ # self.no_tqdm = getattr(llm, "no_tqdm", False)
119
128
 
120
- self._set_memory_impl: SetMemoryHandler | None = None
121
- self._parse_output_impl: (
122
- ParseOutputHandler[InT_contra, OutT_co, CtxT] | None
123
- ) = None
129
+ self._make_memory_impl: MakeMemoryHandler | None = None
130
+ self._parse_output_impl: ParseOutputHandler[InT, OutT_co, CtxT] | None = None
124
131
  self._register_overridden_handlers()
125
132
 
126
133
  @property
@@ -155,10 +162,11 @@ class LLMAgent(
155
162
  self,
156
163
  chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
157
164
  *,
158
- in_args: Sequence[InT_contra] | None = None,
165
+ in_args: InT | None = None,
166
+ memory: LLMAgentMemory,
159
167
  ctx: RunContext[CtxT] | None = None,
160
- ) -> tuple[SystemMessage | None, Sequence[UserMessage], LLMAgentMemory]:
161
- # Get run arguments
168
+ ) -> tuple[SystemMessage | None, UserMessage | None, LLMAgentMemory]:
169
+ # 1. Get run arguments
162
170
  sys_args: LLMPromptArgs | None = None
163
171
  usr_args: LLMPromptArgs | None = None
164
172
  if ctx is not None:
@@ -167,23 +175,22 @@ class LLMAgent(
167
175
  sys_args = run_args.sys
168
176
  usr_args = run_args.usr
169
177
 
170
- # 1. Make system prompt (can be None)
178
+ # 2. Make system prompt (can be None)
171
179
 
172
180
  formatted_sys_prompt = self._prompt_builder.make_system_prompt(
173
181
  sys_args=sys_args, ctx=ctx
174
182
  )
175
183
 
176
- # 2. Set agent memory
184
+ # 3. Set agent memory
177
185
 
178
186
  system_message: SystemMessage | None = None
179
- _memory = self.memory.model_copy(deep=True)
180
- if self._reset_memory_on_run or _memory.is_empty:
181
- _memory.reset(formatted_sys_prompt)
187
+ if self._reset_memory_on_run or memory.is_empty:
188
+ memory.reset(formatted_sys_prompt)
182
189
  if formatted_sys_prompt is not None:
183
- system_message = _memory.message_history[0][0] # type: ignore[assignment]
184
- elif self._set_memory_impl:
185
- _memory = self._set_memory_impl(
186
- prev_memory=_memory,
190
+ system_message = memory.message_history[0] # type: ignore[union-attr]
191
+ elif self._make_memory_impl:
192
+ memory = self._make_memory_impl(
193
+ prev_memory=memory,
187
194
  in_args=in_args,
188
195
  sys_prompt=formatted_sys_prompt,
189
196
  ctx=ctx,
@@ -191,113 +198,96 @@ class LLMAgent(
191
198
 
192
199
  # 3. Make and add user messages
193
200
 
194
- user_message_batch = self._prompt_builder.make_user_messages(
195
- chat_inputs=chat_inputs, in_args_batch=in_args, usr_args=usr_args, ctx=ctx
201
+ user_message = self._prompt_builder.make_user_message(
202
+ chat_inputs=chat_inputs, in_args=in_args, usr_args=usr_args, ctx=ctx
196
203
  )
197
- if user_message_batch:
198
- _memory.update(message_batch=user_message_batch)
204
+ if user_message:
205
+ memory.update([user_message])
199
206
 
200
- return system_message, user_message_batch, _memory
201
-
202
- def _extract_outputs(
203
- self,
204
- memory: LLMAgentMemory,
205
- in_args: Sequence[InT_contra] | None = None,
206
- ctx: RunContext[CtxT] | None = None,
207
- ) -> Sequence[OutT_co]:
208
- outputs: list[OutT_co] = []
209
- for i, _conv in enumerate(memory.message_history.conversations):
210
- if in_args is not None:
211
- _in_args_single = in_args[min(i, len(in_args) - 1)]
212
- else:
213
- _in_args_single = None
214
-
215
- outputs.append(
216
- self._parse_output(
217
- conversation=_conv, in_args=_in_args_single, batch_idx=i, ctx=ctx
218
- )
219
- )
220
-
221
- return outputs
207
+ return system_message, user_message, memory
222
208
 
223
209
  def _parse_output(
224
210
  self,
225
211
  conversation: Messages,
226
212
  *,
227
- in_args: InT_contra | None = None,
228
- batch_idx: int = 0,
213
+ in_args: InT | None = None,
229
214
  ctx: RunContext[CtxT] | None = None,
230
215
  ) -> OutT_co:
231
216
  if self._parse_output_impl:
232
217
  return self._parse_output_impl(
233
- conversation=conversation, in_args=in_args, batch_idx=batch_idx, ctx=ctx
218
+ conversation=conversation, in_args=in_args, ctx=ctx
234
219
  )
235
220
 
236
221
  return validate_obj_from_json_or_py_string(
237
222
  str(conversation[-1].content or ""),
238
223
  adapter=self._out_type_adapter,
239
- from_substring=True,
224
+ from_substring=False,
240
225
  )
241
226
 
242
227
  async def _process(
243
228
  self,
244
229
  chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
245
230
  *,
246
- in_args: Sequence[InT_contra] | None = None,
247
- forgetful: bool = False,
231
+ in_args: InT | None = None,
232
+ memory: LLMAgentMemory,
233
+ run_id: str,
248
234
  ctx: RunContext[CtxT] | None = None,
249
235
  ) -> Sequence[OutT_co]:
250
- system_message, user_message_batch, memory = self._memorize_inputs(
251
- chat_inputs=chat_inputs, in_args=in_args, ctx=ctx
236
+ system_message, user_message, memory = self._memorize_inputs(
237
+ chat_inputs=chat_inputs, in_args=in_args, memory=memory, ctx=ctx
252
238
  )
239
+ if system_message:
240
+ self._print_messages([system_message], run_id=run_id, ctx=ctx)
241
+ if user_message:
242
+ self._print_messages([user_message], run_id=run_id, ctx=ctx)
253
243
 
254
- if system_message is not None:
255
- self._print_messages([system_message], ctx=ctx)
256
- if user_message_batch:
257
- self._print_messages(user_message_batch, ctx=ctx)
244
+ await self._policy_executor.execute(memory, run_id=run_id, ctx=ctx)
258
245
 
259
- await self._policy_executor.execute(memory, ctx=ctx)
260
-
261
- if not forgetful:
262
- self._memory = memory
263
-
264
- return self._extract_outputs(memory=memory, in_args=in_args, ctx=ctx)
246
+ return [
247
+ self._parse_output(
248
+ conversation=memory.message_history, in_args=in_args, ctx=ctx
249
+ )
250
+ ]
265
251
 
266
252
  async def _process_stream(
267
253
  self,
268
254
  chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
269
255
  *,
270
- in_args: Sequence[InT_contra] | None = None,
271
- forgetful: bool = False,
256
+ in_args: InT | None = None,
257
+ memory: LLMAgentMemory,
258
+ run_id: str,
272
259
  ctx: RunContext[CtxT] | None = None,
273
260
  ) -> AsyncIterator[Event[Any]]:
274
- system_message, user_message_batch, memory = self._memorize_inputs(
275
- chat_inputs=chat_inputs, in_args=in_args, ctx=ctx
261
+ system_message, user_message, memory = self._memorize_inputs(
262
+ chat_inputs=chat_inputs, in_args=in_args, memory=memory, ctx=ctx
276
263
  )
277
-
278
- if system_message is not None:
264
+ if system_message:
279
265
  yield SystemMessageEvent(data=system_message)
280
- if user_message_batch:
281
- for user_message in user_message_batch:
282
- yield UserMessageEvent(data=user_message)
266
+ if user_message:
267
+ yield UserMessageEvent(data=user_message)
283
268
 
284
269
  # 4. Run tool call loop (new messages are added to the message
285
270
  # history inside the loop)
286
- async for event in self._policy_executor.execute_stream(memory, ctx=ctx):
271
+ async for event in self._policy_executor.execute_stream(
272
+ memory, run_id=run_id, ctx=ctx
273
+ ):
287
274
  yield event
288
275
 
289
- if not forgetful:
290
- self._memory = memory
291
-
292
- outputs = self._extract_outputs(memory=memory, in_args=in_args, ctx=ctx)
293
- for output in outputs:
294
- yield ProcOutputEvent(data=output, name=self.name)
276
+ output = self._parse_output(
277
+ conversation=memory.message_history, in_args=in_args, ctx=ctx
278
+ )
279
+ yield ProcOutputEvent(data=output, name=self.name)
295
280
 
296
281
  def _print_messages(
297
- self, messages: Sequence[Message], ctx: RunContext[CtxT] | None = None
282
+ self,
283
+ messages: Sequence[Message],
284
+ run_id: str,
285
+ ctx: RunContext[CtxT] | None = None,
298
286
  ) -> None:
299
287
  if ctx:
300
- ctx.printer.print_llm_messages(messages, agent_name=self.name)
288
+ ctx.printer.print_llm_messages(
289
+ messages, agent_name=self.name, run_id=run_id
290
+ )
301
291
 
302
292
  # -- Decorators for custom implementations --
303
293
 
@@ -309,23 +299,23 @@ class LLMAgent(
309
299
  return func
310
300
 
311
301
  def make_input_content(
312
- self, func: MakeInputContentHandler[InT_contra, CtxT]
313
- ) -> MakeInputContentHandler[InT_contra, CtxT]:
302
+ self, func: MakeInputContentHandler[InT, CtxT]
303
+ ) -> MakeInputContentHandler[InT, CtxT]:
314
304
  self._prompt_builder.make_input_content_impl = func
315
305
 
316
306
  return func
317
307
 
318
308
  def parse_output(
319
- self, func: ParseOutputHandler[InT_contra, OutT_co, CtxT]
320
- ) -> ParseOutputHandler[InT_contra, OutT_co, CtxT]:
309
+ self, func: ParseOutputHandler[InT, OutT_co, CtxT]
310
+ ) -> ParseOutputHandler[InT, OutT_co, CtxT]:
321
311
  if self._used_default_llm_response_format:
322
312
  self._policy_executor.llm.response_format = None
323
313
  self._parse_output_impl = func
324
314
 
325
315
  return func
326
316
 
327
- def set_memory(self, func: SetMemoryHandler) -> SetMemoryHandler:
328
- self._set_memory_impl = func
317
+ def make_memory(self, func: MakeMemoryHandler) -> MakeMemoryHandler:
318
+ self._make_memory_impl = func
329
319
 
330
320
  return func
331
321
 
@@ -355,8 +345,8 @@ class LLMAgent(
355
345
  if cur_cls._make_input_content is not base_cls._make_input_content: # noqa: SLF001
356
346
  self._prompt_builder.make_input_content_impl = self._make_input_content
357
347
 
358
- if cur_cls._set_memory is not base_cls._set_memory: # noqa: SLF001
359
- self._set_memory_impl = self._set_memory
348
+ if cur_cls._make_memory is not base_cls._make_memory: # noqa: SLF001
349
+ self._make_memory_impl = self._make_memory
360
350
 
361
351
  if cur_cls._manage_memory is not base_cls._manage_memory: # noqa: SLF001
362
352
  self._policy_executor.manage_memory_impl = self._manage_memory
@@ -383,24 +373,23 @@ class LLMAgent(
383
373
  def _make_input_content(
384
374
  self,
385
375
  *,
386
- in_args: InT_contra | None = None,
376
+ in_args: InT | None = None,
387
377
  usr_args: LLMPromptArgs | None = None,
388
- batch_idx: int = 0,
389
378
  ctx: RunContext[CtxT] | None = None,
390
379
  ) -> Content:
391
380
  raise NotImplementedError(
392
381
  "LLMAgent._format_in_args must be overridden by a subclass"
393
382
  )
394
383
 
395
- def _set_memory(
384
+ def _make_memory(
396
385
  self,
397
386
  prev_memory: LLMAgentMemory,
398
- in_args: Sequence[InT_contra] | None = None,
387
+ in_args: Sequence[InT] | None = None,
399
388
  sys_prompt: LLMPrompt | None = None,
400
389
  ctx: RunContext[Any] | None = None,
401
390
  ) -> LLMAgentMemory:
402
391
  raise NotImplementedError(
403
- "LLMAgent._set_memory must be overridden by a subclass"
392
+ "LLMAgent._make_memory must be overridden by a subclass"
404
393
  )
405
394
 
406
395
  def _exit_tool_call_loop(
@@ -1,16 +1,15 @@
1
1
  from collections.abc import Sequence
2
2
  from typing import Any, Protocol
3
3
 
4
- from pydantic import Field
4
+ from pydantic import PrivateAttr
5
5
 
6
6
  from .memory import Memory
7
- from .message_history import MessageHistory
8
7
  from .run_context import RunContext
9
8
  from .typing.io import LLMPrompt
10
- from .typing.message import Message
9
+ from .typing.message import Message, Messages, SystemMessage
11
10
 
12
11
 
13
- class SetMemoryHandler(Protocol):
12
+ class MakeMemoryHandler(Protocol):
14
13
  def __call__(
15
14
  self,
16
15
  prev_memory: "LLMAgentMemory",
@@ -21,38 +20,48 @@ class SetMemoryHandler(Protocol):
21
20
 
22
21
 
23
22
  class LLMAgentMemory(Memory):
24
- message_history: MessageHistory = Field(default_factory=MessageHistory)
23
+ _message_history: Messages = PrivateAttr(default_factory=list) # type: ignore
24
+ _sys_prompt: LLMPrompt | None = PrivateAttr(default=None)
25
+
26
+ def __init__(self, sys_prompt: LLMPrompt | None = None) -> None:
27
+ super().__init__()
28
+ self._sys_prompt = sys_prompt
29
+ self.reset(sys_prompt)
30
+
31
+ @property
32
+ def sys_prompt(self) -> LLMPrompt | None:
33
+ return self._sys_prompt
34
+
35
+ @property
36
+ def message_history(self) -> Messages:
37
+ return self._message_history
25
38
 
26
39
  def reset(
27
40
  self, sys_prompt: LLMPrompt | None = None, ctx: RunContext[Any] | None = None
28
41
  ):
29
- self.message_history.reset(sys_prompt=sys_prompt)
42
+ if sys_prompt is not None:
43
+ self._sys_prompt = sys_prompt
44
+
45
+ self._message_history = (
46
+ [SystemMessage(content=self._sys_prompt)]
47
+ if self._sys_prompt is not None
48
+ else []
49
+ )
50
+
51
+ def erase(self) -> None:
52
+ self._message_history = []
30
53
 
31
54
  def update(
32
- self,
33
- message_list: Sequence[Message] | None = None,
34
- *,
35
- message_batch: Sequence[Message] | None = None,
36
- ctx: RunContext[Any] | None = None,
55
+ self, messages: Sequence[Message], *, ctx: RunContext[Any] | None = None
37
56
  ):
38
- if message_batch is not None and message_list is not None:
39
- raise ValueError(
40
- "Only one of message_batch or messages should be provided."
41
- )
42
- if message_batch is not None:
43
- self.message_history.add_message_batch(message_batch)
44
- elif message_list is not None:
45
- self.message_history.add_message_list(message_list)
46
- else:
47
- raise ValueError("Either message_batch or messages must be provided.")
57
+ self._message_history.extend(messages)
48
58
 
49
59
  @property
50
60
  def is_empty(self) -> bool:
51
- return len(self.message_history) == 0
52
-
53
- @property
54
- def batch_size(self) -> int:
55
- return self.message_history.batch_size
61
+ return len(self._message_history) == 0
56
62
 
57
63
  def __repr__(self) -> str:
58
- return f"Message History: {len(self.message_history)}"
64
+ return (
65
+ "LLMAgentMemory with message history of "
66
+ f"length {len(self._message_history)}"
67
+ )