grasp_agents 0.5.5__py3-none-any.whl → 0.5.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,246 @@
1
+ import asyncio
2
+ import logging
3
+ from collections.abc import AsyncIterator, Sequence
4
+ from typing import Any, ClassVar, Generic, cast
5
+
6
+ from ..errors import PacketRoutingError
7
+ from ..memory import MemT
8
+ from ..packet import Packet
9
+ from ..run_context import CtxT, RunContext
10
+ from ..typing.events import Event, ProcPacketOutputEvent, ProcPayloadOutputEvent
11
+ from ..typing.io import InT, OutT
12
+ from ..utils import stream_concurrent
13
+ from .base_processor import BaseProcessor, with_retry, with_retry_stream
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class ParallelProcessor(
19
+ BaseProcessor[InT, OutT, MemT, CtxT], Generic[InT, OutT, MemT, CtxT]
20
+ ):
21
+ _generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
22
+ 0: "_in_type",
23
+ 1: "_out_type",
24
+ }
25
+
26
+ async def _process(
27
+ self,
28
+ chat_inputs: Any | None = None,
29
+ *,
30
+ in_args: InT | None = None,
31
+ memory: MemT,
32
+ call_id: str,
33
+ ctx: RunContext[CtxT] | None = None,
34
+ ) -> OutT:
35
+ return cast("OutT", in_args)
36
+
37
+ async def _process_stream(
38
+ self,
39
+ chat_inputs: Any | None = None,
40
+ *,
41
+ in_args: InT | None = None,
42
+ memory: MemT,
43
+ call_id: str,
44
+ ctx: RunContext[CtxT] | None = None,
45
+ ) -> AsyncIterator[Event[Any]]:
46
+ output = cast("OutT", in_args)
47
+ yield ProcPayloadOutputEvent(data=output, proc_name=self.name, call_id=call_id)
48
+
49
+ def _validate_parallel_recipients(
50
+ self, out_packets: Sequence[Packet[OutT]], call_id: str
51
+ ) -> None:
52
+ recipient_sets = [set(p.recipients or []) for p in out_packets]
53
+ same_recipients = all(rs == recipient_sets[0] for rs in recipient_sets)
54
+ if not same_recipients:
55
+ raise PacketRoutingError(
56
+ proc_name=self.name,
57
+ call_id=call_id,
58
+ message="Parallel runs must return the same recipients "
59
+ f"[proc_name={self.name}; call_id={call_id}]",
60
+ )
61
+
62
+ @with_retry
63
+ async def _run_single(
64
+ self,
65
+ chat_inputs: Any | None = None,
66
+ *,
67
+ in_args: InT | None = None,
68
+ forgetful: bool = False,
69
+ call_id: str,
70
+ ctx: RunContext[CtxT] | None = None,
71
+ ) -> Packet[OutT]:
72
+ memory = self.memory.model_copy(deep=True) if forgetful else self.memory
73
+
74
+ output = await self._process(
75
+ chat_inputs=chat_inputs,
76
+ in_args=in_args,
77
+ memory=memory,
78
+ call_id=call_id,
79
+ ctx=ctx,
80
+ )
81
+ val_output = self._validate_output(output, call_id=call_id)
82
+
83
+ recipients = self._select_recipients(output=val_output, ctx=ctx)
84
+ self._validate_recipients(recipients, call_id=call_id)
85
+
86
+ return Packet(payloads=[val_output], sender=self.name, recipients=recipients)
87
+
88
+ async def _run_parallel(
89
+ self, in_args: list[InT], call_id: str, ctx: RunContext[CtxT] | None = None
90
+ ) -> Packet[OutT]:
91
+ tasks = [
92
+ self._run_single(
93
+ in_args=inp, forgetful=True, call_id=f"{call_id}/{idx}", ctx=ctx
94
+ )
95
+ for idx, inp in enumerate(in_args)
96
+ ]
97
+ out_packets = await asyncio.gather(*tasks)
98
+ self._validate_parallel_recipients(out_packets, call_id=call_id)
99
+
100
+ return Packet(
101
+ payloads=[out_packet.payloads[0] for out_packet in out_packets],
102
+ sender=self.name,
103
+ recipients=out_packets[0].recipients,
104
+ )
105
+
106
+ async def run(
107
+ self,
108
+ chat_inputs: Any | None = None,
109
+ *,
110
+ in_packet: Packet[InT] | None = None,
111
+ in_args: InT | list[InT] | None = None,
112
+ forgetful: bool = False,
113
+ call_id: str | None = None,
114
+ ctx: RunContext[CtxT] | None = None,
115
+ ) -> Packet[OutT]:
116
+ call_id = self._generate_call_id(call_id)
117
+
118
+ val_in_args = self._validate_inputs(
119
+ call_id=call_id,
120
+ chat_inputs=chat_inputs,
121
+ in_packet=in_packet,
122
+ in_args=in_args,
123
+ )
124
+
125
+ if val_in_args and len(val_in_args) > 1:
126
+ return await self._run_parallel(
127
+ in_args=val_in_args, call_id=call_id, ctx=ctx
128
+ )
129
+
130
+ return await self._run_single(
131
+ chat_inputs=chat_inputs,
132
+ in_args=val_in_args[0] if val_in_args else None,
133
+ forgetful=forgetful,
134
+ call_id=call_id,
135
+ ctx=ctx,
136
+ )
137
+
138
+ @with_retry_stream
139
+ async def _run_single_stream(
140
+ self,
141
+ chat_inputs: Any | None = None,
142
+ *,
143
+ in_args: InT | None = None,
144
+ forgetful: bool = False,
145
+ call_id: str,
146
+ ctx: RunContext[CtxT] | None = None,
147
+ ) -> AsyncIterator[Event[Any]]:
148
+ memory = self.memory.model_copy(deep=True) if forgetful else self.memory
149
+
150
+ output: OutT | None = None
151
+ async for event in self._process_stream(
152
+ chat_inputs=chat_inputs,
153
+ in_args=in_args,
154
+ memory=memory,
155
+ call_id=call_id,
156
+ ctx=ctx,
157
+ ):
158
+ if isinstance(event, ProcPayloadOutputEvent):
159
+ output = event.data
160
+ yield event
161
+
162
+ assert output is not None
163
+
164
+ val_output = self._validate_output(output, call_id=call_id)
165
+
166
+ recipients = self._select_recipients(output=val_output, ctx=ctx)
167
+ self._validate_recipients(recipients, call_id=call_id)
168
+
169
+ out_packet = Packet[OutT](
170
+ payloads=[val_output], sender=self.name, recipients=recipients
171
+ )
172
+
173
+ yield ProcPacketOutputEvent(
174
+ data=out_packet, proc_name=self.name, call_id=call_id
175
+ )
176
+
177
+ async def _run_parallel_stream(
178
+ self,
179
+ in_args: list[InT],
180
+ call_id: str,
181
+ ctx: RunContext[CtxT] | None = None,
182
+ ) -> AsyncIterator[Event[Any]]:
183
+ streams = [
184
+ self._run_single_stream(
185
+ in_args=inp, forgetful=True, call_id=f"{call_id}/{idx}", ctx=ctx
186
+ )
187
+ for idx, inp in enumerate(in_args)
188
+ ]
189
+
190
+ out_packets_map: dict[int, Packet[OutT]] = {}
191
+ async for idx, event in stream_concurrent(streams):
192
+ if isinstance(event, ProcPacketOutputEvent):
193
+ out_packets_map[idx] = event.data
194
+ else:
195
+ yield event
196
+
197
+ self._validate_parallel_recipients(
198
+ out_packets=list(out_packets_map.values()), call_id=call_id
199
+ )
200
+
201
+ out_packet = Packet(
202
+ payloads=[
203
+ out_packet.payloads[0]
204
+ for _, out_packet in sorted(out_packets_map.items())
205
+ ],
206
+ sender=self.name,
207
+ recipients=out_packets_map[0].recipients,
208
+ )
209
+
210
+ yield ProcPacketOutputEvent(
211
+ data=out_packet, proc_name=self.name, call_id=call_id
212
+ )
213
+
214
+ async def run_stream(
215
+ self,
216
+ chat_inputs: Any | None = None,
217
+ *,
218
+ in_packet: Packet[InT] | None = None,
219
+ in_args: InT | list[InT] | None = None,
220
+ forgetful: bool = False,
221
+ call_id: str | None = None,
222
+ ctx: RunContext[CtxT] | None = None,
223
+ ) -> AsyncIterator[Event[Any]]:
224
+ call_id = self._generate_call_id(call_id)
225
+
226
+ val_in_args = self._validate_inputs(
227
+ call_id=call_id,
228
+ chat_inputs=chat_inputs,
229
+ in_packet=in_packet,
230
+ in_args=in_args,
231
+ )
232
+
233
+ if val_in_args and len(val_in_args) > 1:
234
+ stream = self._run_parallel_stream(
235
+ in_args=val_in_args, call_id=call_id, ctx=ctx
236
+ )
237
+ else:
238
+ stream = self._run_single_stream(
239
+ chat_inputs=chat_inputs,
240
+ in_args=val_in_args[0] if val_in_args else None,
241
+ forgetful=forgetful,
242
+ call_id=call_id,
243
+ ctx=ctx,
244
+ )
245
+ async for event in stream:
246
+ yield event
@@ -0,0 +1,161 @@
1
+ import logging
2
+ from collections.abc import AsyncIterator
3
+ from typing import Any, ClassVar, Generic, cast
4
+
5
+ from ..memory import MemT
6
+ from ..packet import Packet
7
+ from ..run_context import CtxT, RunContext
8
+ from ..typing.events import Event, ProcPacketOutputEvent, ProcPayloadOutputEvent
9
+ from ..typing.io import InT, OutT, ProcName
10
+ from .base_processor import BaseProcessor, with_retry, with_retry_stream
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class Processor(BaseProcessor[InT, OutT, MemT, CtxT], Generic[InT, OutT, MemT, CtxT]):
16
+ _generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
17
+ 0: "_in_type",
18
+ 1: "_out_type",
19
+ }
20
+
21
+ async def _process(
22
+ self,
23
+ chat_inputs: Any | None = None,
24
+ *,
25
+ in_args: list[InT] | None = None,
26
+ memory: MemT,
27
+ call_id: str,
28
+ ctx: RunContext[CtxT] | None = None,
29
+ ) -> list[OutT]:
30
+ return cast("list[OutT]", in_args)
31
+
32
+ async def _process_stream(
33
+ self,
34
+ chat_inputs: Any | None = None,
35
+ *,
36
+ in_args: list[InT] | None = None,
37
+ memory: MemT,
38
+ call_id: str,
39
+ ctx: RunContext[CtxT] | None = None,
40
+ ) -> AsyncIterator[Event[Any]]:
41
+ outputs = await self._process(
42
+ chat_inputs=chat_inputs,
43
+ in_args=in_args,
44
+ memory=memory,
45
+ call_id=call_id,
46
+ ctx=ctx,
47
+ )
48
+ for output in outputs:
49
+ yield ProcPayloadOutputEvent(
50
+ data=output, proc_name=self.name, call_id=call_id
51
+ )
52
+
53
+ def _preprocess(
54
+ self,
55
+ chat_inputs: Any | None = None,
56
+ *,
57
+ in_packet: Packet[InT] | None = None,
58
+ in_args: InT | list[InT] | None = None,
59
+ forgetful: bool = False,
60
+ call_id: str | None = None,
61
+ ctx: RunContext[CtxT] | None = None,
62
+ ) -> tuple[list[InT] | None, MemT, str]:
63
+ call_id = self._generate_call_id(call_id)
64
+
65
+ val_in_args = self._validate_inputs(
66
+ call_id=call_id,
67
+ chat_inputs=chat_inputs,
68
+ in_packet=in_packet,
69
+ in_args=in_args,
70
+ )
71
+
72
+ memory = self.memory.model_copy(deep=True) if forgetful else self.memory
73
+
74
+ return val_in_args, memory, call_id
75
+
76
+ def _postprocess(
77
+ self, outputs: list[OutT], call_id: str, ctx: RunContext[CtxT] | None = None
78
+ ) -> Packet[OutT]:
79
+ payloads: list[OutT] = []
80
+ routing: dict[int, list[ProcName] | None] = {}
81
+ for idx, output in enumerate(outputs):
82
+ val_output = self._validate_output(output, call_id=call_id)
83
+ recipients = self._select_recipients(output=val_output, ctx=ctx)
84
+ self._validate_recipients(recipients, call_id=call_id)
85
+
86
+ payloads.append(val_output)
87
+ routing[idx] = recipients
88
+
89
+ recipient_sets = [set(r or []) for r in routing.values()]
90
+ if all(r == recipient_sets[0] for r in recipient_sets):
91
+ recipients = routing[0]
92
+ else:
93
+ recipients = routing
94
+
95
+ return Packet(payloads=payloads, sender=self.name, recipients=recipients) # type: ignore[return-value]
96
+
97
+ @with_retry
98
+ async def run(
99
+ self,
100
+ chat_inputs: Any | None = None,
101
+ *,
102
+ in_packet: Packet[InT] | None = None,
103
+ in_args: InT | list[InT] | None = None,
104
+ forgetful: bool = False,
105
+ call_id: str | None = None,
106
+ ctx: RunContext[CtxT] | None = None,
107
+ ) -> Packet[OutT]:
108
+ val_in_args, memory, call_id = self._preprocess(
109
+ chat_inputs=chat_inputs,
110
+ in_packet=in_packet,
111
+ in_args=in_args,
112
+ forgetful=forgetful,
113
+ call_id=call_id,
114
+ ctx=ctx,
115
+ )
116
+ outputs = await self._process(
117
+ chat_inputs=chat_inputs,
118
+ in_args=val_in_args,
119
+ memory=memory,
120
+ call_id=call_id,
121
+ ctx=ctx,
122
+ )
123
+
124
+ return self._postprocess(outputs=outputs, call_id=call_id, ctx=ctx)
125
+
126
+ @with_retry_stream
127
+ async def run_stream(
128
+ self,
129
+ chat_inputs: Any | None = None,
130
+ *,
131
+ in_packet: Packet[InT] | None = None,
132
+ in_args: InT | list[InT] | None = None,
133
+ forgetful: bool = False,
134
+ call_id: str | None = None,
135
+ ctx: RunContext[CtxT] | None = None,
136
+ ) -> AsyncIterator[Event[Any]]:
137
+ val_in_args, memory, call_id = self._preprocess(
138
+ chat_inputs=chat_inputs,
139
+ in_packet=in_packet,
140
+ in_args=in_args,
141
+ forgetful=forgetful,
142
+ call_id=call_id,
143
+ ctx=ctx,
144
+ )
145
+ outputs: list[OutT] = []
146
+ async for event in self._process_stream(
147
+ chat_inputs=chat_inputs,
148
+ in_args=val_in_args,
149
+ memory=memory,
150
+ call_id=call_id,
151
+ ctx=ctx,
152
+ ):
153
+ if isinstance(event, ProcPayloadOutputEvent):
154
+ outputs.append(event.data)
155
+ yield event
156
+
157
+ out_packet = self._postprocess(outputs=outputs, call_id=call_id, ctx=ctx)
158
+
159
+ yield ProcPacketOutputEvent(
160
+ data=out_packet, proc_name=self.name, call_id=call_id
161
+ )
grasp_agents/runner.py CHANGED
@@ -1,3 +1,4 @@
1
+ import logging
1
2
  from collections.abc import AsyncIterator, Sequence
2
3
  from functools import partial
3
4
  from typing import Any, Generic
@@ -5,17 +6,19 @@ from typing import Any, Generic
5
6
  from .errors import RunnerError
6
7
  from .packet import Packet, StartPacket
7
8
  from .packet_pool import END_PROC_NAME, PacketPool
8
- from .processor import Processor
9
+ from .processors.base_processor import BaseProcessor
9
10
  from .run_context import CtxT, RunContext
10
11
  from .typing.events import Event, ProcPacketOutputEvent, RunResultEvent
11
12
  from .typing.io import OutT
12
13
 
14
+ logger = logging.getLogger(__name__)
15
+
13
16
 
14
17
  class Runner(Generic[OutT, CtxT]):
15
18
  def __init__(
16
19
  self,
17
- entry_proc: Processor[Any, Any, Any, CtxT],
18
- procs: Sequence[Processor[Any, Any, Any, CtxT]],
20
+ entry_proc: BaseProcessor[Any, Any, Any, CtxT],
21
+ procs: Sequence[BaseProcessor[Any, Any, Any, CtxT]],
19
22
  ctx: RunContext[CtxT] | None = None,
20
23
  ) -> None:
21
24
  if entry_proc not in procs:
@@ -31,7 +34,6 @@ class Runner(Generic[OutT, CtxT]):
31
34
  self._entry_proc = entry_proc
32
35
  self._procs = procs
33
36
  self._ctx = ctx or RunContext[CtxT]()
34
- self._packet_pool: PacketPool[CtxT] = PacketPool()
35
37
 
36
38
  @property
37
39
  def ctx(self) -> RunContext[CtxT]:
@@ -46,28 +48,41 @@ class Runner(Generic[OutT, CtxT]):
46
48
 
47
49
  async def _packet_handler(
48
50
  self,
49
- proc: Processor[Any, Any, Any, CtxT],
50
- pool: PacketPool[CtxT],
51
51
  packet: Packet[Any],
52
+ *,
53
+ proc: BaseProcessor[Any, Any, Any, CtxT],
54
+ pool: PacketPool,
52
55
  ctx: RunContext[CtxT],
53
56
  **run_kwargs: Any,
54
57
  ) -> None:
55
58
  _in_packet, _chat_inputs = self._unpack_packet(packet)
59
+
60
+ logger.info(f"\n[Running processor {proc.name}]\n")
61
+
56
62
  out_packet = await proc.run(
57
63
  chat_inputs=_chat_inputs, in_packet=_in_packet, ctx=ctx, **run_kwargs
58
64
  )
65
+
66
+ logger.info(
67
+ f"\n[Finished running processor {proc.name}]\n"
68
+ f"Posting output packet to recipients {out_packet.recipients}\n"
69
+ )
70
+
59
71
  await pool.post(out_packet)
