grasp_agents 0.5.3__py3-none-any.whl → 0.5.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
grasp_agents/processor.py CHANGED
@@ -2,13 +2,18 @@ import asyncio
2
2
  import logging
3
3
  from abc import ABC
4
4
  from collections.abc import AsyncIterator, Sequence
5
- from typing import Any, ClassVar, Generic, cast, final
5
+ from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast, final
6
6
  from uuid import uuid4
7
7
 
8
8
  from pydantic import BaseModel, TypeAdapter
9
9
  from pydantic import ValidationError as PydanticValidationError
10
10
 
11
- from .errors import ProcInputValidationError, ProcOutputValidationError
11
+ from .errors import (
12
+ PacketRoutingError,
13
+ ProcInputValidationError,
14
+ ProcOutputValidationError,
15
+ ProcRunError,
16
+ )
12
17
  from .generics_utils import AutoInstanceAttributesMixin
13
18
  from .memory import DummyMemory, MemT
14
19
  from .packet import Packet
@@ -20,22 +25,36 @@ from .typing.events import (
20
25
  ProcStreamingErrorData,
21
26
  ProcStreamingErrorEvent,
22
27
  )
23
- from .typing.io import InT, OutT_co, ProcName
28
+ from .typing.io import InT, OutT, ProcName
24
29
  from .typing.tool import BaseTool
25
30
  from .utils import stream_concurrent
26
31
 
27
32
  logger = logging.getLogger(__name__)
28
33
 
34
+ _OutT_contra = TypeVar("_OutT_contra", contravariant=True)
35
+
36
+
37
+ class RecipientSelector(Protocol[_OutT_contra, CtxT]):
38
+ def __call__(
39
+ self, output: _OutT_contra, ctx: RunContext[CtxT] | None
40
+ ) -> list[ProcName] | None: ...
41
+
29
42
 
30
- class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, CtxT]):
43
+ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT, MemT, CtxT]):
31
44
  _generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
32
45
  0: "_in_type",
33
46
  1: "_out_type",
34
47
  }
35
48
 
36
- def __init__(self, name: ProcName, max_retries: int = 0, **kwargs: Any) -> None:
49
+ def __init__(
50
+ self,
51
+ name: ProcName,
52
+ max_retries: int = 0,
53
+ recipients: list[ProcName] | None = None,
54
+ **kwargs: Any,
55
+ ) -> None:
37
56
  self._in_type: type[InT]
38
- self._out_type: type[OutT_co]
57
+ self._out_type: type[OutT]
39
58
 
40
59
  super().__init__()
41
60
 
@@ -43,12 +62,19 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
43
62
  self._memory: MemT = cast("MemT", DummyMemory())
44
63
  self._max_retries: int = max_retries
45
64
 
65
+ self.recipients = recipients
66
+
67
+ self.recipient_selector: RecipientSelector[OutT, CtxT] | None
68
+ if not hasattr(type(self), "recipient_selector"):
69
+ # Set to None if not defined in the subclass
70
+ self.recipient_selector = None
71
+
46
72
  @property
47
73
  def in_type(self) -> type[InT]:
48
74
  return self._in_type
49
75
 
50
76
  @property
51
- def out_type(self) -> type[OutT_co]:
77
+ def out_type(self) -> type[OutT]:
52
78
  return self._out_type
53
79
 
54
80
  @property
@@ -68,65 +94,124 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
68
94
  return str(uuid4())[:6] + "_" + self.name
69
95
  return call_id
70
96
 
71
- def _validate_and_resolve_single_input(
97
+ def _validate_inputs(
72
98
  self,
99
+ call_id: str,
73
100
  chat_inputs: Any | None = None,
74
101
  in_packet: Packet[InT] | None = None,
75
- in_args: InT | None = None,
76
- ) -> InT | None:
77
- multiple_inputs_err_message = (
102
+ in_args: InT | list[InT] | None = None,
103
+ ) -> list[InT] | None:
104
+ mult_inputs_err_message = (
78
105
  "Only one of chat_inputs, in_args, or in_message must be provided."
79
106
  )
107
+ err_kwargs = {"proc_name": self.name, "call_id": call_id}
108
+
80
109
  if chat_inputs is not None and in_args is not None:
81
- raise ProcInputValidationError(multiple_inputs_err_message)
110
+ raise ProcInputValidationError(
111
+ message=mult_inputs_err_message, **err_kwargs
112
+ )
82
113
  if chat_inputs is not None and in_packet is not None:
83
- raise ProcInputValidationError(multiple_inputs_err_message)
114
+ raise ProcInputValidationError(
115
+ message=mult_inputs_err_message, **err_kwargs
116
+ )
84
117
  if in_args is not None and in_packet is not None:
85
- raise ProcInputValidationError(multiple_inputs_err_message)
86
-
87
- if in_packet is not None:
88
- if len(in_packet.payloads) != 1:
89
- raise ProcInputValidationError(
90
- "Single input runs require exactly one payload in in_packet."
91
- )
92
- return in_packet.payloads[0]
93
-
94
- return in_args
118
+ raise ProcInputValidationError(
119
+ message=mult_inputs_err_message, **err_kwargs
120
+ )
95
121
 
96
- def _validate_and_resolve_parallel_inputs(
97
- self,
98
- chat_inputs: Any | None,
99
- in_packet: Packet[InT] | None,
100
- in_args: Sequence[InT] | None,
101
- ) -> Sequence[InT]:
102
- if chat_inputs is not None:
122
+ if in_packet is not None and not in_packet.payloads:
123
+ raise ProcInputValidationError(
124
+ message="in_packet must contain at least one payload.", **err_kwargs
125
+ )
126
+ if in_args is not None and not in_args:
103
127
  raise ProcInputValidationError(
104
- "chat_inputs are not supported in parallel runs. "
105
- "Use in_packet or in_args."
128
+ message="in_args must contain at least one argument.", **err_kwargs
106
129
  )
107
- if in_packet is not None:
108
- if not in_packet.payloads:
130
+
131
+ if chat_inputs is not None:
132
+ return None
133
+
134
+ resolved_args: list[InT]
135
+
136
+ if isinstance(in_args, list):
137
+ _in_args = cast("list[Any]", in_args)
138
+ if all(isinstance(x, self.in_type) for x in _in_args):
139
+ resolved_args = cast("list[InT]", _in_args)
140
+ elif isinstance(_in_args, self.in_type):
141
+ resolved_args = cast("list[InT]", [_in_args])
142
+ else:
109
143
  raise ProcInputValidationError(
110
- "Parallel runs require at least one input payload in in_packet."
144
+ message=f"in_args are neither of type {self.in_type} "
145
+ f"nor a sequence of {self.in_type}.",
146
+ **err_kwargs,
111
147
  )
112
- return in_packet.payloads
113
- if in_args is not None:
114
- return in_args
115
- raise ProcInputValidationError(
116
- "Parallel runs require either in_packet or in_args to be provided."
117
- )
118
148
 
119
- def _validate_outputs(self, out_payloads: Sequence[OutT_co]) -> Sequence[OutT_co]:
149
+ elif in_args is not None:
150
+ resolved_args = cast("list[InT]", [in_args])
151
+
152
+ else:
153
+ assert in_packet is not None
154
+ resolved_args = cast("list[InT]", in_packet.payloads)
155
+
156
+ try:
157
+ for args in resolved_args:
158
+ TypeAdapter(self._in_type).validate_python(args)
159
+ except PydanticValidationError as err:
160
+ raise ProcInputValidationError(message=str(err), **err_kwargs) from err
161
+
162
+ return resolved_args
163
+
164
+ def _validate_output(self, out_payload: OutT, call_id: str) -> OutT:
165
+ if out_payload is None:
166
+ return out_payload
120
167
  try:
121
- return [
122
- TypeAdapter(self._out_type).validate_python(payload)
123
- for payload in out_payloads
124
- ]
168
+ return TypeAdapter(self._out_type).validate_python(out_payload)
125
169
  except PydanticValidationError as err:
126
170
  raise ProcOutputValidationError(
127
- f"Output validation failed for processor {self.name}:\n{err}"
171
+ schema=self._out_type, proc_name=self.name, call_id=call_id
128
172
  ) from err
129
173
 
174
+ def _validate_recipients(
175
+ self, recipients: list[ProcName] | None, call_id: str
176
+ ) -> None:
177
+ for r in recipients or []:
178
+ if r not in (self.recipients or []):
179
+ raise PacketRoutingError(
180
+ proc_name=self.name,
181
+ call_id=call_id,
182
+ selected_recipient=r,
183
+ allowed_recipients=cast("list[str]", self.recipients),
184
+ )
185
+
186
+ def _validate_par_recipients(
187
+ self, out_packets: Sequence[Packet[OutT]], call_id: str
188
+ ) -> None:
189
+ recipient_sets = [set(p.recipients or []) for p in out_packets]
190
+ same_recipients = all(rs == recipient_sets[0] for rs in recipient_sets)
191
+ if not same_recipients:
192
+ raise PacketRoutingError(
193
+ proc_name=self.name,
194
+ call_id=call_id,
195
+ message="Parallel runs must return the same recipients "
196
+ f"[proc_name={self.name}; call_id={call_id}]",
197
+ )
198
+
199
+ @final
200
+ def _select_recipients(
201
+ self, output: OutT, ctx: RunContext[CtxT] | None = None
202
+ ) -> list[ProcName] | None:
203
+ if self.recipient_selector:
204
+ return self.recipient_selector(output=output, ctx=ctx)
205
+
206
+ return self.recipients
207
+
208
+ def add_recipient_selector(
209
+ self, func: RecipientSelector[OutT, CtxT]
210
+ ) -> RecipientSelector[OutT, CtxT]:
211
+ self.recipient_selector = func
212
+
213
+ return func
214
+
130
215
  async def _process(
131
216
  self,
132
217
  chat_inputs: Any | None = None,
@@ -135,13 +220,8 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
135
220
  memory: MemT,
136
221
  call_id: str,
137
222
  ctx: RunContext[CtxT] | None = None,
138
- ) -> Sequence[OutT_co]:
139
- if in_args is None:
140
- raise ProcInputValidationError(
141
- "Default implementation of _process requires in_args"
142
- )
143
-
144
- return cast("Sequence[OutT_co]", in_args)
223
+ ) -> OutT:
224
+ return cast("OutT", in_args)
145
225
 
146
226
  async def _process_stream(
147
227
  self,
@@ -152,99 +232,86 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
152
232
  call_id: str,
153
233
  ctx: RunContext[CtxT] | None = None,
154
234
  ) -> AsyncIterator[Event[Any]]:
155
- if in_args is None:
156
- raise ProcInputValidationError(
157
- "Default implementation of _process_stream requires in_args"
158
- )
159
- outputs = cast("Sequence[OutT_co]", in_args)
160
- for out in outputs:
161
- yield ProcPayloadOutputEvent(data=out, proc_name=self.name, call_id=call_id)
235
+ output = cast("OutT", in_args)
236
+ yield ProcPayloadOutputEvent(data=output, proc_name=self.name, call_id=call_id)
162
237
 
163
238
  async def _run_single_once(
164
239
  self,
165
240
  chat_inputs: Any | None = None,
166
241
  *,
167
- in_packet: Packet[InT] | None = None,
168
242
  in_args: InT | None = None,
169
243
  forgetful: bool = False,
170
244
  call_id: str,
171
245
  ctx: RunContext[CtxT] | None = None,
172
- ) -> Packet[OutT_co]:
173
- resolved_in_args = self._validate_and_resolve_single_input(
174
- chat_inputs=chat_inputs, in_packet=in_packet, in_args=in_args
175
- )
246
+ ) -> Packet[OutT]:
176
247
  _memory = self.memory.model_copy(deep=True) if forgetful else self.memory
177
248
 
178
- outputs = await self._process(
249
+ output = await self._process(
179
250
  chat_inputs=chat_inputs,
180
- in_args=resolved_in_args,
251
+ in_args=in_args,
181
252
  memory=_memory,
182
253
  call_id=call_id,
183
254
  ctx=ctx,
184
255
  )
185
- val_outputs = self._validate_outputs(outputs)
256
+ val_output = self._validate_output(output, call_id=call_id)
257
+
258
+ recipients = self._select_recipients(output=val_output, ctx=ctx)
259
+ self._validate_recipients(recipients, call_id=call_id)
186
260
 
187
- return Packet(payloads=val_outputs, sender=self.name)
261
+ return Packet(payloads=[val_output], sender=self.name, recipients=recipients)
188
262
 
189
263
  async def _run_single(
190
264
  self,
191
265
  chat_inputs: Any | None = None,
192
266
  *,
193
- in_packet: Packet[InT] | None = None,
194
267
  in_args: InT | None = None,
195
268
  forgetful: bool = False,
196
269
  call_id: str,
197
270
  ctx: RunContext[CtxT] | None = None,
198
- ) -> Packet[OutT_co] | None:
271
+ ) -> Packet[OutT]:
199
272
  n_attempt = 0
200
273
  while n_attempt <= self.max_retries:
201
274
  try:
202
275
  return await self._run_single_once(
203
276
  chat_inputs=chat_inputs,
204
- in_packet=in_packet,
205
277
  in_args=in_args,
206
278
  forgetful=forgetful,
207
279
  call_id=call_id,
208
280
  ctx=ctx,
209
281
  )
210
282
  except Exception as err:
283
+ err_message = (
284
+ f"\nProcessor run failed [proc_name={self.name}; call_id={call_id}]"
285
+ )
211
286
  n_attempt += 1
212
287
  if n_attempt > self.max_retries:
213
288
  if n_attempt == 1:
214
- logger.warning(f"\nProcessor run failed:\n{err}")
289
+ logger.warning(f"{err_message}:\n{err}")
215
290
  if n_attempt > 1:
216
- logger.warning(f"\nProcessor run failed after retrying:\n{err}")
217
- return None
218
- logger.warning(
219
- f"\nProcessor run failed (retry attempt {n_attempt}):\n{err}"
220
- )
291
+ logger.warning(f"{err_message} after retrying:\n{err}")
292
+ raise ProcRunError(proc_name=self.name, call_id=call_id) from err
293
+
294
+ logger.warning(f"{err_message} (retry attempt {n_attempt}):\n{err}")
295
+
296
+ raise ProcRunError(proc_name=self.name, call_id=call_id)
221
297
 
222
298
  async def _run_par(
223
- self,
224
- chat_inputs: Any | None = None,
225
- *,
226
- in_packet: Packet[InT] | None = None,
227
- in_args: Sequence[InT] | None = None,
228
- call_id: str,
229
- ctx: RunContext[CtxT] | None = None,
230
- ) -> Packet[OutT_co]:
231
- par_inputs = self._validate_and_resolve_parallel_inputs(
232
- chat_inputs=chat_inputs, in_packet=in_packet, in_args=in_args
233
- )
299
+ self, in_args: list[InT], call_id: str, ctx: RunContext[CtxT] | None = None
300
+ ) -> Packet[OutT]:
234
301
  tasks = [
235
302
  self._run_single(
236
303
  in_args=inp, forgetful=True, call_id=f"{call_id}/{idx}", ctx=ctx
237
304
  )
238
- for idx, inp in enumerate(par_inputs)
305
+ for idx, inp in enumerate(in_args)
239
306
  ]
240
307
  out_packets = await asyncio.gather(*tasks)
241
308
 
242
- return Packet( # type: ignore[return]
243
- payloads=[
244
- (out_packet.payloads[0] if out_packet else None)
245
- for out_packet in out_packets
246
- ],
309
+ self._validate_par_recipients(out_packets, call_id=call_id)
310
+
311
+ return Packet(
312
+ payloads=[out_packet.payloads[0] for out_packet in out_packets],
247
313
  sender=self.name,
314
+ recipients=out_packets[0].recipients,
248
315
  )
249
316
 
250
317
  async def run(
@@ -252,27 +319,25 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
252
319
  chat_inputs: Any | None = None,
253
320
  *,
254
321
  in_packet: Packet[InT] | None = None,
255
- in_args: InT | Sequence[InT] | None = None,
322
+ in_args: InT | list[InT] | None = None,
256
323
  forgetful: bool = False,
257
324
  call_id: str | None = None,
258
325
  ctx: RunContext[CtxT] | None = None,
259
- ) -> Packet[OutT_co]:
326
+ ) -> Packet[OutT]:
260
327
  call_id = self._generate_call_id(call_id)
261
328
 
