grasp_agents 0.5.3__py3-none-any.whl → 0.5.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_agents/__init__.py +4 -6
- grasp_agents/errors.py +80 -18
- grasp_agents/llm_agent.py +106 -146
- grasp_agents/llm_agent_memory.py +1 -1
- grasp_agents/llm_policy_executor.py +17 -15
- grasp_agents/packet.py +23 -4
- grasp_agents/packet_pool.py +117 -50
- grasp_agents/printer.py +9 -5
- grasp_agents/processor.py +217 -166
- grasp_agents/prompt_builder.py +75 -138
- grasp_agents/run_context.py +3 -16
- grasp_agents/runner.py +110 -21
- grasp_agents/typing/events.py +8 -4
- grasp_agents/typing/io.py +1 -8
- grasp_agents/workflow/looped_workflow.py +13 -19
- grasp_agents/workflow/sequential_workflow.py +6 -10
- grasp_agents/workflow/workflow_processor.py +23 -16
- {grasp_agents-0.5.3.dist-info → grasp_agents-0.5.5.dist-info}/METADATA +1 -1
- {grasp_agents-0.5.3.dist-info → grasp_agents-0.5.5.dist-info}/RECORD +21 -22
- grasp_agents/comm_processor.py +0 -214
- {grasp_agents-0.5.3.dist-info → grasp_agents-0.5.5.dist-info}/WHEEL +0 -0
- {grasp_agents-0.5.3.dist-info → grasp_agents-0.5.5.dist-info}/licenses/LICENSE.md +0 -0
grasp_agents/processor.py
CHANGED
@@ -2,13 +2,18 @@ import asyncio
|
|
2
2
|
import logging
|
3
3
|
from abc import ABC
|
4
4
|
from collections.abc import AsyncIterator, Sequence
|
5
|
-
from typing import Any, ClassVar, Generic, cast, final
|
5
|
+
from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast, final
|
6
6
|
from uuid import uuid4
|
7
7
|
|
8
8
|
from pydantic import BaseModel, TypeAdapter
|
9
9
|
from pydantic import ValidationError as PydanticValidationError
|
10
10
|
|
11
|
-
from .errors import
|
11
|
+
from .errors import (
|
12
|
+
PacketRoutingError,
|
13
|
+
ProcInputValidationError,
|
14
|
+
ProcOutputValidationError,
|
15
|
+
ProcRunError,
|
16
|
+
)
|
12
17
|
from .generics_utils import AutoInstanceAttributesMixin
|
13
18
|
from .memory import DummyMemory, MemT
|
14
19
|
from .packet import Packet
|
@@ -20,22 +25,36 @@ from .typing.events import (
|
|
20
25
|
ProcStreamingErrorData,
|
21
26
|
ProcStreamingErrorEvent,
|
22
27
|
)
|
23
|
-
from .typing.io import InT,
|
28
|
+
from .typing.io import InT, OutT, ProcName
|
24
29
|
from .typing.tool import BaseTool
|
25
30
|
from .utils import stream_concurrent
|
26
31
|
|
27
32
|
logger = logging.getLogger(__name__)
|
28
33
|
|
34
|
+
_OutT_contra = TypeVar("_OutT_contra", contravariant=True)
|
35
|
+
|
36
|
+
|
37
|
+
class RecipientSelector(Protocol[_OutT_contra, CtxT]):
|
38
|
+
def __call__(
|
39
|
+
self, output: _OutT_contra, ctx: RunContext[CtxT] | None
|
40
|
+
) -> list[ProcName] | None: ...
|
41
|
+
|
29
42
|
|
30
|
-
class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT,
|
43
|
+
class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT, MemT, CtxT]):
|
31
44
|
_generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
|
32
45
|
0: "_in_type",
|
33
46
|
1: "_out_type",
|
34
47
|
}
|
35
48
|
|
36
|
-
def __init__(
|
49
|
+
def __init__(
|
50
|
+
self,
|
51
|
+
name: ProcName,
|
52
|
+
max_retries: int = 0,
|
53
|
+
recipients: list[ProcName] | None = None,
|
54
|
+
**kwargs: Any,
|
55
|
+
) -> None:
|
37
56
|
self._in_type: type[InT]
|
38
|
-
self._out_type: type[
|
57
|
+
self._out_type: type[OutT]
|
39
58
|
|
40
59
|
super().__init__()
|
41
60
|
|
@@ -43,12 +62,19 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
|
|
43
62
|
self._memory: MemT = cast("MemT", DummyMemory())
|
44
63
|
self._max_retries: int = max_retries
|
45
64
|
|
65
|
+
self.recipients = recipients
|
66
|
+
|
67
|
+
self.recipient_selector: RecipientSelector[OutT, CtxT] | None
|
68
|
+
if not hasattr(type(self), "recipient_selector"):
|
69
|
+
# Set to None if not defined in the subclass
|
70
|
+
self.recipient_selector = None
|
71
|
+
|
46
72
|
@property
|
47
73
|
def in_type(self) -> type[InT]:
|
48
74
|
return self._in_type
|
49
75
|
|
50
76
|
@property
|
51
|
-
def out_type(self) -> type[
|
77
|
+
def out_type(self) -> type[OutT]:
|
52
78
|
return self._out_type
|
53
79
|
|
54
80
|
@property
|
@@ -68,65 +94,124 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
|
|
68
94
|
return str(uuid4())[:6] + "_" + self.name
|
69
95
|
return call_id
|
70
96
|
|
71
|
-
def
|
97
|
+
def _validate_inputs(
|
72
98
|
self,
|
99
|
+
call_id: str,
|
73
100
|
chat_inputs: Any | None = None,
|
74
101
|
in_packet: Packet[InT] | None = None,
|
75
|
-
in_args: InT | None = None,
|
76
|
-
) -> InT | None:
|
77
|
-
|
102
|
+
in_args: InT | list[InT] | None = None,
|
103
|
+
) -> list[InT] | None:
|
104
|
+
mult_inputs_err_message = (
|
78
105
|
"Only one of chat_inputs, in_args, or in_message must be provided."
|
79
106
|
)
|
107
|
+
err_kwargs = {"proc_name": self.name, "call_id": call_id}
|
108
|
+
|
80
109
|
if chat_inputs is not None and in_args is not None:
|
81
|
-
raise ProcInputValidationError(
|
110
|
+
raise ProcInputValidationError(
|
111
|
+
message=mult_inputs_err_message, **err_kwargs
|
112
|
+
)
|
82
113
|
if chat_inputs is not None and in_packet is not None:
|
83
|
-
raise ProcInputValidationError(
|
114
|
+
raise ProcInputValidationError(
|
115
|
+
message=mult_inputs_err_message, **err_kwargs
|
116
|
+
)
|
84
117
|
if in_args is not None and in_packet is not None:
|
85
|
-
raise ProcInputValidationError(
|
86
|
-
|
87
|
-
|
88
|
-
if len(in_packet.payloads) != 1:
|
89
|
-
raise ProcInputValidationError(
|
90
|
-
"Single input runs require exactly one payload in in_packet."
|
91
|
-
)
|
92
|
-
return in_packet.payloads[0]
|
93
|
-
|
94
|
-
return in_args
|
118
|
+
raise ProcInputValidationError(
|
119
|
+
message=mult_inputs_err_message, **err_kwargs
|
120
|
+
)
|
95
121
|
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
in_args
|
101
|
-
) -> Sequence[InT]:
|
102
|
-
if chat_inputs is not None:
|
122
|
+
if in_packet is not None and not in_packet.payloads:
|
123
|
+
raise ProcInputValidationError(
|
124
|
+
message="in_packet must contain at least one payload.", **err_kwargs
|
125
|
+
)
|
126
|
+
if in_args is not None and not in_args:
|
103
127
|
raise ProcInputValidationError(
|
104
|
-
"
|
105
|
-
"Use in_packet or in_args."
|
128
|
+
message="in_args must contain at least one argument.", **err_kwargs
|
106
129
|
)
|
107
|
-
|
108
|
-
|
130
|
+
|
131
|
+
if chat_inputs is not None:
|
132
|
+
return None
|
133
|
+
|
134
|
+
resolved_args: list[InT]
|
135
|
+
|
136
|
+
if isinstance(in_args, list):
|
137
|
+
_in_args = cast("list[Any]", in_args)
|
138
|
+
if all(isinstance(x, self.in_type) for x in _in_args):
|
139
|
+
resolved_args = cast("list[InT]", _in_args)
|
140
|
+
elif isinstance(_in_args, self.in_type):
|
141
|
+
resolved_args = cast("list[InT]", [_in_args])
|
142
|
+
else:
|
109
143
|
raise ProcInputValidationError(
|
110
|
-
"
|
144
|
+
message=f"in_args are neither of type {self.in_type} "
|
145
|
+
f"nor a sequence of {self.in_type}.",
|
146
|
+
**err_kwargs,
|
111
147
|
)
|
112
|
-
return in_packet.payloads
|
113
|
-
if in_args is not None:
|
114
|
-
return in_args
|
115
|
-
raise ProcInputValidationError(
|
116
|
-
"Parallel runs require either in_packet or in_args to be provided."
|
117
|
-
)
|
118
148
|
|
119
|
-
|
149
|
+
elif in_args is not None:
|
150
|
+
resolved_args = cast("list[InT]", [in_args])
|
151
|
+
|
152
|
+
else:
|
153
|
+
assert in_packet is not None
|
154
|
+
resolved_args = cast("list[InT]", in_packet.payloads)
|
155
|
+
|
156
|
+
try:
|
157
|
+
for args in resolved_args:
|
158
|
+
TypeAdapter(self._in_type).validate_python(args)
|
159
|
+
except PydanticValidationError as err:
|
160
|
+
raise ProcInputValidationError(message=str(err), **err_kwargs) from err
|
161
|
+
|
162
|
+
return resolved_args
|
163
|
+
|
164
|
+
def _validate_output(self, out_payload: OutT, call_id: str) -> OutT:
|
165
|
+
if out_payload is None:
|
166
|
+
return out_payload
|
120
167
|
try:
|
121
|
-
return
|
122
|
-
TypeAdapter(self._out_type).validate_python(payload)
|
123
|
-
for payload in out_payloads
|
124
|
-
]
|
168
|
+
return TypeAdapter(self._out_type).validate_python(out_payload)
|
125
169
|
except PydanticValidationError as err:
|
126
170
|
raise ProcOutputValidationError(
|
127
|
-
|
171
|
+
schema=self._out_type, proc_name=self.name, call_id=call_id
|
128
172
|
) from err
|
129
173
|
|
174
|
+
def _validate_recipients(
|
175
|
+
self, recipients: list[ProcName] | None, call_id: str
|
176
|
+
) -> None:
|
177
|
+
for r in recipients or []:
|
178
|
+
if r not in (self.recipients or []):
|
179
|
+
raise PacketRoutingError(
|
180
|
+
proc_name=self.name,
|
181
|
+
call_id=call_id,
|
182
|
+
selected_recipient=r,
|
183
|
+
allowed_recipients=cast("list[str]", self.recipients),
|
184
|
+
)
|
185
|
+
|
186
|
+
def _validate_par_recipients(
|
187
|
+
self, out_packets: Sequence[Packet[OutT]], call_id: str
|
188
|
+
) -> None:
|
189
|
+
recipient_sets = [set(p.recipients or []) for p in out_packets]
|
190
|
+
same_recipients = all(rs == recipient_sets[0] for rs in recipient_sets)
|
191
|
+
if not same_recipients:
|
192
|
+
raise PacketRoutingError(
|
193
|
+
proc_name=self.name,
|
194
|
+
call_id=call_id,
|
195
|
+
message="Parallel runs must return the same recipients "
|
196
|
+
f"[proc_name={self.name}; call_id={call_id}]",
|
197
|
+
)
|
198
|
+
|
199
|
+
@final
|
200
|
+
def _select_recipients(
|
201
|
+
self, output: OutT, ctx: RunContext[CtxT] | None = None
|
202
|
+
) -> list[ProcName] | None:
|
203
|
+
if self.recipient_selector:
|
204
|
+
return self.recipient_selector(output=output, ctx=ctx)
|
205
|
+
|
206
|
+
return self.recipients
|
207
|
+
|
208
|
+
def add_recipient_selector(
|
209
|
+
self, func: RecipientSelector[OutT, CtxT]
|
210
|
+
) -> RecipientSelector[OutT, CtxT]:
|
211
|
+
self.recipient_selector = func
|
212
|
+
|
213
|
+
return func
|
214
|
+
|
130
215
|
async def _process(
|
131
216
|
self,
|
132
217
|
chat_inputs: Any | None = None,
|
@@ -135,13 +220,8 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
|
|
135
220
|
memory: MemT,
|
136
221
|
call_id: str,
|
137
222
|
ctx: RunContext[CtxT] | None = None,
|
138
|
-
) ->
|
139
|
-
|
140
|
-
raise ProcInputValidationError(
|
141
|
-
"Default implementation of _process requires in_args"
|
142
|
-
)
|
143
|
-
|
144
|
-
return cast("Sequence[OutT_co]", in_args)
|
223
|
+
) -> OutT:
|
224
|
+
return cast("OutT", in_args)
|
145
225
|
|
146
226
|
async def _process_stream(
|
147
227
|
self,
|
@@ -152,99 +232,86 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
|
|
152
232
|
call_id: str,
|
153
233
|
ctx: RunContext[CtxT] | None = None,
|
154
234
|
) -> AsyncIterator[Event[Any]]:
|
155
|
-
|
156
|
-
|
157
|
-
"Default implementation of _process_stream requires in_args"
|
158
|
-
)
|
159
|
-
outputs = cast("Sequence[OutT_co]", in_args)
|
160
|
-
for out in outputs:
|
161
|
-
yield ProcPayloadOutputEvent(data=out, proc_name=self.name, call_id=call_id)
|
235
|
+
output = cast("OutT", in_args)
|
236
|
+
yield ProcPayloadOutputEvent(data=output, proc_name=self.name, call_id=call_id)
|
162
237
|
|
163
238
|
async def _run_single_once(
|
164
239
|
self,
|
165
240
|
chat_inputs: Any | None = None,
|
166
241
|
*,
|
167
|
-
in_packet: Packet[InT] | None = None,
|
168
242
|
in_args: InT | None = None,
|
169
243
|
forgetful: bool = False,
|
170
244
|
call_id: str,
|
171
245
|
ctx: RunContext[CtxT] | None = None,
|
172
|
-
) -> Packet[
|
173
|
-
resolved_in_args = self._validate_and_resolve_single_input(
|
174
|
-
chat_inputs=chat_inputs, in_packet=in_packet, in_args=in_args
|
175
|
-
)
|
246
|
+
) -> Packet[OutT]:
|
176
247
|
_memory = self.memory.model_copy(deep=True) if forgetful else self.memory
|
177
248
|
|
178
|
-
|
249
|
+
output = await self._process(
|
179
250
|
chat_inputs=chat_inputs,
|
180
|
-
in_args=
|
251
|
+
in_args=in_args,
|
181
252
|
memory=_memory,
|
182
253
|
call_id=call_id,
|
183
254
|
ctx=ctx,
|
184
255
|
)
|
185
|
-
|
256
|
+
val_output = self._validate_output(output, call_id=call_id)
|
257
|
+
|
258
|
+
recipients = self._select_recipients(output=val_output, ctx=ctx)
|
259
|
+
self._validate_recipients(recipients, call_id=call_id)
|
186
260
|
|
187
|
-
return Packet(payloads=
|
261
|
+
return Packet(payloads=[val_output], sender=self.name, recipients=recipients)
|
188
262
|
|
189
263
|
async def _run_single(
|
190
264
|
self,
|
191
265
|
chat_inputs: Any | None = None,
|
192
266
|
*,
|
193
|
-
in_packet: Packet[InT] | None = None,
|
194
267
|
in_args: InT | None = None,
|
195
268
|
forgetful: bool = False,
|
196
269
|
call_id: str,
|
197
270
|
ctx: RunContext[CtxT] | None = None,
|
198
|
-
) -> Packet[
|
271
|
+
) -> Packet[OutT]:
|
199
272
|
n_attempt = 0
|
200
273
|
while n_attempt <= self.max_retries:
|
201
274
|
try:
|
202
275
|
return await self._run_single_once(
|
203
276
|
chat_inputs=chat_inputs,
|
204
|
-
in_packet=in_packet,
|
205
277
|
in_args=in_args,
|
206
278
|
forgetful=forgetful,
|
207
279
|
call_id=call_id,
|
208
280
|
ctx=ctx,
|
209
281
|
)
|
210
282
|
except Exception as err:
|
283
|
+
err_message = (
|
284
|
+
f"\nProcessor run failed [proc_name={self.name}; call_id={call_id}]"
|
285
|
+
)
|
211
286
|
n_attempt += 1
|
212
287
|
if n_attempt > self.max_retries:
|
213
288
|
if n_attempt == 1:
|
214
|
-
logger.warning(f"
|
289
|
+
logger.warning(f"{err_message}:\n{err}")
|
215
290
|
if n_attempt > 1:
|
216
|
-
logger.warning(f"
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
291
|
+
logger.warning(f"{err_message} after retrying:\n{err}")
|
292
|
+
raise ProcRunError(proc_name=self.name, call_id=call_id) from err
|
293
|
+
|
294
|
+
logger.warning(f"{err_message} (retry attempt {n_attempt}):\n{err}")
|
295
|
+
|
296
|
+
raise ProcRunError(proc_name=self.name, call_id=call_id)
|
221
297
|
|
222
298
|
async def _run_par(
|
223
|
-
self,
|
224
|
-
|
225
|
-
*,
|
226
|
-
in_packet: Packet[InT] | None = None,
|
227
|
-
in_args: Sequence[InT] | None = None,
|
228
|
-
call_id: str,
|
229
|
-
ctx: RunContext[CtxT] | None = None,
|
230
|
-
) -> Packet[OutT_co]:
|
231
|
-
par_inputs = self._validate_and_resolve_parallel_inputs(
|
232
|
-
chat_inputs=chat_inputs, in_packet=in_packet, in_args=in_args
|
233
|
-
)
|
299
|
+
self, in_args: list[InT], call_id: str, ctx: RunContext[CtxT] | None = None
|
300
|
+
) -> Packet[OutT]:
|
234
301
|
tasks = [
|
235
302
|
self._run_single(
|
236
303
|
in_args=inp, forgetful=True, call_id=f"{call_id}/{idx}", ctx=ctx
|
237
304
|
)
|
238
|
-
for idx, inp in enumerate(
|
305
|
+
for idx, inp in enumerate(in_args)
|
239
306
|
]
|
240
307
|
out_packets = await asyncio.gather(*tasks)
|
241
308
|
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
],
|
309
|
+
self._validate_par_recipients(out_packets, call_id=call_id)
|
310
|
+
|
311
|
+
return Packet(
|
312
|
+
payloads=[out_packet.payloads[0] for out_packet in out_packets],
|
247
313
|
sender=self.name,
|
314
|
+
recipients=out_packets[0].recipients,
|
248
315
|
)
|
249
316
|
|
250
317
|
async def run(
|
@@ -252,27 +319,25 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
|
|
252
319
|
chat_inputs: Any | None = None,
|
253
320
|
*,
|
254
321
|
in_packet: Packet[InT] | None = None,
|
255
|
-
in_args: InT |
|
322
|
+
in_args: InT | list[InT] | None = None,
|
256
323
|
forgetful: bool = False,
|
257
324
|
call_id: str | None = None,
|
258
325
|
ctx: RunContext[CtxT] | None = None,
|
259
|
-
) -> Packet[
|
326
|
+
) -> Packet[OutT]:
|
260
327
|
call_id = self._generate_call_id(call_id)
|
261
328
|
|
262
|
-
|
263
|
-
|
264
|
-
):
|
265
|
-
return await self._run_par(
|
266
|
-
chat_inputs=chat_inputs,
|
267
|
-
in_packet=in_packet,
|
268
|
-
in_args=cast("Sequence[InT] | None", in_args),
|
269
|
-
call_id=call_id,
|
270
|
-
ctx=ctx,
|
271
|
-
)
|
272
|
-
return await self._run_single( # type: ignore[return]
|
329
|
+
val_in_args = self._validate_inputs(
|
330
|
+
call_id=call_id,
|
273
331
|
chat_inputs=chat_inputs,
|
274
332
|
in_packet=in_packet,
|
275
|
-
in_args=
|
333
|
+
in_args=in_args,
|
334
|
+
)
|
335
|
+
|
336
|
+
if val_in_args and len(val_in_args) > 1:
|
337
|
+
return await self._run_par(in_args=val_in_args, call_id=call_id, ctx=ctx)
|
338
|
+
return await self._run_single(
|
339
|
+
chat_inputs=chat_inputs,
|
340
|
+
in_args=val_in_args[0] if val_in_args else None,
|
276
341
|
forgetful=forgetful,
|
277
342
|
call_id=call_id,
|
278
343
|
ctx=ctx,
|
@@ -282,31 +347,35 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
|
|
282
347
|
self,
|
283
348
|
chat_inputs: Any | None = None,
|
284
349
|
*,
|
285
|
-
in_packet: Packet[InT] | None = None,
|
286
350
|
in_args: InT | None = None,
|
287
351
|
forgetful: bool = False,
|
288
352
|
call_id: str,
|
289
353
|
ctx: RunContext[CtxT] | None = None,
|
290
354
|
) -> AsyncIterator[Event[Any]]:
|
291
|
-
resolved_in_args = self._validate_and_resolve_single_input(
|
292
|
-
chat_inputs=chat_inputs, in_packet=in_packet, in_args=in_args
|
293
|
-
)
|
294
355
|
_memory = self.memory.model_copy(deep=True) if forgetful else self.memory
|
295
356
|
|
296
|
-
|
357
|
+
output: OutT | None = None
|
297
358
|
async for event in self._process_stream(
|
298
359
|
chat_inputs=chat_inputs,
|
299
|
-
in_args=
|
360
|
+
in_args=in_args,
|
300
361
|
memory=_memory,
|
301
362
|
call_id=call_id,
|
302
363
|
ctx=ctx,
|
303
364
|
):
|
304
365
|
if isinstance(event, ProcPayloadOutputEvent):
|
305
|
-
|
366
|
+
output = event.data
|
306
367
|
yield event
|
307
368
|
|
308
|
-
|
309
|
-
|
369
|
+
assert output is not None
|
370
|
+
|
371
|
+
val_output = self._validate_output(output, call_id=call_id)
|
372
|
+
|
373
|
+
recipients = self._select_recipients(output=val_output, ctx=ctx)
|
374
|
+
self._validate_recipients(recipients, call_id=call_id)
|
375
|
+
|
376
|
+
out_packet = Packet[OutT](
|
377
|
+
payloads=[val_output], sender=self.name, recipients=recipients
|
378
|
+
)
|
310
379
|
|
311
380
|
yield ProcPacketOutputEvent(
|
312
381
|
data=out_packet, proc_name=self.name, call_id=call_id
|
@@ -316,7 +385,6 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
|
|
316
385
|
self,
|
317
386
|
chat_inputs: Any | None = None,
|
318
387
|
*,
|
319
|
-
in_packet: Packet[InT] | None = None,
|
320
388
|
in_args: InT | None = None,
|
321
389
|
forgetful: bool = False,
|
322
390
|
call_id: str,
|
@@ -327,7 +395,6 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
|
|
327
395
|
try:
|
328
396
|
async for event in self._run_single_stream_once(
|
329
397
|
chat_inputs=chat_inputs,
|
330
|
-
in_packet=in_packet,
|
331
398
|
in_args=in_args,
|
332
399
|
forgetful=forgetful,
|
333
400
|
call_id=call_id,
|
@@ -343,56 +410,48 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
|
|
343
410
|
data=err_data, proc_name=self.name, call_id=call_id
|
344
411
|
)
|
345
412
|
|
413
|
+
err_message = (
|
414
|
+
"\nStreaming processor run failed "
|
415
|
+
f"[proc_name={self.name}; call_id={call_id}]"
|
416
|
+
)
|
417
|
+
|
346
418
|
n_attempt += 1
|
347
419
|
if n_attempt > self.max_retries:
|
348
420
|
if n_attempt == 1:
|
349
|
-
logger.warning(f"
|
421
|
+
logger.warning(f"{err_message}:\n{err}")
|
350
422
|
if n_attempt > 1:
|
351
|
-
logger.warning(
|
352
|
-
|
353
|
-
)
|
354
|
-
return
|
423
|
+
logger.warning(f"{err_message} after retrying:\n{err}")
|
424
|
+
raise ProcRunError(proc_name=self.name, call_id=call_id) from err
|
355
425
|
|
356
|
-
logger.warning(
|
357
|
-
"\nStreaming processor run failed "
|
358
|
-
f"(retry attempt {n_attempt}):\n{err}"
|
359
|
-
)
|
426
|
+
logger.warning(f"{err_message} (retry attempt {n_attempt}):\n{err}")
|
360
427
|
|
361
428
|
async def _run_par_stream(
|
362
429
|
self,
|
363
|
-
|
364
|
-
*,
|
365
|
-
in_packet: Packet[InT] | None = None,
|
366
|
-
in_args: Sequence[InT] | None = None,
|
430
|
+
in_args: list[InT],
|
367
431
|
call_id: str,
|
368
432
|
ctx: RunContext[CtxT] | None = None,
|
369
433
|
) -> AsyncIterator[Event[Any]]:
|
370
|
-
par_inputs = self._validate_and_resolve_parallel_inputs(
|
371
|
-
chat_inputs=chat_inputs, in_packet=in_packet, in_args=in_args
|
372
|
-
)
|
373
434
|
streams = [
|
374
435
|
self._run_single_stream(
|
375
436
|
in_args=inp, forgetful=True, call_id=f"{call_id}/{idx}", ctx=ctx
|
376
437
|
)
|
377
|
-
for idx, inp in enumerate(
|
438
|
+
for idx, inp in enumerate(in_args)
|
378
439
|
]
|
379
440
|
|
380
|
-
out_packets_map: dict[int, Packet[
|
381
|
-
range(len(streams)), None
|
382
|
-
)
|
383
|
-
|
441
|
+
out_packets_map: dict[int, Packet[OutT]] = {}
|
384
442
|
async for idx, event in stream_concurrent(streams):
|
385
443
|
if isinstance(event, ProcPacketOutputEvent):
|
386
444
|
out_packets_map[idx] = event.data
|
387
445
|
else:
|
388
446
|
yield event
|
389
447
|
|
390
|
-
out_packet = Packet(
|
448
|
+
out_packet = Packet(
|
391
449
|
payloads=[
|
392
|
-
|
393
|
-
for out_packet in out_packets_map.
|
450
|
+
out_packet.payloads[0]
|
451
|
+
for _, out_packet in sorted(out_packets_map.items())
|
394
452
|
],
|
395
453
|
sender=self.name,
|
454
|
+
recipients=out_packets_map[0].recipients,
|
396
455
|
)
|
397
456
|
|
398
457
|
yield ProcPacketOutputEvent(
|
@@ -404,30 +463,26 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
|
|
404
463
|
chat_inputs: Any | None = None,
|
405
464
|
*,
|
406
465
|
in_packet: Packet[InT] | None = None,
|
407
|
-
in_args: InT |
|
466
|
+
in_args: InT | list[InT] | None = None,
|
408
467
|
forgetful: bool = False,
|
409
468
|
call_id: str | None = None,
|
410
469
|
ctx: RunContext[CtxT] | None = None,
|
411
470
|
) -> AsyncIterator[Event[Any]]:
|
412
471
|
call_id = self._generate_call_id(call_id)
|
413
472
|
|
414
|
-
|
473
|
+
val_in_args = self._validate_inputs(
|
474
|
+
call_id=call_id,
|
475
|
+
chat_inputs=chat_inputs,
|
476
|
+
in_packet=in_packet,
|
477
|
+
in_args=in_args,
|
478
|
+
)
|
415
479
|
|
416
|
-
if
|
417
|
-
|
418
|
-
):
|
419
|
-
stream = self._run_par_stream(
|
420
|
-
chat_inputs=chat_inputs,
|
421
|
-
in_packet=in_packet,
|
422
|
-
in_args=cast("Sequence[InT] | None", in_args),
|
423
|
-
call_id=call_id,
|
424
|
-
ctx=ctx,
|
425
|
-
)
|
480
|
+
if val_in_args and len(val_in_args) > 1:
|
481
|
+
stream = self._run_par_stream(in_args=val_in_args, call_id=call_id, ctx=ctx)
|
426
482
|
else:
|
427
483
|
stream = self._run_single_stream(
|
428
484
|
chat_inputs=chat_inputs,
|
429
|
-
|
430
|
-
in_args=cast("InT | None", in_args),
|
485
|
+
in_args=val_in_args[0] if val_in_args else None,
|
431
486
|
forgetful=forgetful,
|
432
487
|
call_id=call_id,
|
433
488
|
ctx=ctx,
|
@@ -435,12 +490,10 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
|
|
435
490
|
async for event in stream:
|
436
491
|
yield event
|
437
492
|
|
438
|
-
# yield ProcFinishEvent(proc_name=self.name, call_id=call_id, data=None)
|
439
|
-
|
440
493
|
@final
|
441
494
|
def as_tool(
|
442
495
|
self, tool_name: str, tool_description: str
|
443
|
-
) -> BaseTool[InT,
|
496
|
+
) -> BaseTool[InT, OutT, Any]: # type: ignore[override]
|
444
497
|
# TODO: stream tools
|
445
498
|
processor_instance = self
|
446
499
|
in_type = processor_instance.in_type
|
@@ -455,13 +508,11 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
|
|
455
508
|
name: str = tool_name
|
456
509
|
description: str = tool_description
|
457
510
|
|
458
|
-
async def run(
|
459
|
-
self, inp: InT, ctx: RunContext[CtxT] | None = None
|
460
|
-
) -> OutT_co:
|
511
|
+
async def run(self, inp: InT, ctx: RunContext[CtxT] | None = None) -> OutT:
|
461
512
|
result = await processor_instance.run(
|
462
513
|
in_args=in_type.model_validate(inp), forgetful=True, ctx=ctx
|
463
514
|
)
|
464
515
|
|
465
516
|
return result.payloads[0]
|
466
517
|
|
467
|
-
return ProcessorTool()
|
518
|
+
return ProcessorTool()
|