grasp_agents 0.5.2__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
grasp_agents/printer.py CHANGED
@@ -219,12 +219,12 @@ async def print_event_stream(
219
219
  ) -> None:
220
220
  color = _get_color(event, Role.ASSISTANT)
221
221
 
222
- if isinstance(event, ProcPacketOutputEvent):
223
- src = "processor"
224
- elif isinstance(event, WorkflowResultEvent):
222
+ if isinstance(event, WorkflowResultEvent):
225
223
  src = "workflow"
226
- else:
224
+ elif isinstance(event, RunResultEvent):
227
225
  src = "run"
226
+ else:
227
+ src = "processor"
228
228
 
229
229
  text = f"\n<{event.proc_name}> [{event.call_id}]\n"
230
230
 
@@ -232,6 +232,10 @@ async def print_event_stream(
232
232
  text += f"<{src} output>\n"
233
233
  for p in event.data.payloads:
234
234
  if isinstance(p, BaseModel):
235
+ for field_info in type(p).model_fields.values():
236
+ if field_info.exclude:
237
+ field_info.exclude = False
238
+ type(p).model_rebuild(force=True)
235
239
  p_str = p.model_dump_json(indent=2)
236
240
  else:
237
241
  try:
grasp_agents/processor.py CHANGED
@@ -2,13 +2,18 @@ import asyncio
2
2
  import logging
3
3
  from abc import ABC
4
4
  from collections.abc import AsyncIterator, Sequence
5
- from typing import Any, ClassVar, Generic, cast, final
5
+ from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast, final
6
6
  from uuid import uuid4
7
7
 
8
8
  from pydantic import BaseModel, TypeAdapter
9
9
  from pydantic import ValidationError as PydanticValidationError
10
10
 
11
- from .errors import ProcInputValidationError, ProcOutputValidationError
11
+ from .errors import (
12
+ PacketRoutingError,
13
+ ProcInputValidationError,
14
+ ProcOutputValidationError,
15
+ ProcRunError,
16
+ )
12
17
  from .generics_utils import AutoInstanceAttributesMixin
13
18
  from .memory import DummyMemory, MemT
14
19
  from .packet import Packet
@@ -20,35 +25,51 @@ from .typing.events import (
20
25
  ProcStreamingErrorData,
21
26
  ProcStreamingErrorEvent,
22
27
  )
23
- from .typing.io import InT, OutT_co, ProcName
28
+ from .typing.io import InT, OutT, ProcName
24
29
  from .typing.tool import BaseTool
25
30
  from .utils import stream_concurrent
26
31
 
27
32
  logger = logging.getLogger(__name__)
28
33
 
34
+ _OutT_contra = TypeVar("_OutT_contra", contravariant=True)
35
+
36
+
37
+ class SelectRecipientsHandler(Protocol[_OutT_contra, CtxT]):
38
+ def __call__(
39
+ self, output: _OutT_contra, ctx: RunContext[CtxT] | None
40
+ ) -> list[ProcName] | None: ...
41
+
29
42
 
30
- class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, CtxT]):
43
+ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT, MemT, CtxT]):
31
44
  _generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
32
45
  0: "_in_type",
33
46
  1: "_out_type",
34
47
  }
35
48
 
36
- def __init__(self, name: ProcName, max_retries: int = 0, **kwargs: Any) -> None:
49
+ def __init__(
50
+ self,
51
+ name: ProcName,
52
+ max_retries: int = 0,
53
+ recipients: list[ProcName] | None = None,
54
+ **kwargs: Any,
55
+ ) -> None:
37
56
  self._in_type: type[InT]
38
- self._out_type: type[OutT_co]
57
+ self._out_type: type[OutT]
39
58
 
40
59
  super().__init__()
41
60
 
42
61
  self._name: ProcName = name
43
62
  self._memory: MemT = cast("MemT", DummyMemory())
44
63
  self._max_retries: int = max_retries
64
+ self.recipients = recipients
65
+ self.select_recipients_impl: SelectRecipientsHandler[OutT, CtxT] | None = None
45
66
 
46
67
  @property
47
68
  def in_type(self) -> type[InT]:
48
69
  return self._in_type
49
70
 
50
71
  @property
51
- def out_type(self) -> type[OutT_co]:
72
+ def out_type(self) -> type[OutT]:
52
73
  return self._out_type
53
74
 
54
75
  @property
@@ -68,65 +89,108 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
68
89
  return str(uuid4())[:6] + "_" + self.name
69
90
  return call_id
70
91
 
71
- def _validate_and_resolve_single_input(
92
+ def _validate_inputs(
72
93
  self,
94
+ call_id: str,
73
95
  chat_inputs: Any | None = None,
74
96
  in_packet: Packet[InT] | None = None,
75
- in_args: InT | None = None,
76
- ) -> InT | None:
77
- multiple_inputs_err_message = (
97
+ in_args: InT | Sequence[InT] | None = None,
98
+ ) -> Sequence[InT] | None:
99
+ mult_inputs_err_message = (
78
100
  "Only one of chat_inputs, in_args, or in_message must be provided."
79
101
  )
102
+ err_kwargs = {"proc_name": self.name, "call_id": call_id}
103
+
80
104
  if chat_inputs is not None and in_args is not None:
81
- raise ProcInputValidationError(multiple_inputs_err_message)
105
+ raise ProcInputValidationError(
106
+ message=mult_inputs_err_message, **err_kwargs
107
+ )
82
108
  if chat_inputs is not None and in_packet is not None:
83
- raise ProcInputValidationError(multiple_inputs_err_message)
109
+ raise ProcInputValidationError(
110
+ message=mult_inputs_err_message, **err_kwargs
111
+ )
84
112
  if in_args is not None and in_packet is not None:
85
- raise ProcInputValidationError(multiple_inputs_err_message)
86
-
87
- if in_packet is not None:
88
- if len(in_packet.payloads) != 1:
89
- raise ProcInputValidationError(
90
- "Single input runs require exactly one payload in in_packet."
91
- )
92
- return in_packet.payloads[0]
93
-
94
- return in_args
113
+ raise ProcInputValidationError(
114
+ message=mult_inputs_err_message, **err_kwargs
115
+ )
95
116
 
96
- def _validate_and_resolve_parallel_inputs(
97
- self,
98
- chat_inputs: Any | None,
99
- in_packet: Packet[InT] | None,
100
- in_args: Sequence[InT] | None,
101
- ) -> Sequence[InT]:
102
- if chat_inputs is not None:
117
+ if in_packet is not None and not in_packet.payloads:
118
+ raise ProcInputValidationError(
119
+ message="in_packet must contain at least one payload.", **err_kwargs
120
+ )
121
+ if in_args is not None and not in_args:
103
122
  raise ProcInputValidationError(
104
- "chat_inputs are not supported in parallel runs. "
105
- "Use in_packet or in_args."
123
+ message="in_args must contain at least one argument.", **err_kwargs
106
124
  )
107
- if in_packet is not None:
108
- if not in_packet.payloads:
125
+
126
+ if chat_inputs is not None:
127
+ return None
128
+
129
+ resolved_args: Sequence[InT]
130
+
131
+ if isinstance(in_args, Sequence):
132
+ _in_args = cast("Sequence[Any]", in_args)
133
+ if all(isinstance(x, self.in_type) for x in _in_args):
134
+ resolved_args = cast("Sequence[InT]", _in_args)
135
+ elif isinstance(_in_args, self.in_type):
136
+ resolved_args = cast("Sequence[InT]", [_in_args])
137
+ else:
109
138
  raise ProcInputValidationError(
110
- "Parallel runs require at least one input payload in in_packet."
139
+ message=f"in_args are neither of type {self.in_type} "
140
+ f"nor a sequence of {self.in_type}.",
141
+ **err_kwargs,
111
142
  )
112
- return in_packet.payloads
113
- if in_args is not None:
114
- return in_args
115
- raise ProcInputValidationError(
116
- "Parallel runs require either in_packet or in_args to be provided."
117
- )
118
143
 
119
- def _validate_outputs(self, out_payloads: Sequence[OutT_co]) -> Sequence[OutT_co]:
144
+ elif in_args is not None:
145
+ resolved_args = cast("Sequence[InT]", [in_args])
146
+
147
+ else:
148
+ assert in_packet is not None
149
+ resolved_args = in_packet.payloads
150
+
120
151
  try:
121
- return [
122
- TypeAdapter(self._out_type).validate_python(payload)
123
- for payload in out_payloads
124
- ]
152
+ for args in resolved_args:
153
+ TypeAdapter(self._in_type).validate_python(args)
154
+ except PydanticValidationError as err:
155
+ raise ProcInputValidationError(message=str(err), **err_kwargs) from err
156
+
157
+ return resolved_args
158
+
159
+ def _validate_output(self, out_payload: OutT, call_id: str) -> OutT:
160
+ if out_payload is None:
161
+ return out_payload
162
+ try:
163
+ return TypeAdapter(self._out_type).validate_python(out_payload)
125
164
  except PydanticValidationError as err:
126
165
  raise ProcOutputValidationError(
127
- f"Output validation failed for processor {self.name}:\n{err}"
166
+ schema=self._out_type, proc_name=self.name, call_id=call_id
128
167
  ) from err
129
168
 
169
+ def _validate_recipients(
170
+ self, recipients: Sequence[ProcName] | None, call_id: str
171
+ ) -> None:
172
+ for r in recipients or []:
173
+ if r not in (self.recipients or []):
174
+ raise PacketRoutingError(
175
+ proc_name=self.name,
176
+ call_id=call_id,
177
+ selected_recipient=r,
178
+ allowed_recipients=cast("list[str]", self.recipients),
179
+ )
180
+
181
+ def _validate_par_recipients(
182
+ self, out_packets: Sequence[Packet[OutT]], call_id: str
183
+ ) -> None:
184
+ recipient_sets = [set(p.recipients or []) for p in out_packets]
185
+ same_recipients = all(rs == recipient_sets[0] for rs in recipient_sets)
186
+ if not same_recipients:
187
+ raise PacketRoutingError(
188
+ proc_name=self.name,
189
+ call_id=call_id,
190
+ message="Parallel runs must return the same recipients "
191
+ f"[proc_name={self.name}; call_id={call_id}]",
192
+ )
193
+
130
194
  async def _process(
131
195
  self,
132
196
  chat_inputs: Any | None = None,
@@ -135,13 +199,8 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
135
199
  memory: MemT,
136
200
  call_id: str,
137
201
  ctx: RunContext[CtxT] | None = None,
138
- ) -> Sequence[OutT_co]:
139
- if in_args is None:
140
- raise ProcInputValidationError(
141
- "Default implementation of _process requires in_args"
142
- )
143
-
144
- return cast("Sequence[OutT_co]", in_args)
202
+ ) -> OutT:
203
+ return cast("OutT", in_args)
145
204
 
146
205
  async def _process_stream(
147
206
  self,
@@ -152,99 +211,86 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
152
211
  call_id: str,
153
212
  ctx: RunContext[CtxT] | None = None,
154
213
  ) -> AsyncIterator[Event[Any]]:
155
- if in_args is None:
156
- raise ProcInputValidationError(
157
- "Default implementation of _process_stream requires in_args"
158
- )
159
- outputs = cast("Sequence[OutT_co]", in_args)
160
- for out in outputs:
161
- yield ProcPayloadOutputEvent(data=out, proc_name=self.name, call_id=call_id)
214
+ output = cast("OutT", in_args)
215
+ yield ProcPayloadOutputEvent(data=output, proc_name=self.name, call_id=call_id)
162
216
 
163
217
  async def _run_single_once(
164
218
  self,
165
219
  chat_inputs: Any | None = None,
166
220
  *,
167
- in_packet: Packet[InT] | None = None,
168
221
  in_args: InT | None = None,
169
222
  forgetful: bool = False,
170
223
  call_id: str,
171
224
  ctx: RunContext[CtxT] | None = None,
172
- ) -> Packet[OutT_co]:
173
- resolved_in_args = self._validate_and_resolve_single_input(
174
- chat_inputs=chat_inputs, in_packet=in_packet, in_args=in_args
175
- )
225
+ ) -> Packet[OutT]:
176
226
  _memory = self.memory.model_copy(deep=True) if forgetful else self.memory
177
227
 
178
- outputs = await self._process(
228
+ output = await self._process(
179
229
  chat_inputs=chat_inputs,
180
- in_args=resolved_in_args,
230
+ in_args=in_args,
181
231
  memory=_memory,
182
232
  call_id=call_id,
183
233
  ctx=ctx,
184
234
  )
185
- val_outputs = self._validate_outputs(outputs)
235
+ val_output = self._validate_output(output, call_id=call_id)
186
236
 
187
- return Packet(payloads=val_outputs, sender=self.name)
237
+ recipients = self._select_recipients(output=val_output, ctx=ctx)
238
+ self._validate_recipients(recipients, call_id=call_id)
239
+
240
+ return Packet(payloads=[val_output], sender=self.name, recipients=recipients)
188
241
 
189
242
  async def _run_single(
190
243
  self,
191
244
  chat_inputs: Any | None = None,
192
245
  *,
193
- in_packet: Packet[InT] | None = None,
194
246
  in_args: InT | None = None,
195
247
  forgetful: bool = False,
196
248
  call_id: str,
197
249
  ctx: RunContext[CtxT] | None = None,
198
- ) -> Packet[OutT_co] | None:
250
+ ) -> Packet[OutT]:
199
251
  n_attempt = 0
200
252
  while n_attempt <= self.max_retries:
201
253
  try:
202
254
  return await self._run_single_once(
203
255
  chat_inputs=chat_inputs,
204
- in_packet=in_packet,
205
256
  in_args=in_args,
206
257
  forgetful=forgetful,
207
258
  call_id=call_id,
208
259
  ctx=ctx,
209
260
  )
210
261
  except Exception as err:
262
+ err_message = (
263
+ f"\nProcessor run failed [proc_name={self.name}; call_id={call_id}]"
264
+ )
211
265
  n_attempt += 1
212
266
  if n_attempt > self.max_retries:
213
267
  if n_attempt == 1:
214
- logger.warning(f"\nProcessor run failed:\n{err}")
268
+ logger.warning(f"{err_message}:\n{err}")
215
269
  if n_attempt > 1:
216
- logger.warning(f"\nProcessor run failed after retrying:\n{err}")
217
- return None
218
- logger.warning(
219
- f"\nProcessor run failed (retry attempt {n_attempt}):\n{err}"
220
- )
270
+ logger.warning(f"{err_message} after retrying:\n{err}")
271
+ raise ProcRunError(proc_name=self.name, call_id=call_id) from err
272
+
273
+ logger.warning(f"{err_message} (retry attempt {n_attempt}):\n{err}")
274
+
275
+ raise ProcRunError(proc_name=self.name, call_id=call_id)
221
276
 
222
277
  async def _run_par(
223
- self,
224
- chat_inputs: Any | None = None,
225
- *,
226
- in_packet: Packet[InT] | None = None,
227
- in_args: Sequence[InT] | None = None,
228
- call_id: str,
229
- ctx: RunContext[CtxT] | None = None,
230
- ) -> Packet[OutT_co]:
231
- par_inputs = self._validate_and_resolve_parallel_inputs(
232
- chat_inputs=chat_inputs, in_packet=in_packet, in_args=in_args
233
- )
278
+ self, in_args: Sequence[InT], call_id: str, ctx: RunContext[CtxT] | None = None
279
+ ) -> Packet[OutT]:
234
280
  tasks = [
235
281
  self._run_single(
236
282
  in_args=inp, forgetful=True, call_id=f"{call_id}/{idx}", ctx=ctx
237
283
  )
238
- for idx, inp in enumerate(par_inputs)
284
+ for idx, inp in enumerate(in_args)
239
285
  ]
240
286
  out_packets = await asyncio.gather(*tasks)
241
287
 
242
- return Packet( # type: ignore[return]
243
- payloads=[
244
- (out_packet.payloads[0] if out_packet else None)
245
- for out_packet in out_packets
246
- ],
288
+ self._validate_par_recipients(out_packets, call_id=call_id)
289
+
290
+ return Packet(
291
+ payloads=[out_packet.payloads[0] for out_packet in out_packets],
247
292
  sender=self.name,
293
+ recipients=out_packets[0].recipients,
248
294
  )
249
295
 
250
296
  async def run(
@@ -256,23 +302,21 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
256
302
  forgetful: bool = False,
257
303
  call_id: str | None = None,
258
304
  ctx: RunContext[CtxT] | None = None,
259
- ) -> Packet[OutT_co]:
305
+ ) -> Packet[OutT]:
260
306
  call_id = self._generate_call_id(call_id)
261
307
 
