grasp_agents 0.5.3__py3-none-any.whl → 0.5.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,7 +3,7 @@ import json
3
3
  from collections.abc import AsyncIterator, Coroutine, Sequence
4
4
  from itertools import starmap
5
5
  from logging import getLogger
6
- from typing import Any, Generic, Protocol
6
+ from typing import Any, Generic, Protocol, final
7
7
 
8
8
  from pydantic import BaseModel
9
9
 
@@ -29,7 +29,7 @@ from .typing.tool import BaseTool, NamedToolChoice, ToolCall, ToolChoice
29
29
  logger = getLogger(__name__)
30
30
 
31
31
 
32
- class ExitToolCallLoopHandler(Protocol[CtxT]):
32
+ class ToolCallLoopTerminator(Protocol[CtxT]):
33
33
  def __call__(
34
34
  self,
35
35
  conversation: Messages,
@@ -39,7 +39,7 @@ class ExitToolCallLoopHandler(Protocol[CtxT]):
39
39
  ) -> bool: ...
40
40
 
41
41
 
42
- class ManageMemoryHandler(Protocol[CtxT]):
42
+ class MemoryManager(Protocol[CtxT]):
43
43
  def __call__(
44
44
  self,
45
45
  memory: LLMAgentMemory,
@@ -78,8 +78,8 @@ class LLMPolicyExecutor(Generic[CtxT]):
78
78
  self._max_turns = max_turns
79
79
  self._react_mode = react_mode
80
80
 
81
- self.exit_tool_call_loop_impl: ExitToolCallLoopHandler[CtxT] | None = None
82
- self.manage_memory_impl: ManageMemoryHandler[CtxT] | None = None
81
+ self.tool_call_loop_terminator: ToolCallLoopTerminator[CtxT] | None = None
82
+ self.memory_manager: MemoryManager[CtxT] | None = None
83
83
 
84
84
  @property
85
85
  def agent_name(self) -> str:
@@ -97,18 +97,20 @@ class LLMPolicyExecutor(Generic[CtxT]):
97
97
  def max_turns(self) -> int:
98
98
  return self._max_turns
99
99
 
100
- def _exit_tool_call_loop(
100
+ @final
101
+ def _terminate_tool_call_loop(
101
102
  self,
102
103
  conversation: Messages,
103
104
  *,
104
105
  ctx: RunContext[CtxT] | None = None,
105
106
  **kwargs: Any,
106
107
  ) -> bool:
107
- if self.exit_tool_call_loop_impl:
108
- return self.exit_tool_call_loop_impl(conversation, ctx=ctx, **kwargs)
108
+ if self.tool_call_loop_terminator:
109
+ return self.tool_call_loop_terminator(conversation, ctx=ctx, **kwargs)
109
110
 
110
111
  return False
111
112
 
113
+ @final
112
114
  def _manage_memory(
113
115
  self,
114
116
  memory: LLMAgentMemory,
@@ -116,8 +118,8 @@ class LLMPolicyExecutor(Generic[CtxT]):
116
118
  ctx: RunContext[CtxT] | None = None,
117
119
  **kwargs: Any,
118
120
  ) -> None:
119
- if self.manage_memory_impl:
120
- self.manage_memory_impl(memory=memory, ctx=ctx, **kwargs)
121
+ if self.memory_manager:
122
+ self.memory_manager(memory=memory, ctx=ctx, **kwargs)
121
123
 
122
124
  async def generate_message(
123
125
  self,
@@ -255,7 +257,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
255
257
 
256
258
  final_answer_message = self._extract_final_answer_from_tool_calls(memory=memory)
257
259
  if final_answer_message is None:
258
- raise AgentFinalAnswerError
260
+ raise AgentFinalAnswerError(proc_name=self.agent_name, call_id=call_id)
259
261
 
260
262
  return final_answer_message
261
263
 
@@ -282,7 +284,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
282
284
 
283
285
  final_answer_message = self._extract_final_answer_from_tool_calls(memory)
284
286
  if final_answer_message is None:
285
- raise AgentFinalAnswerError
287
+ raise AgentFinalAnswerError(proc_name=self.agent_name, call_id=call_id)
286
288
  yield GenMessageEvent(
287
289
  proc_name=self.agent_name, call_id=call_id, data=final_answer_message
288
290
  )
@@ -309,8 +311,8 @@ class LLMPolicyExecutor(Generic[CtxT]):
309
311
  # 2. Check if we should exit the tool call loop
310
312
 
311
313
  # If a final answer is not provided via a tool call, we use
312
- # exit_tool_call_loop to determine whether to exit the loop.
313
- if not self._final_answer_as_tool_call and self._exit_tool_call_loop(
314
+ # _terminate_tool_call_loop to determine whether to exit the loop.
315
+ if not self._final_answer_as_tool_call and self._terminate_tool_call_loop(
314
316
  memory.message_history, ctx=ctx, num_turns=turns
315
317
  ):
316
318
  return gen_message
@@ -390,7 +392,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
390
392
  turns = 0
391
393
 
392
394
  while True:
393
- if not self._final_answer_as_tool_call and self._exit_tool_call_loop(
395
+ if not self._final_answer_as_tool_call and self._terminate_tool_call_loop(
394
396
  memory.message_history, ctx=ctx, num_turns=turns
395
397
  ):
396
398
  return
grasp_agents/packet.py CHANGED
@@ -1,11 +1,13 @@
1
1
  from collections.abc import Sequence
2
- from typing import Generic, TypeVar
2
+ from typing import Annotated, Any, Generic, Literal, TypeVar
3
3
  from uuid import uuid4
4
4
 
5
- from pydantic import BaseModel, ConfigDict, Field
5
+ from pydantic import AfterValidator, BaseModel, ConfigDict, Field
6
6
 
7
7
  from .typing.io import ProcName
8
8
 
9
+ START_PROC_NAME: Literal["*START*"] = "*START*"
10
+
9
11
  _PayloadT_co = TypeVar("_PayloadT_co", covariant=True)
10
12
 
11
13
 
@@ -13,15 +15,32 @@ class Packet(BaseModel, Generic[_PayloadT_co]):
13
15
  id: str = Field(default_factory=lambda: str(uuid4())[:8])
14
16
  payloads: Sequence[_PayloadT_co]
15
17
  sender: ProcName
16
- recipients: Sequence[ProcName] = Field(default_factory=list)
18
+ recipients: Sequence[ProcName] | None = None
17
19
 
18
20
  model_config = ConfigDict(extra="forbid")
19
21
 
20
22
  def __repr__(self) -> str:
23
+ _to = ", ".join(self.recipients) if self.recipients else "None"
21
24
  return (
22
25
  f"{self.__class__.__name__}:\n"
23
26
  f"ID: {self.id}\n"
24
27
  f"From: {self.sender}\n"
25
- f"To: {', '.join(self.recipients)}\n"
28
+ f"To: {_to}\n"
26
29
  f"Payloads: {len(self.payloads)}"
27
30
  )
31
+
32
+
33
+ def _check_recipients_length(v: Sequence[ProcName] | None) -> Sequence[ProcName] | None:
34
+ if v is not None and len(v) != 1:
35
+ raise ValueError("recipients must contain exactly one item")
36
+ return v
37
+
38
+
39
+ class StartPacket(Packet[_PayloadT_co]):
40
+ chat_inputs: Any | None = "start"
41
+ sender: ProcName = Field(default=START_PROC_NAME, frozen=True)
42
+ payloads: Sequence[_PayloadT_co] = Field(default=(), frozen=True)
43
+ recipients: Annotated[
44
+ Sequence[ProcName] | None,
45
+ AfterValidator(_check_recipients_length),
46
+ ] = None
@@ -1,7 +1,8 @@
1
1
  import asyncio
2
2
  import logging
3
3
  from collections.abc import AsyncIterator
4
- from typing import Any, Generic, Protocol, TypeVar
4
+ from types import TracebackType
5
+ from typing import Any, Generic, Literal, Protocol, TypeVar
5
6
 
6
7
  from .packet import Packet
7
8
  from .run_context import CtxT, RunContext
@@ -11,6 +12,9 @@ from .typing.io import ProcName
11
12
  logger = logging.getLogger(__name__)
12
13
 
13
14
 
15
+ END_PROC_NAME: Literal["*END*"] = "*END*"
16
+
17
+
14
18
  _PayloadT_contra = TypeVar("_PayloadT_contra", contravariant=True)
15
19
 
16
20
 
@@ -20,73 +24,136 @@ class PacketHandler(Protocol[_PayloadT_contra, CtxT]):
20
24
  packet: Packet[_PayloadT_contra],
21
25
  ctx: RunContext[CtxT],
22
26
  **kwargs: Any,
23
- ) -> AsyncIterator[Event[Any]] | None: ...
27
+ ) -> None: ...
24
28
 
25
29
 
26
30
  class PacketPool(Generic[CtxT]):
27
31
  def __init__(self) -> None:
28
- self._queues: dict[ProcName, asyncio.Queue[Packet[Any]]] = {}
32
+ self._packet_queues: dict[ProcName, asyncio.Queue[Packet[Any] | None]] = {}
29
33
  self._packet_handlers: dict[ProcName, PacketHandler[Any, CtxT]] = {}
30
- self._tasks: dict[ProcName, asyncio.Task[AsyncIterator[Event[Any]] | None]] = {}
34
+ self._task_group: asyncio.TaskGroup | None = None
35
+
36
+ self._event_queue: asyncio.Queue[Event[Any] | None] = asyncio.Queue()
37
+
38
+ self._final_result_fut: asyncio.Future[Packet[Any]] | None = None
39
+
40
+ self._stopping = False
41
+ self._stopped_evt = asyncio.Event()
42
+
43
+ self._errors: list[Exception] = []
31
44
 
32
45
  async def post(self, packet: Packet[Any]) -> None:
33
- for recipient_id in packet.recipients:
34
- queue = self._queues.setdefault(recipient_id, asyncio.Queue())
46
+ if packet.recipients == [END_PROC_NAME]:
47
+ fut = self._ensure_final_future()
48
+ if not fut.done():
49
+ fut.set_result(packet)
50
+ await self.shutdown()
51
+ return
52
+
53
+ for recipient_id in packet.recipients or []:
54
+ queue = self._packet_queues.setdefault(recipient_id, asyncio.Queue())
35
55
  await queue.put(packet)
36
56
 
57
+ def _ensure_final_future(self) -> asyncio.Future[Packet[Any]]:
58
+ fut = self._final_result_fut
59
+ if fut is None:
60
+ fut = asyncio.get_running_loop().create_future()
61
+ self._final_result_fut = fut
62
+ return fut
63
+
64
+ async def final_result(self) -> Packet[Any]:
65
+ fut = self._ensure_final_future()
66
+ try:
67
+ return await fut
68
+ finally:
69
+ await self.shutdown()
70
+
37
71
  def register_packet_handler(
38
72
  self,
39
- processor_name: ProcName,
73
+ proc_name: ProcName,
40
74
  handler: PacketHandler[Any, CtxT],
41
75
  ctx: RunContext[CtxT],
42
76
  **run_kwargs: Any,
43
77
  ) -> None:
44
- self._packet_handlers[processor_name] = handler
45
- self._queues.setdefault(processor_name, asyncio.Queue())
46
- if processor_name not in self._tasks:
47
- self._tasks[processor_name] = asyncio.create_task(
48
- self._handle_packets(processor_name, ctx=ctx, **run_kwargs)
78
+ if self._stopping:
79
+ raise RuntimeError("PacketPool is stopping/stopped")
80
+
81
+ self._packet_handlers[proc_name] = handler
82
+ self._packet_queues.setdefault(proc_name, asyncio.Queue())
83
+
84
+ if self._task_group is not None:
85
+ self._task_group.create_task(
86
+ self._handle_packets(proc_name, ctx=ctx, **run_kwargs),
87
+ name=f"packet-handler:{proc_name}",
49
88
  )
50
89
 
90
+ async def push_event(self, event: Event[Any]) -> None:
91
+ await self._event_queue.put(event)
92
+
93
+ async def __aenter__(self) -> "PacketPool[CtxT]":
94
+ self._task_group = asyncio.TaskGroup()
95
+ await self._task_group.__aenter__()
96
+
97
+ return self
98
+
99
+ async def __aexit__(
100
+ self,
101
+ exc_type: type[BaseException] | None,
102
+ exc: BaseException | None,
103
+ tb: TracebackType | None,
104
+ ) -> bool | None:
105
+ await self.shutdown()
106
+
107
+ if self._task_group is not None:
108
+ try:
109
+ return await self._task_group.__aexit__(exc_type, exc, tb)
110
+ finally:
111
+ self._task_group = None
112
+
113
+ if self._errors:
114
+ raise ExceptionGroup("PacketPool worker errors", self._errors)
115
+
116
+ return False
117
+
51
118
  async def _handle_packets(
52
- self, processor_name: ProcName, ctx: RunContext[CtxT], **run_kwargs: Any
53
- ) -> AsyncIterator[Event[Any]] | None:
54
- queue = self._queues[processor_name]
119
+ self, proc_name: ProcName, ctx: RunContext[CtxT], **run_kwargs: Any
120
+ ) -> None:
121
+ queue = self._packet_queues[proc_name]
122
+ handler = self._packet_handlers[proc_name]
123
+
55
124
  while True:
125
+ packet = await queue.get()
126
+ if packet is None:
127
+ break
56
128
  try:
57
- packet = await queue.get()
58
- handler = self._packet_handlers.get(processor_name)
59
- if handler is None:
60
- break
61
- try:
62
- if ctx.is_streaming:
63
- async for event in handler(packet, ctx=ctx, **run_kwargs): # type: ignore[return-value]
64
- yield event
65
- else:
66
- await handler(packet, ctx=ctx, **run_kwargs)
67
-
68
- except Exception:
69
- logger.exception(f"Error handling packet for {processor_name}")
70
-
71
- queue.task_done()
72
-
73
- except Exception:
74
- logger.exception(
75
- f"Unexpected error in processing loop for {processor_name}"
76
- )
77
-
78
- async def unregister_packet_handler(self, processor_name: ProcName) -> None:
79
- if task := self._tasks.get(processor_name):
80
- task.cancel()
81
- try:
82
- await task
129
+ await handler(packet, ctx=ctx, **run_kwargs)
83
130
  except asyncio.CancelledError:
84
- logger.info(f"{processor_name} exited")
85
-
86
- self._tasks.pop(processor_name, None)
87
- self._queues.pop(processor_name, None)
88
- self._packet_handlers.pop(processor_name, None)
89
-
90
- async def stop_all(self) -> None:
91
- for processor_name in list(self._tasks):
92
- await self.unregister_packet_handler(processor_name)
131
+ raise
132
+ except Exception as err:
133
+ logger.exception("Error handling packet for %s", proc_name)
134
+ self._errors.append(err)
135
+ fut = self._final_result_fut
136
+ if fut and not fut.done():
137
+ fut.set_exception(err)
138
+ await self.shutdown()
139
+ raise
140
+
141
+ async def stream_events(self) -> AsyncIterator[Event[Any]]:
142
+ while True:
143
+ event = await self._event_queue.get()
144
+ if event is None:
145
+ break
146
+ yield event
147
+
148
+ async def shutdown(self) -> None:
149
+ if self._stopping:
150
+ await self._stopped_evt.wait()
151
+ return
152
+ self._stopping = True
153
+ try:
154
+ await self._event_queue.put(None)
155
+ for queue in self._packet_queues.values():
156
+ await queue.put(None)
157
+
158
+ finally:
159
+ self._stopped_evt.set()
grasp_agents/printer.py CHANGED
@@ -119,7 +119,7 @@ class Printer:
119
119
  # Thinking
120
120
  if isinstance(message, AssistantMessage) and message.reasoning_content:
121
121
  thinking = message.reasoning_content.strip(" \n")
122
- out += f"\n<thinking>\n{thinking}\n</thinking>\n"
122
+ out += f"<thinking>\n{thinking}\n</thinking>\n"
123
123
 
124
124
  # Content
125
125
  content = self.content_to_str(message.content or "", message.role)
@@ -219,12 +219,12 @@ async def print_event_stream(
219
219
  ) -> None:
220
220
  color = _get_color(event, Role.ASSISTANT)
221
221
 
222
- if isinstance(event, ProcPacketOutputEvent):
223
- src = "processor"
224
- elif isinstance(event, WorkflowResultEvent):
222
+ if isinstance(event, WorkflowResultEvent):
225
223
  src = "workflow"
226
- else:
224
+ elif isinstance(event, RunResultEvent):
227
225
  src = "run"
226
+ else:
227
+ src = "processor"
228
228
 
229
229
  text = f"\n<{event.proc_name}> [{event.call_id}]\n"
230
230
 
@@ -232,6 +232,10 @@ async def print_event_stream(
232
232
  text += f"<{src} output>\n"
233
233
  for p in event.data.payloads:
234
234
  if isinstance(p, BaseModel):
235
+ for field_info in type(p).model_fields.values():
236
+ if field_info.exclude:
237
+ field_info.exclude = False
238
+ type(p).model_rebuild(force=True)
235
239
  p_str = p.model_dump_json(indent=2)
236
240
  else:
237
241
  try: