grasp_agents 0.5.2__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_agents/__init__.py +3 -4
- grasp_agents/errors.py +80 -18
- grasp_agents/llm_agent.py +36 -58
- grasp_agents/llm_policy_executor.py +2 -2
- grasp_agents/packet.py +23 -4
- grasp_agents/packet_pool.py +117 -50
- grasp_agents/printer.py +8 -4
- grasp_agents/processor.py +208 -163
- grasp_agents/prompt_builder.py +80 -105
- grasp_agents/run_context.py +1 -9
- grasp_agents/runner.py +110 -21
- grasp_agents/typing/events.py +8 -4
- grasp_agents/typing/io.py +1 -1
- grasp_agents/workflow/looped_workflow.py +47 -52
- grasp_agents/workflow/sequential_workflow.py +6 -10
- grasp_agents/workflow/workflow_processor.py +7 -15
- {grasp_agents-0.5.2.dist-info → grasp_agents-0.5.4.dist-info}/METADATA +1 -1
- {grasp_agents-0.5.2.dist-info → grasp_agents-0.5.4.dist-info}/RECORD +20 -21
- grasp_agents/comm_processor.py +0 -214
- {grasp_agents-0.5.2.dist-info → grasp_agents-0.5.4.dist-info}/WHEEL +0 -0
- {grasp_agents-0.5.2.dist-info → grasp_agents-0.5.4.dist-info}/licenses/LICENSE.md +0 -0
grasp_agents/printer.py
CHANGED
@@ -219,12 +219,12 @@ async def print_event_stream(
|
|
219
219
|
) -> None:
|
220
220
|
color = _get_color(event, Role.ASSISTANT)
|
221
221
|
|
222
|
-
if isinstance(event,
|
223
|
-
src = "processor"
|
224
|
-
elif isinstance(event, WorkflowResultEvent):
|
222
|
+
if isinstance(event, WorkflowResultEvent):
|
225
223
|
src = "workflow"
|
226
|
-
|
224
|
+
elif isinstance(event, RunResultEvent):
|
227
225
|
src = "run"
|
226
|
+
else:
|
227
|
+
src = "processor"
|
228
228
|
|
229
229
|
text = f"\n<{event.proc_name}> [{event.call_id}]\n"
|
230
230
|
|
@@ -232,6 +232,10 @@ async def print_event_stream(
|
|
232
232
|
text += f"<{src} output>\n"
|
233
233
|
for p in event.data.payloads:
|
234
234
|
if isinstance(p, BaseModel):
|
235
|
+
for field_info in type(p).model_fields.values():
|
236
|
+
if field_info.exclude:
|
237
|
+
field_info.exclude = False
|
238
|
+
type(p).model_rebuild(force=True)
|
235
239
|
p_str = p.model_dump_json(indent=2)
|
236
240
|
else:
|
237
241
|
try:
|
grasp_agents/processor.py
CHANGED
@@ -2,13 +2,18 @@ import asyncio
|
|
2
2
|
import logging
|
3
3
|
from abc import ABC
|
4
4
|
from collections.abc import AsyncIterator, Sequence
|
5
|
-
from typing import Any, ClassVar, Generic, cast, final
|
5
|
+
from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast, final
|
6
6
|
from uuid import uuid4
|
7
7
|
|
8
8
|
from pydantic import BaseModel, TypeAdapter
|
9
9
|
from pydantic import ValidationError as PydanticValidationError
|
10
10
|
|
11
|
-
from .errors import
|
11
|
+
from .errors import (
|
12
|
+
PacketRoutingError,
|
13
|
+
ProcInputValidationError,
|
14
|
+
ProcOutputValidationError,
|
15
|
+
ProcRunError,
|
16
|
+
)
|
12
17
|
from .generics_utils import AutoInstanceAttributesMixin
|
13
18
|
from .memory import DummyMemory, MemT
|
14
19
|
from .packet import Packet
|
@@ -20,35 +25,51 @@ from .typing.events import (
|
|
20
25
|
ProcStreamingErrorData,
|
21
26
|
ProcStreamingErrorEvent,
|
22
27
|
)
|
23
|
-
from .typing.io import InT,
|
28
|
+
from .typing.io import InT, OutT, ProcName
|
24
29
|
from .typing.tool import BaseTool
|
25
30
|
from .utils import stream_concurrent
|
26
31
|
|
27
32
|
logger = logging.getLogger(__name__)
|
28
33
|
|
34
|
+
_OutT_contra = TypeVar("_OutT_contra", contravariant=True)
|
35
|
+
|
36
|
+
|
37
|
+
class SelectRecipientsHandler(Protocol[_OutT_contra, CtxT]):
|
38
|
+
def __call__(
|
39
|
+
self, output: _OutT_contra, ctx: RunContext[CtxT] | None
|
40
|
+
) -> list[ProcName] | None: ...
|
41
|
+
|
29
42
|
|
30
|
-
class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT,
|
43
|
+
class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT, MemT, CtxT]):
|
31
44
|
_generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
|
32
45
|
0: "_in_type",
|
33
46
|
1: "_out_type",
|
34
47
|
}
|
35
48
|
|
36
|
-
def __init__(
|
49
|
+
def __init__(
|
50
|
+
self,
|
51
|
+
name: ProcName,
|
52
|
+
max_retries: int = 0,
|
53
|
+
recipients: list[ProcName] | None = None,
|
54
|
+
**kwargs: Any,
|
55
|
+
) -> None:
|
37
56
|
self._in_type: type[InT]
|
38
|
-
self._out_type: type[
|
57
|
+
self._out_type: type[OutT]
|
39
58
|
|
40
59
|
super().__init__()
|
41
60
|
|
42
61
|
self._name: ProcName = name
|
43
62
|
self._memory: MemT = cast("MemT", DummyMemory())
|
44
63
|
self._max_retries: int = max_retries
|
64
|
+
self.recipients = recipients
|
65
|
+
self.select_recipients_impl: SelectRecipientsHandler[OutT, CtxT] | None = None
|
45
66
|
|
46
67
|
@property
|
47
68
|
def in_type(self) -> type[InT]:
|
48
69
|
return self._in_type
|
49
70
|
|
50
71
|
@property
|
51
|
-
def out_type(self) -> type[
|
72
|
+
def out_type(self) -> type[OutT]:
|
52
73
|
return self._out_type
|
53
74
|
|
54
75
|
@property
|
@@ -68,65 +89,108 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
|
|
68
89
|
return str(uuid4())[:6] + "_" + self.name
|
69
90
|
return call_id
|
70
91
|
|
71
|
-
def
|
92
|
+
def _validate_inputs(
|
72
93
|
self,
|
94
|
+
call_id: str,
|
73
95
|
chat_inputs: Any | None = None,
|
74
96
|
in_packet: Packet[InT] | None = None,
|
75
|
-
in_args: InT | None = None,
|
76
|
-
) -> InT | None:
|
77
|
-
|
97
|
+
in_args: InT | Sequence[InT] | None = None,
|
98
|
+
) -> Sequence[InT] | None:
|
99
|
+
mult_inputs_err_message = (
|
78
100
|
"Only one of chat_inputs, in_args, or in_message must be provided."
|
79
101
|
)
|
102
|
+
err_kwargs = {"proc_name": self.name, "call_id": call_id}
|
103
|
+
|
80
104
|
if chat_inputs is not None and in_args is not None:
|
81
|
-
raise ProcInputValidationError(
|
105
|
+
raise ProcInputValidationError(
|
106
|
+
message=mult_inputs_err_message, **err_kwargs
|
107
|
+
)
|
82
108
|
if chat_inputs is not None and in_packet is not None:
|
83
|
-
raise ProcInputValidationError(
|
109
|
+
raise ProcInputValidationError(
|
110
|
+
message=mult_inputs_err_message, **err_kwargs
|
111
|
+
)
|
84
112
|
if in_args is not None and in_packet is not None:
|
85
|
-
raise ProcInputValidationError(
|
86
|
-
|
87
|
-
|
88
|
-
if len(in_packet.payloads) != 1:
|
89
|
-
raise ProcInputValidationError(
|
90
|
-
"Single input runs require exactly one payload in in_packet."
|
91
|
-
)
|
92
|
-
return in_packet.payloads[0]
|
93
|
-
|
94
|
-
return in_args
|
113
|
+
raise ProcInputValidationError(
|
114
|
+
message=mult_inputs_err_message, **err_kwargs
|
115
|
+
)
|
95
116
|
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
in_args
|
101
|
-
) -> Sequence[InT]:
|
102
|
-
if chat_inputs is not None:
|
117
|
+
if in_packet is not None and not in_packet.payloads:
|
118
|
+
raise ProcInputValidationError(
|
119
|
+
message="in_packet must contain at least one payload.", **err_kwargs
|
120
|
+
)
|
121
|
+
if in_args is not None and not in_args:
|
103
122
|
raise ProcInputValidationError(
|
104
|
-
"
|
105
|
-
"Use in_packet or in_args."
|
123
|
+
message="in_args must contain at least one argument.", **err_kwargs
|
106
124
|
)
|
107
|
-
|
108
|
-
|
125
|
+
|
126
|
+
if chat_inputs is not None:
|
127
|
+
return None
|
128
|
+
|
129
|
+
resolved_args: Sequence[InT]
|
130
|
+
|
131
|
+
if isinstance(in_args, Sequence):
|
132
|
+
_in_args = cast("Sequence[Any]", in_args)
|
133
|
+
if all(isinstance(x, self.in_type) for x in _in_args):
|
134
|
+
resolved_args = cast("Sequence[InT]", _in_args)
|
135
|
+
elif isinstance(_in_args, self.in_type):
|
136
|
+
resolved_args = cast("Sequence[InT]", [_in_args])
|
137
|
+
else:
|
109
138
|
raise ProcInputValidationError(
|
110
|
-
"
|
139
|
+
message=f"in_args are neither of type {self.in_type} "
|
140
|
+
f"nor a sequence of {self.in_type}.",
|
141
|
+
**err_kwargs,
|
111
142
|
)
|
112
|
-
return in_packet.payloads
|
113
|
-
if in_args is not None:
|
114
|
-
return in_args
|
115
|
-
raise ProcInputValidationError(
|
116
|
-
"Parallel runs require either in_packet or in_args to be provided."
|
117
|
-
)
|
118
143
|
|
119
|
-
|
144
|
+
elif in_args is not None:
|
145
|
+
resolved_args = cast("Sequence[InT]", [in_args])
|
146
|
+
|
147
|
+
else:
|
148
|
+
assert in_packet is not None
|
149
|
+
resolved_args = in_packet.payloads
|
150
|
+
|
120
151
|
try:
|
121
|
-
|
122
|
-
TypeAdapter(self.
|
123
|
-
|
124
|
-
|
152
|
+
for args in resolved_args:
|
153
|
+
TypeAdapter(self._in_type).validate_python(args)
|
154
|
+
except PydanticValidationError as err:
|
155
|
+
raise ProcInputValidationError(message=str(err), **err_kwargs) from err
|
156
|
+
|
157
|
+
return resolved_args
|
158
|
+
|
159
|
+
def _validate_output(self, out_payload: OutT, call_id: str) -> OutT:
|
160
|
+
if out_payload is None:
|
161
|
+
return out_payload
|
162
|
+
try:
|
163
|
+
return TypeAdapter(self._out_type).validate_python(out_payload)
|
125
164
|
except PydanticValidationError as err:
|
126
165
|
raise ProcOutputValidationError(
|
127
|
-
|
166
|
+
schema=self._out_type, proc_name=self.name, call_id=call_id
|
128
167
|
) from err
|
129
168
|
|
169
|
+
def _validate_recipients(
|
170
|
+
self, recipients: Sequence[ProcName] | None, call_id: str
|
171
|
+
) -> None:
|
172
|
+
for r in recipients or []:
|
173
|
+
if r not in (self.recipients or []):
|
174
|
+
raise PacketRoutingError(
|
175
|
+
proc_name=self.name,
|
176
|
+
call_id=call_id,
|
177
|
+
selected_recipient=r,
|
178
|
+
allowed_recipients=cast("list[str]", self.recipients),
|
179
|
+
)
|
180
|
+
|
181
|
+
def _validate_par_recipients(
|
182
|
+
self, out_packets: Sequence[Packet[OutT]], call_id: str
|
183
|
+
) -> None:
|
184
|
+
recipient_sets = [set(p.recipients or []) for p in out_packets]
|
185
|
+
same_recipients = all(rs == recipient_sets[0] for rs in recipient_sets)
|
186
|
+
if not same_recipients:
|
187
|
+
raise PacketRoutingError(
|
188
|
+
proc_name=self.name,
|
189
|
+
call_id=call_id,
|
190
|
+
message="Parallel runs must return the same recipients "
|
191
|
+
f"[proc_name={self.name}; call_id={call_id}]",
|
192
|
+
)
|
193
|
+
|
130
194
|
async def _process(
|
131
195
|
self,
|
132
196
|
chat_inputs: Any | None = None,
|
@@ -135,13 +199,8 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
|
|
135
199
|
memory: MemT,
|
136
200
|
call_id: str,
|
137
201
|
ctx: RunContext[CtxT] | None = None,
|
138
|
-
) ->
|
139
|
-
|
140
|
-
raise ProcInputValidationError(
|
141
|
-
"Default implementation of _process requires in_args"
|
142
|
-
)
|
143
|
-
|
144
|
-
return cast("Sequence[OutT_co]", in_args)
|
202
|
+
) -> OutT:
|
203
|
+
return cast("OutT", in_args)
|
145
204
|
|
146
205
|
async def _process_stream(
|
147
206
|
self,
|
@@ -152,99 +211,86 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
|
|
152
211
|
call_id: str,
|
153
212
|
ctx: RunContext[CtxT] | None = None,
|
154
213
|
) -> AsyncIterator[Event[Any]]:
|
155
|
-
|
156
|
-
|
157
|
-
"Default implementation of _process_stream requires in_args"
|
158
|
-
)
|
159
|
-
outputs = cast("Sequence[OutT_co]", in_args)
|
160
|
-
for out in outputs:
|
161
|
-
yield ProcPayloadOutputEvent(data=out, proc_name=self.name, call_id=call_id)
|
214
|
+
output = cast("OutT", in_args)
|
215
|
+
yield ProcPayloadOutputEvent(data=output, proc_name=self.name, call_id=call_id)
|
162
216
|
|
163
217
|
async def _run_single_once(
|
164
218
|
self,
|
165
219
|
chat_inputs: Any | None = None,
|
166
220
|
*,
|
167
|
-
in_packet: Packet[InT] | None = None,
|
168
221
|
in_args: InT | None = None,
|
169
222
|
forgetful: bool = False,
|
170
223
|
call_id: str,
|
171
224
|
ctx: RunContext[CtxT] | None = None,
|
172
|
-
) -> Packet[
|
173
|
-
resolved_in_args = self._validate_and_resolve_single_input(
|
174
|
-
chat_inputs=chat_inputs, in_packet=in_packet, in_args=in_args
|
175
|
-
)
|
225
|
+
) -> Packet[OutT]:
|
176
226
|
_memory = self.memory.model_copy(deep=True) if forgetful else self.memory
|
177
227
|
|
178
|
-
|
228
|
+
output = await self._process(
|
179
229
|
chat_inputs=chat_inputs,
|
180
|
-
in_args=
|
230
|
+
in_args=in_args,
|
181
231
|
memory=_memory,
|
182
232
|
call_id=call_id,
|
183
233
|
ctx=ctx,
|
184
234
|
)
|
185
|
-
|
235
|
+
val_output = self._validate_output(output, call_id=call_id)
|
186
236
|
|
187
|
-
|
237
|
+
recipients = self._select_recipients(output=val_output, ctx=ctx)
|
238
|
+
self._validate_recipients(recipients, call_id=call_id)
|
239
|
+
|
240
|
+
return Packet(payloads=[val_output], sender=self.name, recipients=recipients)
|
188
241
|
|
189
242
|
async def _run_single(
|
190
243
|
self,
|
191
244
|
chat_inputs: Any | None = None,
|
192
245
|
*,
|
193
|
-
in_packet: Packet[InT] | None = None,
|
194
246
|
in_args: InT | None = None,
|
195
247
|
forgetful: bool = False,
|
196
248
|
call_id: str,
|
197
249
|
ctx: RunContext[CtxT] | None = None,
|
198
|
-
) -> Packet[
|
250
|
+
) -> Packet[OutT]:
|
199
251
|
n_attempt = 0
|
200
252
|
while n_attempt <= self.max_retries:
|
201
253
|
try:
|
202
254
|
return await self._run_single_once(
|
203
255
|
chat_inputs=chat_inputs,
|
204
|
-
in_packet=in_packet,
|
205
256
|
in_args=in_args,
|
206
257
|
forgetful=forgetful,
|
207
258
|
call_id=call_id,
|
208
259
|
ctx=ctx,
|
209
260
|
)
|
210
261
|
except Exception as err:
|
262
|
+
err_message = (
|
263
|
+
f"\nProcessor run failed [proc_name={self.name}; call_id={call_id}]"
|
264
|
+
)
|
211
265
|
n_attempt += 1
|
212
266
|
if n_attempt > self.max_retries:
|
213
267
|
if n_attempt == 1:
|
214
|
-
logger.warning(f"
|
268
|
+
logger.warning(f"{err_message}:\n{err}")
|
215
269
|
if n_attempt > 1:
|
216
|
-
logger.warning(f"
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
270
|
+
logger.warning(f"{err_message} after retrying:\n{err}")
|
271
|
+
raise ProcRunError(proc_name=self.name, call_id=call_id) from err
|
272
|
+
|
273
|
+
logger.warning(f"{err_message} (retry attempt {n_attempt}):\n{err}")
|
274
|
+
|
275
|
+
raise ProcRunError(proc_name=self.name, call_id=call_id)
|
221
276
|
|
222
277
|
async def _run_par(
|
223
|
-
self,
|
224
|
-
|
225
|
-
*,
|
226
|
-
in_packet: Packet[InT] | None = None,
|
227
|
-
in_args: Sequence[InT] | None = None,
|
228
|
-
call_id: str,
|
229
|
-
ctx: RunContext[CtxT] | None = None,
|
230
|
-
) -> Packet[OutT_co]:
|
231
|
-
par_inputs = self._validate_and_resolve_parallel_inputs(
|
232
|
-
chat_inputs=chat_inputs, in_packet=in_packet, in_args=in_args
|
233
|
-
)
|
278
|
+
self, in_args: Sequence[InT], call_id: str, ctx: RunContext[CtxT] | None = None
|
279
|
+
) -> Packet[OutT]:
|
234
280
|
tasks = [
|
235
281
|
self._run_single(
|
236
282
|
in_args=inp, forgetful=True, call_id=f"{call_id}/{idx}", ctx=ctx
|
237
283
|
)
|
238
|
-
for idx, inp in enumerate(
|
284
|
+
for idx, inp in enumerate(in_args)
|
239
285
|
]
|
240
286
|
out_packets = await asyncio.gather(*tasks)
|
241
287
|
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
],
|
288
|
+
self._validate_par_recipients(out_packets, call_id=call_id)
|
289
|
+
|
290
|
+
return Packet(
|
291
|
+
payloads=[out_packet.payloads[0] for out_packet in out_packets],
|
247
292
|
sender=self.name,
|
293
|
+
recipients=out_packets[0].recipients,
|
248
294
|
)
|
249
295
|
|
250
296
|
async def run(
|
@@ -256,23 +302,21 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
|
|
256
302
|
forgetful: bool = False,
|
257
303
|
call_id: str | None = None,
|
258
304
|
ctx: RunContext[CtxT] | None = None,
|
259
|
-
) -> Packet[
|
305
|
+
) -> Packet[OutT]:
|
260
306
|
call_id = self._generate_call_id(call_id)
|
261
307
|
|
262
|
-
|
263
|
-
|
264
|
-
):
|
265
|
-
return await self._run_par(
|
266
|
-
chat_inputs=chat_inputs,
|
267
|
-
in_packet=in_packet,
|
268
|
-
in_args=cast("Sequence[InT] | None", in_args),
|
269
|
-
call_id=call_id,
|
270
|
-
ctx=ctx,
|
271
|
-
)
|
272
|
-
return await self._run_single( # type: ignore[return]
|
308
|
+
val_in_args = self._validate_inputs(
|
309
|
+
call_id=call_id,
|
273
310
|
chat_inputs=chat_inputs,
|
274
311
|
in_packet=in_packet,
|
275
|
-
in_args=
|
312
|
+
in_args=in_args,
|
313
|
+
)
|
314
|
+
|
315
|
+
if val_in_args and len(val_in_args) > 1:
|
316
|
+
return await self._run_par(in_args=val_in_args, call_id=call_id, ctx=ctx)
|
317
|
+
return await self._run_single(
|
318
|
+
chat_inputs=chat_inputs,
|
319
|
+
in_args=val_in_args[0] if val_in_args else None,
|
276
320
|
forgetful=forgetful,
|
277
321
|
call_id=call_id,
|
278
322
|
ctx=ctx,
|
@@ -282,31 +326,35 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
|
|
282
326
|
self,
|
283
327
|
chat_inputs: Any | None = None,
|
284
328
|
*,
|
285
|
-
in_packet: Packet[InT] | None = None,
|
286
329
|
in_args: InT | None = None,
|
287
330
|
forgetful: bool = False,
|
288
331
|
call_id: str,
|
289
332
|
ctx: RunContext[CtxT] | None = None,
|
290
333
|
) -> AsyncIterator[Event[Any]]:
|
291
|
-
resolved_in_args = self._validate_and_resolve_single_input(
|
292
|
-
chat_inputs=chat_inputs, in_packet=in_packet, in_args=in_args
|
293
|
-
)
|
294
334
|
_memory = self.memory.model_copy(deep=True) if forgetful else self.memory
|
295
335
|
|
296
|
-
|
336
|
+
output: OutT | None = None
|
297
337
|
async for event in self._process_stream(
|
298
338
|
chat_inputs=chat_inputs,
|
299
|
-
in_args=
|
339
|
+
in_args=in_args,
|
300
340
|
memory=_memory,
|
301
341
|
call_id=call_id,
|
302
342
|
ctx=ctx,
|
303
343
|
):
|
304
344
|
if isinstance(event, ProcPayloadOutputEvent):
|
305
|
-
|
345
|
+
output = event.data
|
306
346
|
yield event
|
307
347
|
|
308
|
-
|
309
|
-
|
348
|
+
assert output is not None
|
349
|
+
|
350
|
+
val_output = self._validate_output(output, call_id=call_id)
|
351
|
+
|
352
|
+
recipients = self._select_recipients(output=val_output, ctx=ctx)
|
353
|
+
self._validate_recipients(recipients, call_id=call_id)
|
354
|
+
|
355
|
+
out_packet = Packet[OutT](
|
356
|
+
payloads=[val_output], sender=self.name, recipients=recipients
|
357
|
+
)
|
310
358
|
|
311
359
|
yield ProcPacketOutputEvent(
|
312
360
|
data=out_packet, proc_name=self.name, call_id=call_id
|
@@ -316,7 +364,6 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
|
|
316
364
|
self,
|
317
365
|
chat_inputs: Any | None = None,
|
318
366
|
*,
|
319
|
-
in_packet: Packet[InT] | None = None,
|
320
367
|
in_args: InT | None = None,
|
321
368
|
forgetful: bool = False,
|
322
369
|
call_id: str,
|
@@ -327,7 +374,6 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
|
|
327
374
|
try:
|
328
375
|
async for event in self._run_single_stream_once(
|
329
376
|
chat_inputs=chat_inputs,
|
330
|
-
in_packet=in_packet,
|
331
377
|
in_args=in_args,
|
332
378
|
forgetful=forgetful,
|
333
379
|
call_id=call_id,
|
@@ -343,56 +389,48 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
|
|
343
389
|
data=err_data, proc_name=self.name, call_id=call_id
|
344
390
|
)
|
345
391
|
|
392
|
+
err_message = (
|
393
|
+
"\nStreaming processor run failed "
|
394
|
+
f"[proc_name={self.name}; call_id={call_id}]"
|
395
|
+
)
|
396
|
+
|
346
397
|
n_attempt += 1
|
347
398
|
if n_attempt > self.max_retries:
|
348
399
|
if n_attempt == 1:
|
349
|
-
logger.warning(f"
|
400
|
+
logger.warning(f"{err_message}:\n{err}")
|
350
401
|
if n_attempt > 1:
|
351
|
-
logger.warning(
|
352
|
-
|
353
|
-
)
|
354
|
-
return
|
402
|
+
logger.warning(f"{err_message} after retrying:\n{err}")
|
403
|
+
raise ProcRunError(proc_name=self.name, call_id=call_id) from err
|
355
404
|
|
356
|
-
logger.warning(
|
357
|
-
"\nStreaming processor run failed "
|
358
|
-
f"(retry attempt {n_attempt}):\n{err}"
|
359
|
-
)
|
405
|
+
logger.warning(f"{err_message} (retry attempt {n_attempt}):\n{err}")
|
360
406
|
|
361
407
|
async def _run_par_stream(
|
362
408
|
self,
|
363
|
-
|
364
|
-
*,
|
365
|
-
in_packet: Packet[InT] | None = None,
|
366
|
-
in_args: Sequence[InT] | None = None,
|
409
|
+
in_args: Sequence[InT],
|
367
410
|
call_id: str,
|
368
411
|
ctx: RunContext[CtxT] | None = None,
|
369
412
|
) -> AsyncIterator[Event[Any]]:
|
370
|
-
par_inputs = self._validate_and_resolve_parallel_inputs(
|
371
|
-
chat_inputs=chat_inputs, in_packet=in_packet, in_args=in_args
|
372
|
-
)
|
373
413
|
streams = [
|
374
414
|
self._run_single_stream(
|
375
415
|
in_args=inp, forgetful=True, call_id=f"{call_id}/{idx}", ctx=ctx
|
376
416
|
)
|
377
|
-
for idx, inp in enumerate(
|
417
|
+
for idx, inp in enumerate(in_args)
|
378
418
|
]
|
379
419
|
|
380
|
-
out_packets_map: dict[int, Packet[
|
381
|
-
range(len(streams)), None
|
382
|
-
)
|
383
|
-
|
420
|
+
out_packets_map: dict[int, Packet[OutT]] = {}
|
384
421
|
async for idx, event in stream_concurrent(streams):
|
385
422
|
if isinstance(event, ProcPacketOutputEvent):
|
386
423
|
out_packets_map[idx] = event.data
|
387
424
|
else:
|
388
425
|
yield event
|
389
426
|
|
390
|
-
out_packet = Packet(
|
427
|
+
out_packet = Packet(
|
391
428
|
payloads=[
|
392
|
-
|
393
|
-
for out_packet in out_packets_map.
|
429
|
+
out_packet.payloads[0]
|
430
|
+
for _, out_packet in sorted(out_packets_map.items())
|
394
431
|
],
|
395
432
|
sender=self.name,
|
433
|
+
recipients=out_packets_map[0].recipients,
|
396
434
|
)
|
397
435
|
|
398
436
|
yield ProcPacketOutputEvent(
|
@@ -411,23 +449,19 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
|
|
411
449
|
) -> AsyncIterator[Event[Any]]:
|
412
450
|
call_id = self._generate_call_id(call_id)
|
413
451
|
|
414
|
-
|
452
|
+
val_in_args = self._validate_inputs(
|
453
|
+
call_id=call_id,
|
454
|
+
chat_inputs=chat_inputs,
|
455
|
+
in_packet=in_packet,
|
456
|
+
in_args=in_args,
|
457
|
+
)
|
415
458
|
|
416
|
-
if
|
417
|
-
|
418
|
-
):
|
419
|
-
stream = self._run_par_stream(
|
420
|
-
chat_inputs=chat_inputs,
|
421
|
-
in_packet=in_packet,
|
422
|
-
in_args=cast("Sequence[InT] | None", in_args),
|
423
|
-
call_id=call_id,
|
424
|
-
ctx=ctx,
|
425
|
-
)
|
459
|
+
if val_in_args and len(val_in_args) > 1:
|
460
|
+
stream = self._run_par_stream(in_args=val_in_args, call_id=call_id, ctx=ctx)
|
426
461
|
else:
|
427
462
|
stream = self._run_single_stream(
|
428
463
|
chat_inputs=chat_inputs,
|
429
|
-
|
430
|
-
in_args=cast("InT | None", in_args),
|
464
|
+
in_args=val_in_args[0] if val_in_args else None,
|
431
465
|
forgetful=forgetful,
|
432
466
|
call_id=call_id,
|
433
467
|
ctx=ctx,
|
@@ -435,12 +469,25 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
|
|
435
469
|
async for event in stream:
|
436
470
|
yield event
|
437
471
|
|
438
|
-
|
472
|
+
def _select_recipients(
|
473
|
+
self, output: OutT, ctx: RunContext[CtxT] | None = None
|
474
|
+
) -> list[ProcName] | None:
|
475
|
+
if self.select_recipients_impl:
|
476
|
+
return self.select_recipients_impl(output=output, ctx=ctx)
|
477
|
+
|
478
|
+
return self.recipients
|
479
|
+
|
480
|
+
def select_recipients(
|
481
|
+
self, func: SelectRecipientsHandler[OutT, CtxT]
|
482
|
+
) -> SelectRecipientsHandler[OutT, CtxT]:
|
483
|
+
self.select_recipients_impl = func
|
484
|
+
|
485
|
+
return func
|
439
486
|
|
440
487
|
@final
|
441
488
|
def as_tool(
|
442
489
|
self, tool_name: str, tool_description: str
|
443
|
-
) -> BaseTool[InT,
|
490
|
+
) -> BaseTool[InT, OutT, Any]: # type: ignore[override]
|
444
491
|
# TODO: stream tools
|
445
492
|
processor_instance = self
|
446
493
|
in_type = processor_instance.in_type
|
@@ -455,13 +502,11 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
|
|
455
502
|
name: str = tool_name
|
456
503
|
description: str = tool_description
|
457
504
|
|
458
|
-
async def run(
|
459
|
-
self, inp: InT, ctx: RunContext[CtxT] | None = None
|
460
|
-
) -> OutT_co:
|
505
|
+
async def run(self, inp: InT, ctx: RunContext[CtxT] | None = None) -> OutT:
|
461
506
|
result = await processor_instance.run(
|
462
507
|
in_args=in_type.model_validate(inp), forgetful=True, ctx=ctx
|
463
508
|
)
|
464
509
|
|
465
510
|
return result.payloads[0]
|
466
511
|
|
467
|
-
return ProcessorTool()
|
512
|
+
return ProcessorTool()
|