grasp_agents 0.5.5__py3-none-any.whl → 0.5.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_agents/__init__.py +5 -1
- grasp_agents/llm.py +5 -1
- grasp_agents/llm_agent.py +18 -7
- grasp_agents/packet_pool.py +6 -1
- grasp_agents/printer.py +7 -4
- grasp_agents/{processor.py → processors/base_processor.py} +89 -287
- grasp_agents/processors/parallel_processor.py +244 -0
- grasp_agents/processors/processor.py +161 -0
- grasp_agents/runner.py +20 -1
- grasp_agents/typing/events.py +4 -0
- grasp_agents/workflow/looped_workflow.py +35 -27
- grasp_agents/workflow/sequential_workflow.py +14 -3
- grasp_agents/workflow/workflow_processor.py +21 -15
- {grasp_agents-0.5.5.dist-info → grasp_agents-0.5.6.dist-info}/METADATA +1 -1
- {grasp_agents-0.5.5.dist-info → grasp_agents-0.5.6.dist-info}/RECORD +17 -15
- {grasp_agents-0.5.5.dist-info → grasp_agents-0.5.6.dist-info}/WHEEL +0 -0
- {grasp_agents-0.5.5.dist-info → grasp_agents-0.5.6.dist-info}/licenses/LICENSE.md +0 -0
grasp_agents/__init__.py
CHANGED
@@ -6,7 +6,9 @@ from .llm_agent import LLMAgent
|
|
6
6
|
from .llm_agent_memory import LLMAgentMemory
|
7
7
|
from .memory import Memory
|
8
8
|
from .packet import Packet
|
9
|
-
from .
|
9
|
+
from .processors.base_processor import BaseProcessor
|
10
|
+
from .processors.parallel_processor import ParallelProcessor
|
11
|
+
from .processors.processor import Processor
|
10
12
|
from .run_context import RunContext
|
11
13
|
from .typing.completion import Completion
|
12
14
|
from .typing.content import Content, ImageData
|
@@ -17,6 +19,7 @@ from .typing.tool import BaseTool
|
|
17
19
|
__all__ = [
|
18
20
|
"LLM",
|
19
21
|
"AssistantMessage",
|
22
|
+
"BaseProcessor",
|
20
23
|
"BaseTool",
|
21
24
|
"Completion",
|
22
25
|
"Content",
|
@@ -29,6 +32,7 @@ __all__ = [
|
|
29
32
|
"Messages",
|
30
33
|
"Packet",
|
31
34
|
"Packet",
|
35
|
+
"ParallelProcessor",
|
32
36
|
"ProcName",
|
33
37
|
"Processor",
|
34
38
|
"RunContext",
|
grasp_agents/llm.py
CHANGED
@@ -19,7 +19,11 @@ from .errors import (
|
|
19
19
|
)
|
20
20
|
from .typing.completion import Completion
|
21
21
|
from .typing.converters import Converters
|
22
|
-
from .typing.events import
|
22
|
+
from .typing.events import (
|
23
|
+
CompletionChunkEvent,
|
24
|
+
CompletionEvent,
|
25
|
+
LLMStreamingErrorEvent,
|
26
|
+
)
|
23
27
|
from .typing.message import Messages
|
24
28
|
from .typing.tool import BaseTool, ToolChoice
|
25
29
|
|
grasp_agents/llm_agent.py
CHANGED
@@ -11,7 +11,7 @@ from .llm_policy_executor import (
|
|
11
11
|
MemoryManager,
|
12
12
|
ToolCallLoopTerminator,
|
13
13
|
)
|
14
|
-
from .
|
14
|
+
from .processors.parallel_processor import ParallelProcessor
|
15
15
|
from .prompt_builder import (
|
16
16
|
InputContentBuilder,
|
17
17
|
PromptBuilder,
|
@@ -46,7 +46,7 @@ class OutputParser(Protocol[_InT_contra, _OutT_co, CtxT]):
|
|
46
46
|
|
47
47
|
|
48
48
|
class LLMAgent(
|
49
|
-
|
49
|
+
ParallelProcessor[InT, OutT, LLMAgentMemory, CtxT],
|
50
50
|
Generic[InT, OutT, CtxT],
|
51
51
|
):
|
52
52
|
_generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
|
@@ -196,6 +196,20 @@ class LLMAgent(
|
|
196
196
|
|
197
197
|
return system_message, input_message
|
198
198
|
|
199
|
+
def _parse_output_default(
|
200
|
+
self,
|
201
|
+
conversation: Messages,
|
202
|
+
*,
|
203
|
+
in_args: InT | None = None,
|
204
|
+
ctx: RunContext[CtxT] | None = None,
|
205
|
+
) -> OutT:
|
206
|
+
return validate_obj_from_json_or_py_string(
|
207
|
+
str(conversation[-1].content or ""),
|
208
|
+
schema=self._out_type,
|
209
|
+
from_substring=False,
|
210
|
+
strip_language_markdown=True,
|
211
|
+
)
|
212
|
+
|
199
213
|
def _parse_output(
|
200
214
|
self,
|
201
215
|
conversation: Messages,
|
@@ -208,11 +222,8 @@ class LLMAgent(
|
|
208
222
|
conversation=conversation, in_args=in_args, ctx=ctx
|
209
223
|
)
|
210
224
|
|
211
|
-
return
|
212
|
-
|
213
|
-
schema=self._out_type,
|
214
|
-
from_substring=False,
|
215
|
-
strip_language_markdown=True,
|
225
|
+
return self._parse_output_default(
|
226
|
+
conversation=conversation, in_args=in_args, ctx=ctx
|
216
227
|
)
|
217
228
|
|
218
229
|
async def _process(
|
grasp_agents/packet_pool.py
CHANGED
@@ -68,6 +68,11 @@ class PacketPool(Generic[CtxT]):
|
|
68
68
|
finally:
|
69
69
|
await self.shutdown()
|
70
70
|
|
71
|
+
@property
|
72
|
+
def final_result_ready(self) -> bool:
|
73
|
+
fut = self._final_result_fut
|
74
|
+
return fut is not None and fut.done()
|
75
|
+
|
71
76
|
def register_packet_handler(
|
72
77
|
self,
|
73
78
|
proc_name: ProcName,
|
@@ -121,7 +126,7 @@ class PacketPool(Generic[CtxT]):
|
|
121
126
|
queue = self._packet_queues[proc_name]
|
122
127
|
handler = self._packet_handlers[proc_name]
|
123
128
|
|
124
|
-
while
|
129
|
+
while not self.final_result_ready:
|
125
130
|
packet = await queue.get()
|
126
131
|
if packet is None:
|
127
132
|
break
|
grasp_agents/printer.py
CHANGED
@@ -232,11 +232,14 @@ async def print_event_stream(
|
|
232
232
|
text += f"<{src} output>\n"
|
233
233
|
for p in event.data.payloads:
|
234
234
|
if isinstance(p, BaseModel):
|
235
|
-
for field_info in type(p).model_fields.values():
|
236
|
-
|
237
|
-
|
238
|
-
|
235
|
+
# for field_info in type(p).model_fields.values():
|
236
|
+
# if field_info.exclude:
|
237
|
+
# field_info.exclude = False
|
238
|
+
# break
|
239
|
+
# type(p).model_rebuild(force=True)
|
239
240
|
p_str = p.model_dump_json(indent=2)
|
241
|
+
# field_info.exclude = True # type: ignore
|
242
|
+
# type(p).model_rebuild(force=True)
|
240
243
|
else:
|
241
244
|
try:
|
242
245
|
p_str = json.dumps(p, indent=2)
|
@@ -1,38 +1,106 @@
|
|
1
|
-
import asyncio
|
2
1
|
import logging
|
3
|
-
from abc import ABC
|
4
|
-
from collections.abc import AsyncIterator,
|
5
|
-
from
|
2
|
+
from abc import ABC, abstractmethod
|
3
|
+
from collections.abc import AsyncIterator, Callable, Coroutine
|
4
|
+
from functools import wraps
|
5
|
+
from typing import (
|
6
|
+
Any,
|
7
|
+
ClassVar,
|
8
|
+
Generic,
|
9
|
+
Protocol,
|
10
|
+
TypeVar,
|
11
|
+
cast,
|
12
|
+
final,
|
13
|
+
)
|
6
14
|
from uuid import uuid4
|
7
15
|
|
8
16
|
from pydantic import BaseModel, TypeAdapter
|
9
17
|
from pydantic import ValidationError as PydanticValidationError
|
10
18
|
|
11
|
-
from
|
19
|
+
from ..errors import (
|
12
20
|
PacketRoutingError,
|
13
21
|
ProcInputValidationError,
|
14
22
|
ProcOutputValidationError,
|
15
23
|
ProcRunError,
|
16
24
|
)
|
17
|
-
from
|
18
|
-
from
|
19
|
-
from
|
20
|
-
from
|
21
|
-
from
|
25
|
+
from ..generics_utils import AutoInstanceAttributesMixin
|
26
|
+
from ..memory import DummyMemory, MemT
|
27
|
+
from ..packet import Packet
|
28
|
+
from ..run_context import CtxT, RunContext
|
29
|
+
from ..typing.events import (
|
30
|
+
DummyEvent,
|
22
31
|
Event,
|
23
|
-
ProcPacketOutputEvent,
|
24
|
-
ProcPayloadOutputEvent,
|
25
32
|
ProcStreamingErrorData,
|
26
33
|
ProcStreamingErrorEvent,
|
27
34
|
)
|
28
|
-
from
|
29
|
-
from
|
30
|
-
from .utils import stream_concurrent
|
35
|
+
from ..typing.io import InT, OutT, ProcName
|
36
|
+
from ..typing.tool import BaseTool
|
31
37
|
|
32
38
|
logger = logging.getLogger(__name__)
|
33
39
|
|
34
40
|
_OutT_contra = TypeVar("_OutT_contra", contravariant=True)
|
35
41
|
|
42
|
+
F = TypeVar("F", bound=Callable[..., Coroutine[Any, Any, Packet[Any]]])
|
43
|
+
F_stream = TypeVar("F_stream", bound=Callable[..., AsyncIterator[Event[Any]]])
|
44
|
+
|
45
|
+
|
46
|
+
def with_retry(func: F) -> F:
|
47
|
+
@wraps(func)
|
48
|
+
async def wrapper(
|
49
|
+
self: "BaseProcessor[Any, Any, Any, Any]", *args: Any, **kwargs: Any
|
50
|
+
) -> Packet[Any]:
|
51
|
+
call_id = kwargs.get("call_id", "unknown")
|
52
|
+
for n_attempt in range(self.max_retries + 1):
|
53
|
+
try:
|
54
|
+
return await func(self, *args, **kwargs)
|
55
|
+
except Exception as err:
|
56
|
+
err_message = (
|
57
|
+
f"\nProcessor run failed [proc_name={self.name}; call_id={call_id}]"
|
58
|
+
)
|
59
|
+
if n_attempt == self.max_retries:
|
60
|
+
if self.max_retries == 0:
|
61
|
+
logger.warning(f"{err_message}:\n{err}")
|
62
|
+
else:
|
63
|
+
logger.warning(f"{err_message} after retrying:\n{err}")
|
64
|
+
raise ProcRunError(proc_name=self.name, call_id=call_id) from err
|
65
|
+
|
66
|
+
logger.warning(f"{err_message} (retry attempt {n_attempt + 1}):\n{err}")
|
67
|
+
# This part should not be reachable due to the raise in the loop
|
68
|
+
raise ProcRunError(proc_name=self.name, call_id=call_id)
|
69
|
+
|
70
|
+
return cast("F", wrapper)
|
71
|
+
|
72
|
+
|
73
|
+
def with_retry_stream(func: F_stream) -> F_stream:
|
74
|
+
@wraps(func)
|
75
|
+
async def wrapper(
|
76
|
+
self: "BaseProcessor[Any, Any, Any, Any]", *args: Any, **kwargs: Any
|
77
|
+
) -> AsyncIterator[Event[Any]]:
|
78
|
+
call_id = kwargs.get("call_id", "unknown")
|
79
|
+
for n_attempt in range(self.max_retries + 1):
|
80
|
+
try:
|
81
|
+
async for event in func(self, *args, **kwargs):
|
82
|
+
yield event
|
83
|
+
return
|
84
|
+
except Exception as err:
|
85
|
+
err_data = ProcStreamingErrorData(error=err, call_id=call_id)
|
86
|
+
yield ProcStreamingErrorEvent(
|
87
|
+
data=err_data, proc_name=self.name, call_id=call_id
|
88
|
+
)
|
89
|
+
err_message = (
|
90
|
+
"\nStreaming processor run failed "
|
91
|
+
f"[proc_name={self.name}; call_id={call_id}]"
|
92
|
+
)
|
93
|
+
if n_attempt == self.max_retries:
|
94
|
+
if self.max_retries == 0:
|
95
|
+
logger.warning(f"{err_message}:\n{err}")
|
96
|
+
else:
|
97
|
+
logger.warning(f"{err_message} after retrying:\n{err}")
|
98
|
+
raise ProcRunError(proc_name=self.name, call_id=call_id) from err
|
99
|
+
|
100
|
+
logger.warning(f"{err_message} (retry attempt {n_attempt}):\n{err}")
|
101
|
+
|
102
|
+
return cast("F_stream", wrapper)
|
103
|
+
|
36
104
|
|
37
105
|
class RecipientSelector(Protocol[_OutT_contra, CtxT]):
|
38
106
|
def __call__(
|
@@ -40,7 +108,7 @@ class RecipientSelector(Protocol[_OutT_contra, CtxT]):
|
|
40
108
|
) -> list[ProcName] | None: ...
|
41
109
|
|
42
110
|
|
43
|
-
class
|
111
|
+
class BaseProcessor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT, MemT, CtxT]):
|
44
112
|
_generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
|
45
113
|
0: "_in_type",
|
46
114
|
1: "_out_type",
|
@@ -66,7 +134,6 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT, MemT, CtxT]
|
|
66
134
|
|
67
135
|
self.recipient_selector: RecipientSelector[OutT, CtxT] | None
|
68
136
|
if not hasattr(type(self), "recipient_selector"):
|
69
|
-
# Set to None if not defined in the subclass
|
70
137
|
self.recipient_selector = None
|
71
138
|
|
72
139
|
@property
|
@@ -183,19 +250,6 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT, MemT, CtxT]
|
|
183
250
|
allowed_recipients=cast("list[str]", self.recipients),
|
184
251
|
)
|
185
252
|
|
186
|
-
def _validate_par_recipients(
|
187
|
-
self, out_packets: Sequence[Packet[OutT]], call_id: str
|
188
|
-
) -> None:
|
189
|
-
recipient_sets = [set(p.recipients or []) for p in out_packets]
|
190
|
-
same_recipients = all(rs == recipient_sets[0] for rs in recipient_sets)
|
191
|
-
if not same_recipients:
|
192
|
-
raise PacketRoutingError(
|
193
|
-
proc_name=self.name,
|
194
|
-
call_id=call_id,
|
195
|
-
message="Parallel runs must return the same recipients "
|
196
|
-
f"[proc_name={self.name}; call_id={call_id}]",
|
197
|
-
)
|
198
|
-
|
199
253
|
@final
|
200
254
|
def _select_recipients(
|
201
255
|
self, output: OutT, ctx: RunContext[CtxT] | None = None
|
@@ -212,108 +266,7 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT, MemT, CtxT]
|
|
212
266
|
|
213
267
|
return func
|
214
268
|
|
215
|
-
|
216
|
-
self,
|
217
|
-
chat_inputs: Any | None = None,
|
218
|
-
*,
|
219
|
-
in_args: InT | None = None,
|
220
|
-
memory: MemT,
|
221
|
-
call_id: str,
|
222
|
-
ctx: RunContext[CtxT] | None = None,
|
223
|
-
) -> OutT:
|
224
|
-
return cast("OutT", in_args)
|
225
|
-
|
226
|
-
async def _process_stream(
|
227
|
-
self,
|
228
|
-
chat_inputs: Any | None = None,
|
229
|
-
*,
|
230
|
-
in_args: InT | None = None,
|
231
|
-
memory: MemT,
|
232
|
-
call_id: str,
|
233
|
-
ctx: RunContext[CtxT] | None = None,
|
234
|
-
) -> AsyncIterator[Event[Any]]:
|
235
|
-
output = cast("OutT", in_args)
|
236
|
-
yield ProcPayloadOutputEvent(data=output, proc_name=self.name, call_id=call_id)
|
237
|
-
|
238
|
-
async def _run_single_once(
|
239
|
-
self,
|
240
|
-
chat_inputs: Any | None = None,
|
241
|
-
*,
|
242
|
-
in_args: InT | None = None,
|
243
|
-
forgetful: bool = False,
|
244
|
-
call_id: str,
|
245
|
-
ctx: RunContext[CtxT] | None = None,
|
246
|
-
) -> Packet[OutT]:
|
247
|
-
_memory = self.memory.model_copy(deep=True) if forgetful else self.memory
|
248
|
-
|
249
|
-
output = await self._process(
|
250
|
-
chat_inputs=chat_inputs,
|
251
|
-
in_args=in_args,
|
252
|
-
memory=_memory,
|
253
|
-
call_id=call_id,
|
254
|
-
ctx=ctx,
|
255
|
-
)
|
256
|
-
val_output = self._validate_output(output, call_id=call_id)
|
257
|
-
|
258
|
-
recipients = self._select_recipients(output=val_output, ctx=ctx)
|
259
|
-
self._validate_recipients(recipients, call_id=call_id)
|
260
|
-
|
261
|
-
return Packet(payloads=[val_output], sender=self.name, recipients=recipients)
|
262
|
-
|
263
|
-
async def _run_single(
|
264
|
-
self,
|
265
|
-
chat_inputs: Any | None = None,
|
266
|
-
*,
|
267
|
-
in_args: InT | None = None,
|
268
|
-
forgetful: bool = False,
|
269
|
-
call_id: str,
|
270
|
-
ctx: RunContext[CtxT] | None = None,
|
271
|
-
) -> Packet[OutT]:
|
272
|
-
n_attempt = 0
|
273
|
-
while n_attempt <= self.max_retries:
|
274
|
-
try:
|
275
|
-
return await self._run_single_once(
|
276
|
-
chat_inputs=chat_inputs,
|
277
|
-
in_args=in_args,
|
278
|
-
forgetful=forgetful,
|
279
|
-
call_id=call_id,
|
280
|
-
ctx=ctx,
|
281
|
-
)
|
282
|
-
except Exception as err:
|
283
|
-
err_message = (
|
284
|
-
f"\nProcessor run failed [proc_name={self.name}; call_id={call_id}]"
|
285
|
-
)
|
286
|
-
n_attempt += 1
|
287
|
-
if n_attempt > self.max_retries:
|
288
|
-
if n_attempt == 1:
|
289
|
-
logger.warning(f"{err_message}:\n{err}")
|
290
|
-
if n_attempt > 1:
|
291
|
-
logger.warning(f"{err_message} after retrying:\n{err}")
|
292
|
-
raise ProcRunError(proc_name=self.name, call_id=call_id) from err
|
293
|
-
|
294
|
-
logger.warning(f"{err_message} (retry attempt {n_attempt}):\n{err}")
|
295
|
-
|
296
|
-
raise ProcRunError(proc_name=self.name, call_id=call_id)
|
297
|
-
|
298
|
-
async def _run_par(
|
299
|
-
self, in_args: list[InT], call_id: str, ctx: RunContext[CtxT] | None = None
|
300
|
-
) -> Packet[OutT]:
|
301
|
-
tasks = [
|
302
|
-
self._run_single(
|
303
|
-
in_args=inp, forgetful=True, call_id=f"{call_id}/{idx}", ctx=ctx
|
304
|
-
)
|
305
|
-
for idx, inp in enumerate(in_args)
|
306
|
-
]
|
307
|
-
out_packets = await asyncio.gather(*tasks)
|
308
|
-
|
309
|
-
self._validate_par_recipients(out_packets, call_id=call_id)
|
310
|
-
|
311
|
-
return Packet(
|
312
|
-
payloads=[out_packet.payloads[0] for out_packet in out_packets],
|
313
|
-
sender=self.name,
|
314
|
-
recipients=out_packets[0].recipients,
|
315
|
-
)
|
316
|
-
|
269
|
+
@abstractmethod
|
317
270
|
async def run(
|
318
271
|
self,
|
319
272
|
chat_inputs: Any | None = None,
|
@@ -324,140 +277,9 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT, MemT, CtxT]
|
|
324
277
|
call_id: str | None = None,
|
325
278
|
ctx: RunContext[CtxT] | None = None,
|
326
279
|
) -> Packet[OutT]:
|
327
|
-
|
328
|
-
|
329
|
-
val_in_args = self._validate_inputs(
|
330
|
-
call_id=call_id,
|
331
|
-
chat_inputs=chat_inputs,
|
332
|
-
in_packet=in_packet,
|
333
|
-
in_args=in_args,
|
334
|
-
)
|
335
|
-
|
336
|
-
if val_in_args and len(val_in_args) > 1:
|
337
|
-
return await self._run_par(in_args=val_in_args, call_id=call_id, ctx=ctx)
|
338
|
-
return await self._run_single(
|
339
|
-
chat_inputs=chat_inputs,
|
340
|
-
in_args=val_in_args[0] if val_in_args else None,
|
341
|
-
forgetful=forgetful,
|
342
|
-
call_id=call_id,
|
343
|
-
ctx=ctx,
|
344
|
-
)
|
345
|
-
|
346
|
-
async def _run_single_stream_once(
|
347
|
-
self,
|
348
|
-
chat_inputs: Any | None = None,
|
349
|
-
*,
|
350
|
-
in_args: InT | None = None,
|
351
|
-
forgetful: bool = False,
|
352
|
-
call_id: str,
|
353
|
-
ctx: RunContext[CtxT] | None = None,
|
354
|
-
) -> AsyncIterator[Event[Any]]:
|
355
|
-
_memory = self.memory.model_copy(deep=True) if forgetful else self.memory
|
356
|
-
|
357
|
-
output: OutT | None = None
|
358
|
-
async for event in self._process_stream(
|
359
|
-
chat_inputs=chat_inputs,
|
360
|
-
in_args=in_args,
|
361
|
-
memory=_memory,
|
362
|
-
call_id=call_id,
|
363
|
-
ctx=ctx,
|
364
|
-
):
|
365
|
-
if isinstance(event, ProcPayloadOutputEvent):
|
366
|
-
output = event.data
|
367
|
-
yield event
|
368
|
-
|
369
|
-
assert output is not None
|
370
|
-
|
371
|
-
val_output = self._validate_output(output, call_id=call_id)
|
372
|
-
|
373
|
-
recipients = self._select_recipients(output=val_output, ctx=ctx)
|
374
|
-
self._validate_recipients(recipients, call_id=call_id)
|
375
|
-
|
376
|
-
out_packet = Packet[OutT](
|
377
|
-
payloads=[val_output], sender=self.name, recipients=recipients
|
378
|
-
)
|
379
|
-
|
380
|
-
yield ProcPacketOutputEvent(
|
381
|
-
data=out_packet, proc_name=self.name, call_id=call_id
|
382
|
-
)
|
383
|
-
|
384
|
-
async def _run_single_stream(
|
385
|
-
self,
|
386
|
-
chat_inputs: Any | None = None,
|
387
|
-
*,
|
388
|
-
in_args: InT | None = None,
|
389
|
-
forgetful: bool = False,
|
390
|
-
call_id: str,
|
391
|
-
ctx: RunContext[CtxT] | None = None,
|
392
|
-
) -> AsyncIterator[Event[Any]]:
|
393
|
-
n_attempt = 0
|
394
|
-
while n_attempt <= self.max_retries:
|
395
|
-
try:
|
396
|
-
async for event in self._run_single_stream_once(
|
397
|
-
chat_inputs=chat_inputs,
|
398
|
-
in_args=in_args,
|
399
|
-
forgetful=forgetful,
|
400
|
-
call_id=call_id,
|
401
|
-
ctx=ctx,
|
402
|
-
):
|
403
|
-
yield event
|
404
|
-
|
405
|
-
return
|
406
|
-
|
407
|
-
except Exception as err:
|
408
|
-
err_data = ProcStreamingErrorData(error=err, call_id=call_id)
|
409
|
-
yield ProcStreamingErrorEvent(
|
410
|
-
data=err_data, proc_name=self.name, call_id=call_id
|
411
|
-
)
|
412
|
-
|
413
|
-
err_message = (
|
414
|
-
"\nStreaming processor run failed "
|
415
|
-
f"[proc_name={self.name}; call_id={call_id}]"
|
416
|
-
)
|
417
|
-
|
418
|
-
n_attempt += 1
|
419
|
-
if n_attempt > self.max_retries:
|
420
|
-
if n_attempt == 1:
|
421
|
-
logger.warning(f"{err_message}:\n{err}")
|
422
|
-
if n_attempt > 1:
|
423
|
-
logger.warning(f"{err_message} after retrying:\n{err}")
|
424
|
-
raise ProcRunError(proc_name=self.name, call_id=call_id) from err
|
425
|
-
|
426
|
-
logger.warning(f"{err_message} (retry attempt {n_attempt}):\n{err}")
|
427
|
-
|
428
|
-
async def _run_par_stream(
|
429
|
-
self,
|
430
|
-
in_args: list[InT],
|
431
|
-
call_id: str,
|
432
|
-
ctx: RunContext[CtxT] | None = None,
|
433
|
-
) -> AsyncIterator[Event[Any]]:
|
434
|
-
streams = [
|
435
|
-
self._run_single_stream(
|
436
|
-
in_args=inp, forgetful=True, call_id=f"{call_id}/{idx}", ctx=ctx
|
437
|
-
)
|
438
|
-
for idx, inp in enumerate(in_args)
|
439
|
-
]
|
440
|
-
|
441
|
-
out_packets_map: dict[int, Packet[OutT]] = {}
|
442
|
-
async for idx, event in stream_concurrent(streams):
|
443
|
-
if isinstance(event, ProcPacketOutputEvent):
|
444
|
-
out_packets_map[idx] = event.data
|
445
|
-
else:
|
446
|
-
yield event
|
447
|
-
|
448
|
-
out_packet = Packet(
|
449
|
-
payloads=[
|
450
|
-
out_packet.payloads[0]
|
451
|
-
for _, out_packet in sorted(out_packets_map.items())
|
452
|
-
],
|
453
|
-
sender=self.name,
|
454
|
-
recipients=out_packets_map[0].recipients,
|
455
|
-
)
|
456
|
-
|
457
|
-
yield ProcPacketOutputEvent(
|
458
|
-
data=out_packet, proc_name=self.name, call_id=call_id
|
459
|
-
)
|
280
|
+
pass
|
460
281
|
|
282
|
+
@abstractmethod
|
461
283
|
async def run_stream(
|
462
284
|
self,
|
463
285
|
chat_inputs: Any | None = None,
|
@@ -468,27 +290,7 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT, MemT, CtxT]
|
|
468
290
|
call_id: str | None = None,
|
469
291
|
ctx: RunContext[CtxT] | None = None,
|
470
292
|
) -> AsyncIterator[Event[Any]]:
|
471
|
-
|
472
|
-
|
473
|
-
val_in_args = self._validate_inputs(
|
474
|
-
call_id=call_id,
|
475
|
-
chat_inputs=chat_inputs,
|
476
|
-
in_packet=in_packet,
|
477
|
-
in_args=in_args,
|
478
|
-
)
|
479
|
-
|
480
|
-
if val_in_args and len(val_in_args) > 1:
|
481
|
-
stream = self._run_par_stream(in_args=val_in_args, call_id=call_id, ctx=ctx)
|
482
|
-
else:
|
483
|
-
stream = self._run_single_stream(
|
484
|
-
chat_inputs=chat_inputs,
|
485
|
-
in_args=val_in_args[0] if val_in_args else None,
|
486
|
-
forgetful=forgetful,
|
487
|
-
call_id=call_id,
|
488
|
-
ctx=ctx,
|
489
|
-
)
|
490
|
-
async for event in stream:
|
491
|
-
yield event
|
293
|
+
yield DummyEvent()
|
492
294
|
|
493
295
|
@final
|
494
296
|
def as_tool(
|
@@ -510,7 +312,7 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT, MemT, CtxT]
|
|
510
312
|
|
511
313
|
async def run(self, inp: InT, ctx: RunContext[CtxT] | None = None) -> OutT:
|
512
314
|
result = await processor_instance.run(
|
513
|
-
in_args=
|
315
|
+
in_args=inp, forgetful=True, ctx=ctx
|
514
316
|
)
|
515
317
|
|
516
318
|
return result.payloads[0]
|