262
- if (in_args is not None and isinstance(in_args, Sequence)) or (
263
- in_packet is not None and len(in_packet.payloads) > 1
264
- ):
265
- return await self._run_par(
266
- chat_inputs=chat_inputs,
267
- in_packet=in_packet,
268
- in_args=cast("Sequence[InT] | None", in_args),
269
- call_id=call_id,
270
- ctx=ctx,
271
- )
272
- return await self._run_single( # type: ignore[return]
329
+ val_in_args = self._validate_inputs(
330
+ call_id=call_id,
273
331
  chat_inputs=chat_inputs,
274
332
  in_packet=in_packet,
275
- in_args=cast("InT | None", in_args),
333
+ in_args=in_args,
334
+ )
335
+
336
+ if val_in_args and len(val_in_args) > 1:
337
+ return await self._run_par(in_args=val_in_args, call_id=call_id, ctx=ctx)
338
+ return await self._run_single(
339
+ chat_inputs=chat_inputs,
340
+ in_args=val_in_args[0] if val_in_args else None,
276
341
  forgetful=forgetful,
277
342
  call_id=call_id,
278
343
  ctx=ctx,
@@ -282,31 +347,35 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
282
347
  self,
283
348
  chat_inputs: Any | None = None,
284
349
  *,
285
- in_packet: Packet[InT] | None = None,
286
350
  in_args: InT | None = None,
287
351
  forgetful: bool = False,
288
352
  call_id: str,
289
353
  ctx: RunContext[CtxT] | None = None,
290
354
  ) -> AsyncIterator[Event[Any]]:
291
- resolved_in_args = self._validate_and_resolve_single_input(
292
- chat_inputs=chat_inputs, in_packet=in_packet, in_args=in_args
293
- )
294
355
  _memory = self.memory.model_copy(deep=True) if forgetful else self.memory
295
356
 
296
- outputs: list[OutT_co] = []
357
+ output: OutT | None = None
297
358
  async for event in self._process_stream(
298
359
  chat_inputs=chat_inputs,
299
- in_args=resolved_in_args,
360
+ in_args=in_args,
300
361
  memory=_memory,
301
362
  call_id=call_id,
302
363
  ctx=ctx,
303
364
  ):
304
365
  if isinstance(event, ProcPayloadOutputEvent):
305
- outputs.append(event.data)
366
+ output = event.data
306
367
  yield event
307
368
 
308
- val_outputs = self._validate_outputs(outputs)
309
- out_packet = Packet[OutT_co](payloads=val_outputs, sender=self.name)
369
+ assert output is not None
370
+
371
+ val_output = self._validate_output(output, call_id=call_id)
372
+
373
+ recipients = self._select_recipients(output=val_output, ctx=ctx)
374
+ self._validate_recipients(recipients, call_id=call_id)
375
+
376
+ out_packet = Packet[OutT](
377
+ payloads=[val_output], sender=self.name, recipients=recipients
378
+ )
310
379
 
311
380
  yield ProcPacketOutputEvent(
312
381
  data=out_packet, proc_name=self.name, call_id=call_id
@@ -316,7 +385,6 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
316
385
  self,
317
386
  chat_inputs: Any | None = None,
318
387
  *,
319
- in_packet: Packet[InT] | None = None,
320
388
  in_args: InT | None = None,
321
389
  forgetful: bool = False,
322
390
  call_id: str,
@@ -327,7 +395,6 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
327
395
  try:
328
396
  async for event in self._run_single_stream_once(
329
397
  chat_inputs=chat_inputs,
330
- in_packet=in_packet,
331
398
  in_args=in_args,
332
399
  forgetful=forgetful,
333
400
  call_id=call_id,
@@ -343,56 +410,48 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
343
410
  data=err_data, proc_name=self.name, call_id=call_id
344
411
  )
345
412
 
413
+ err_message = (
414
+ "\nStreaming processor run failed "
415
+ f"[proc_name={self.name}; call_id={call_id}]"
416
+ )
417
+
346
418
  n_attempt += 1
347
419
  if n_attempt > self.max_retries:
348
420
  if n_attempt == 1:
349
- logger.warning(f"\nStreaming processor run failed:\n{err}")
421
+ logger.warning(f"{err_message}:\n{err}")
350
422
  if n_attempt > 1:
351
- logger.warning(
352
- f"\nStreaming processor run failed after retrying:\n{err}"
353
- )
354
- return
423
+ logger.warning(f"{err_message} after retrying:\n{err}")
424
+ raise ProcRunError(proc_name=self.name, call_id=call_id) from err
355
425
 
356
- logger.warning(
357
- "\nStreaming processor run failed "
358
- f"(retry attempt {n_attempt}):\n{err}"
359
- )
426
+ logger.warning(f"{err_message} (retry attempt {n_attempt}):\n{err}")
360
427
 
361
428
  async def _run_par_stream(
362
429
  self,
363
- chat_inputs: Any | None = None,
364
- *,
365
- in_packet: Packet[InT] | None = None,
366
- in_args: Sequence[InT] | None = None,
430
+ in_args: list[InT],
367
431
  call_id: str,
368
432
  ctx: RunContext[CtxT] | None = None,
369
433
  ) -> AsyncIterator[Event[Any]]:
370
- par_inputs = self._validate_and_resolve_parallel_inputs(
371
- chat_inputs=chat_inputs, in_packet=in_packet, in_args=in_args
372
- )
373
434
  streams = [
374
435
  self._run_single_stream(
375
436
  in_args=inp, forgetful=True, call_id=f"{call_id}/{idx}", ctx=ctx
376
437
  )
377
- for idx, inp in enumerate(par_inputs)
438
+ for idx, inp in enumerate(in_args)
378
439
  ]
379
440
 
380
- out_packets_map: dict[int, Packet[OutT_co] | None] = dict.fromkeys(
381
- range(len(streams)), None
382
- )
383
-
441
+ out_packets_map: dict[int, Packet[OutT]] = {}
384
442
  async for idx, event in stream_concurrent(streams):
385
443
  if isinstance(event, ProcPacketOutputEvent):
386
444
  out_packets_map[idx] = event.data
387
445
  else:
388
446
  yield event
389
447
 
390
- out_packet = Packet( # type: ignore[return]
448
+ out_packet = Packet(
391
449
  payloads=[
392
- (out_packet.payloads[0] if out_packet else None)
393
- for out_packet in out_packets_map.values()
450
+ out_packet.payloads[0]
451
+ for _, out_packet in sorted(out_packets_map.items())
394
452
  ],
395
453
  sender=self.name,
454
+ recipients=out_packets_map[0].recipients,
396
455
  )
397
456
 
398
457
  yield ProcPacketOutputEvent(
@@ -404,30 +463,26 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
404
463
  chat_inputs: Any | None = None,
405
464
  *,
406
465
  in_packet: Packet[InT] | None = None,
407
- in_args: InT | Sequence[InT] | None = None,
466
+ in_args: InT | list[InT] | None = None,
408
467
  forgetful: bool = False,
409
468
  call_id: str | None = None,
410
469
  ctx: RunContext[CtxT] | None = None,
411
470
  ) -> AsyncIterator[Event[Any]]:
412
471
  call_id = self._generate_call_id(call_id)
413
472
 
414
- # yield ProcStartEvent(proc_name=self.name, call_id=call_id, data=None)
473
+ val_in_args = self._validate_inputs(
474
+ call_id=call_id,
475
+ chat_inputs=chat_inputs,
476
+ in_packet=in_packet,
477
+ in_args=in_args,
478
+ )
415
479
 
416
- if (in_args is not None and isinstance(in_args, Sequence)) or (
417
- in_packet is not None and len(in_packet.payloads) > 1
418
- ):
419
- stream = self._run_par_stream(
420
- chat_inputs=chat_inputs,
421
- in_packet=in_packet,
422
- in_args=cast("Sequence[InT] | None", in_args),
423
- call_id=call_id,
424
- ctx=ctx,
425
- )
480
+ if val_in_args and len(val_in_args) > 1:
481
+ stream = self._run_par_stream(in_args=val_in_args, call_id=call_id, ctx=ctx)
426
482
  else:
427
483
  stream = self._run_single_stream(
428
484
  chat_inputs=chat_inputs,
429
- in_packet=in_packet,
430
- in_args=cast("InT | None", in_args),
485
+ in_args=val_in_args[0] if val_in_args else None,
431
486
  forgetful=forgetful,
432
487
  call_id=call_id,
433
488
  ctx=ctx,
@@ -435,12 +490,10 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
435
490
  async for event in stream:
436
491
  yield event
437
492
 
438
- # yield ProcFinishEvent(proc_name=self.name, call_id=call_id, data=None)
439
-
440
493
  @final
441
494
  def as_tool(
442
495
  self, tool_name: str, tool_description: str
443
- ) -> BaseTool[InT, OutT_co, Any]: # type: ignore[override]
496
+ ) -> BaseTool[InT, OutT, Any]: # type: ignore[override]
444
497
  # TODO: stream tools
445
498
  processor_instance = self
446
499
  in_type = processor_instance.in_type
@@ -455,13 +508,11 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
455
508
  name: str = tool_name
456
509
  description: str = tool_description
457
510
 
458
- async def run(
459
- self, inp: InT, ctx: RunContext[CtxT] | None = None
460
- ) -> OutT_co:
511
+ async def run(self, inp: InT, ctx: RunContext[CtxT] | None = None) -> OutT:
461
512
  result = await processor_instance.run(
462
513
  in_args=in_type.model_validate(inp), forgetful=True, ctx=ctx
463
514
  )
464
515
 
465
516
  return result.payloads[0]
466
517
 
467
- return ProcessorTool() # type: ignore[return-value]
518
+ return ProcessorTool()