60
72
 
61
73
  async def _packet_handler_stream(
62
74
  self,
63
- proc: Processor[Any, Any, Any, CtxT],
64
- pool: PacketPool[CtxT],
65
75
  packet: Packet[Any],
76
+ *,
77
+ proc: BaseProcessor[Any, Any, Any, CtxT],
78
+ pool: PacketPool,
66
79
  ctx: RunContext[CtxT],
67
80
  **run_kwargs: Any,
68
81
  ) -> None:
69
82
  _in_packet, _chat_inputs = self._unpack_packet(packet)
70
83
 
84
+ logger.info(f"\n[Running processor {proc.name}]\n")
85
+
71
86
  out_packet: Packet[Any] | None = None
72
87
  async for event in proc.run_stream(
73
88
  chat_inputs=_chat_inputs, in_packet=_in_packet, ctx=ctx, **run_kwargs
@@ -78,20 +93,25 @@ class Runner(Generic[OutT, CtxT]):
78
93
 
79
94
  assert out_packet is not None
80
95
 
96
+ logger.info(
97
+ f"\n[Finished running processor {proc.name}]\n"
98
+ f"Posting output packet to recipients {out_packet.recipients}\n"
99
+ )
100
+
81
101
  await pool.post(out_packet)
82
102
 
83
- async def run(
84
- self,
85
- chat_input: Any = "start",
86
- **run_args: Any,
87
- ) -> Packet[OutT]:
88
- async with PacketPool[CtxT]() as pool:
103
+ async def run(self, chat_input: Any = "start", **run_args: Any) -> Packet[OutT]:
104
+ async with PacketPool() as pool:
89
105
  for proc in self._procs:
90
106
  pool.register_packet_handler(
91
107
  proc_name=proc.name,
92
- handler=partial(self._packet_handler, proc, pool),
93
- ctx=self._ctx,
94
- **run_args,
108
+ handler=partial(
109
+ self._packet_handler,
110
+ proc=proc,
111
+ pool=pool,
112
+ ctx=self._ctx,
113
+ **run_args,
114
+ ),
95
115
  )
96
116
  await pool.post(
97
117
  StartPacket[Any](
@@ -101,17 +121,19 @@ class Runner(Generic[OutT, CtxT]):
101
121
  return await pool.final_result()
102
122
 
103
123
  async def run_stream(
104
- self,
105
- chat_input: Any = "start",
106
- **run_args: Any,
124
+ self, chat_input: Any = "start", **run_args: Any
107
125
  ) -> AsyncIterator[Event[Any]]:
108
- async with PacketPool[CtxT]() as pool:
126
+ async with PacketPool() as pool:
109
127
  for proc in self._procs:
110
128
  pool.register_packet_handler(
111
129
  proc_name=proc.name,
112
- handler=partial(self._packet_handler_stream, proc, pool),
113
- ctx=self._ctx,
114
- **run_args,
130
+ handler=partial(
131
+ self._packet_handler_stream,
132
+ proc=proc,
133
+ pool=pool,
134
+ ctx=self._ctx,
135
+ **run_args,
136
+ ),
115
137
  )
116
138
  await pool.post(
117
139
  StartPacket[Any](