262
- if (in_args is not None and isinstance(in_args, Sequence)) or (
263
- in_packet is not None and len(in_packet.payloads) > 1
264
- ):
265
- return await self._run_par(
266
- chat_inputs=chat_inputs,
267
- in_packet=in_packet,
268
- in_args=cast("Sequence[InT] | None", in_args),
269
- call_id=call_id,
270
- ctx=ctx,
271
- )
272
- return await self._run_single( # type: ignore[return]
308
+ val_in_args = self._validate_inputs(
309
+ call_id=call_id,
273
310
  chat_inputs=chat_inputs,
274
311
  in_packet=in_packet,
275
- in_args=cast("InT | None", in_args),
312
+ in_args=in_args,
313
+ )
314
+
315
+ if val_in_args and len(val_in_args) > 1:
316
+ return await self._run_par(in_args=val_in_args, call_id=call_id, ctx=ctx)
317
+ return await self._run_single(
318
+ chat_inputs=chat_inputs,
319
+ in_args=val_in_args[0] if val_in_args else None,
276
320
  forgetful=forgetful,
277
321
  call_id=call_id,
278
322
  ctx=ctx,
@@ -282,31 +326,35 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
282
326
  self,
283
327
  chat_inputs: Any | None = None,
284
328
  *,
285
- in_packet: Packet[InT] | None = None,
286
329
  in_args: InT | None = None,
287
330
  forgetful: bool = False,
288
331
  call_id: str,
289
332
  ctx: RunContext[CtxT] | None = None,
290
333
  ) -> AsyncIterator[Event[Any]]:
291
- resolved_in_args = self._validate_and_resolve_single_input(
292
- chat_inputs=chat_inputs, in_packet=in_packet, in_args=in_args
293
- )
294
334
  _memory = self.memory.model_copy(deep=True) if forgetful else self.memory
295
335
 
296
- outputs: list[OutT_co] = []
336
+ output: OutT | None = None
297
337
  async for event in self._process_stream(
298
338
  chat_inputs=chat_inputs,
299
- in_args=resolved_in_args,
339
+ in_args=in_args,
300
340
  memory=_memory,
301
341
  call_id=call_id,
302
342
  ctx=ctx,
303
343
  ):
304
344
  if isinstance(event, ProcPayloadOutputEvent):
305
- outputs.append(event.data)
345
+ output = event.data
306
346
  yield event
307
347
 
308
- val_outputs = self._validate_outputs(outputs)
309
- out_packet = Packet[OutT_co](payloads=val_outputs, sender=self.name)
348
+ assert output is not None
349
+
350
+ val_output = self._validate_output(output, call_id=call_id)
351
+
352
+ recipients = self._select_recipients(output=val_output, ctx=ctx)
353
+ self._validate_recipients(recipients, call_id=call_id)
354
+
355
+ out_packet = Packet[OutT](
356
+ payloads=[val_output], sender=self.name, recipients=recipients
357
+ )
310
358
 
311
359
  yield ProcPacketOutputEvent(
312
360
  data=out_packet, proc_name=self.name, call_id=call_id
@@ -316,7 +364,6 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
316
364
  self,
317
365
  chat_inputs: Any | None = None,
318
366
  *,
319
- in_packet: Packet[InT] | None = None,
320
367
  in_args: InT | None = None,
321
368
  forgetful: bool = False,
322
369
  call_id: str,
@@ -327,7 +374,6 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
327
374
  try:
328
375
  async for event in self._run_single_stream_once(
329
376
  chat_inputs=chat_inputs,
330
- in_packet=in_packet,
331
377
  in_args=in_args,
332
378
  forgetful=forgetful,
333
379
  call_id=call_id,
@@ -343,56 +389,48 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
343
389
  data=err_data, proc_name=self.name, call_id=call_id
344
390
  )
345
391
 
392
+ err_message = (
393
+ "\nStreaming processor run failed "
394
+ f"[proc_name={self.name}; call_id={call_id}]"
395
+ )
396
+
346
397
  n_attempt += 1
347
398
  if n_attempt > self.max_retries:
348
399
  if n_attempt == 1:
349
- logger.warning(f"\nStreaming processor run failed:\n{err}")
400
+ logger.warning(f"{err_message}:\n{err}")
350
401
  if n_attempt > 1:
351
- logger.warning(
352
- f"\nStreaming processor run failed after retrying:\n{err}"
353
- )
354
- return
402
+ logger.warning(f"{err_message} after retrying:\n{err}")
403
+ raise ProcRunError(proc_name=self.name, call_id=call_id) from err
355
404
 
356
- logger.warning(
357
- "\nStreaming processor run failed "
358
- f"(retry attempt {n_attempt}):\n{err}"
359
- )
405
+ logger.warning(f"{err_message} (retry attempt {n_attempt}):\n{err}")
360
406
 
361
407
  async def _run_par_stream(
362
408
  self,
363
- chat_inputs: Any | None = None,
364
- *,
365
- in_packet: Packet[InT] | None = None,
366
- in_args: Sequence[InT] | None = None,
409
+ in_args: Sequence[InT],
367
410
  call_id: str,
368
411
  ctx: RunContext[CtxT] | None = None,
369
412
  ) -> AsyncIterator[Event[Any]]:
370
- par_inputs = self._validate_and_resolve_parallel_inputs(
371
- chat_inputs=chat_inputs, in_packet=in_packet, in_args=in_args
372
- )
373
413
  streams = [
374
414
  self._run_single_stream(
375
415
  in_args=inp, forgetful=True, call_id=f"{call_id}/{idx}", ctx=ctx
376
416
  )
377
- for idx, inp in enumerate(par_inputs)
417
+ for idx, inp in enumerate(in_args)
378
418
  ]
379
419
 
380
- out_packets_map: dict[int, Packet[OutT_co] | None] = dict.fromkeys(
381
- range(len(streams)), None
382
- )
383
-
420
+ out_packets_map: dict[int, Packet[OutT]] = {}
384
421
  async for idx, event in stream_concurrent(streams):
385
422
  if isinstance(event, ProcPacketOutputEvent):
386
423
  out_packets_map[idx] = event.data
387
424
  else:
388
425
  yield event
389
426
 
390
- out_packet = Packet( # type: ignore[return]
427
+ out_packet = Packet(
391
428
  payloads=[
392
- (out_packet.payloads[0] if out_packet else None)
393
- for out_packet in out_packets_map.values()
429
+ out_packet.payloads[0]
430
+ for _, out_packet in sorted(out_packets_map.items())
394
431
  ],
395
432
  sender=self.name,
433
+ recipients=out_packets_map[0].recipients,
396
434
  )
397
435
 
398
436
  yield ProcPacketOutputEvent(
@@ -411,23 +449,19 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
411
449
  ) -> AsyncIterator[Event[Any]]:
412
450
  call_id = self._generate_call_id(call_id)
413
451
 
414
- # yield ProcStartEvent(proc_name=self.name, call_id=call_id, data=None)
452
+ val_in_args = self._validate_inputs(
453
+ call_id=call_id,
454
+ chat_inputs=chat_inputs,
455
+ in_packet=in_packet,
456
+ in_args=in_args,
457
+ )
415
458
 
416
- if (in_args is not None and isinstance(in_args, Sequence)) or (
417
- in_packet is not None and len(in_packet.payloads) > 1
418
- ):
419
- stream = self._run_par_stream(
420
- chat_inputs=chat_inputs,
421
- in_packet=in_packet,
422
- in_args=cast("Sequence[InT] | None", in_args),
423
- call_id=call_id,
424
- ctx=ctx,
425
- )
459
+ if val_in_args and len(val_in_args) > 1:
460
+ stream = self._run_par_stream(in_args=val_in_args, call_id=call_id, ctx=ctx)
426
461
  else:
427
462
  stream = self._run_single_stream(
428
463
  chat_inputs=chat_inputs,
429
- in_packet=in_packet,
430
- in_args=cast("InT | None", in_args),
464
+ in_args=val_in_args[0] if val_in_args else None,
431
465
  forgetful=forgetful,
432
466
  call_id=call_id,
433
467
  ctx=ctx,
@@ -435,12 +469,25 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
435
469
  async for event in stream:
436
470
  yield event
437
471
 
438
- # yield ProcFinishEvent(proc_name=self.name, call_id=call_id, data=None)
472
+ def _select_recipients(
473
+ self, output: OutT, ctx: RunContext[CtxT] | None = None
474
+ ) -> list[ProcName] | None:
475
+ if self.select_recipients_impl:
476
+ return self.select_recipients_impl(output=output, ctx=ctx)
477
+
478
+ return self.recipients
479
+
480
+ def select_recipients(
481
+ self, func: SelectRecipientsHandler[OutT, CtxT]
482
+ ) -> SelectRecipientsHandler[OutT, CtxT]:
483
+ self.select_recipients_impl = func
484
+
485
+ return func
439
486
 
440
487
  @final
441
488
  def as_tool(
442
489
  self, tool_name: str, tool_description: str
443
- ) -> BaseTool[InT, OutT_co, Any]: # type: ignore[override]
490
+ ) -> BaseTool[InT, OutT, Any]: # type: ignore[override]
444
491
  # TODO: stream tools
445
492
  processor_instance = self
446
493
  in_type = processor_instance.in_type
@@ -455,13 +502,11 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
455
502
  name: str = tool_name
456
503
  description: str = tool_description
457
504
 
458
- async def run(
459
- self, inp: InT, ctx: RunContext[CtxT] | None = None
460
- ) -> OutT_co:
505
+ async def run(self, inp: InT, ctx: RunContext[CtxT] | None = None) -> OutT:
461
506
  result = await processor_instance.run(
462
507
  in_args=in_type.model_validate(inp), forgetful=True, ctx=ctx
463
508
  )
464
509
 
465
510
  return result.payloads[0]
466
511
 
467
- return ProcessorTool() # type: ignore[return-value]
512
+ return ProcessorTool()