grasp_agents 0.5.5__py3-none-any.whl → 0.5.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_agents/__init__.py +5 -1
- grasp_agents/cloud_llm.py +11 -5
- grasp_agents/llm.py +151 -2
- grasp_agents/llm_agent.py +18 -7
- grasp_agents/llm_policy_executor.py +9 -3
- grasp_agents/packet_pool.py +22 -37
- grasp_agents/printer.py +75 -74
- grasp_agents/{processor.py → processors/base_processor.py} +89 -287
- grasp_agents/processors/parallel_processor.py +246 -0
- grasp_agents/processors/processor.py +161 -0
- grasp_agents/runner.py +46 -24
- grasp_agents/typing/completion_chunk.py +302 -3
- grasp_agents/typing/events.py +259 -49
- grasp_agents/workflow/looped_workflow.py +35 -27
- grasp_agents/workflow/sequential_workflow.py +14 -3
- grasp_agents/workflow/workflow_processor.py +21 -15
- {grasp_agents-0.5.5.dist-info → grasp_agents-0.5.8.dist-info}/METADATA +1 -1
- {grasp_agents-0.5.5.dist-info → grasp_agents-0.5.8.dist-info}/RECORD +20 -18
- {grasp_agents-0.5.5.dist-info → grasp_agents-0.5.8.dist-info}/WHEEL +0 -0
- {grasp_agents-0.5.5.dist-info → grasp_agents-0.5.8.dist-info}/licenses/LICENSE.md +0 -0
grasp_agents/printer.py
CHANGED
@@ -9,14 +9,29 @@ from pydantic import BaseModel
|
|
9
9
|
from termcolor import colored
|
10
10
|
from termcolor._types import Color
|
11
11
|
|
12
|
+
from grasp_agents.typing.completion_chunk import CompletionChunk
|
12
13
|
from grasp_agents.typing.events import (
|
14
|
+
AnnotationsChunkEvent,
|
15
|
+
AnnotationsEndEvent,
|
16
|
+
AnnotationsStartEvent,
|
13
17
|
CompletionChunkEvent,
|
18
|
+
# CompletionEndEvent,
|
19
|
+
CompletionStartEvent,
|
14
20
|
Event,
|
15
21
|
GenMessageEvent,
|
16
22
|
MessageEvent,
|
17
23
|
ProcPacketOutputEvent,
|
24
|
+
ResponseChunkEvent,
|
25
|
+
ResponseEndEvent,
|
26
|
+
ResponseStartEvent,
|
18
27
|
RunResultEvent,
|
19
28
|
SystemMessageEvent,
|
29
|
+
ThinkingChunkEvent,
|
30
|
+
ThinkingEndEvent,
|
31
|
+
ThinkingStartEvent,
|
32
|
+
ToolCallChunkEvent,
|
33
|
+
ToolCallEndEvent,
|
34
|
+
ToolCallStartEvent,
|
20
35
|
ToolMessageEvent,
|
21
36
|
UserMessageEvent,
|
22
37
|
WorkflowResultEvent,
|
@@ -24,7 +39,14 @@ from grasp_agents.typing.events import (
|
|
24
39
|
|
25
40
|
from .typing.completion import Usage
|
26
41
|
from .typing.content import Content, ContentPartText
|
27
|
-
from .typing.message import
|
42
|
+
from .typing.message import (
|
43
|
+
AssistantMessage,
|
44
|
+
Message,
|
45
|
+
Role,
|
46
|
+
SystemMessage,
|
47
|
+
ToolMessage,
|
48
|
+
UserMessage,
|
49
|
+
)
|
28
50
|
|
29
51
|
logger = logging.getLogger(__name__)
|
30
52
|
|
@@ -72,7 +94,7 @@ class Printer:
|
|
72
94
|
return AVAILABLE_COLORS[idx]
|
73
95
|
|
74
96
|
@staticmethod
|
75
|
-
def content_to_str(content: Content | str, role: Role) -> str:
|
97
|
+
def content_to_str(content: Content | str | None, role: Role) -> str:
|
76
98
|
if role == Role.USER and isinstance(content, Content):
|
77
99
|
content_str_parts: list[str] = []
|
78
100
|
for content_part in content.parts:
|
@@ -84,9 +106,9 @@ class Printer:
|
|
84
106
|
content_str_parts.append("<ENCODED_IMAGE>")
|
85
107
|
return "\n".join(content_str_parts)
|
86
108
|
|
87
|
-
assert isinstance(content, str)
|
109
|
+
assert isinstance(content, str | None)
|
88
110
|
|
89
|
-
return content.strip(" \n")
|
111
|
+
return (content or "").strip(" \n")
|
90
112
|
|
91
113
|
@staticmethod
|
92
114
|
def truncate_content_str(content_str: str, trunc_len: int = 2000) -> str:
|
@@ -179,36 +201,10 @@ def stream_text(new_text: str, color: Color) -> None:
|
|
179
201
|
async def print_event_stream(
|
180
202
|
event_generator: AsyncIterator[Event[Any]],
|
181
203
|
color_by: ColoringMode = "role",
|
182
|
-
trunc_len: int =
|
204
|
+
trunc_len: int = 10000,
|
183
205
|
) -> AsyncIterator[Event[Any]]:
|
184
|
-
prev_chunk_id: str | None = None
|
185
|
-
thinking_open = False
|
186
|
-
response_open = False
|
187
|
-
open_tool_calls: set[str] = set()
|
188
|
-
|
189
206
|
color = Printer.get_role_color(Role.ASSISTANT)
|
190
207
|
|
191
|
-
def _close_blocks(
|
192
|
-
_thinking_open: bool, _response_open: bool, color: Color
|
193
|
-
) -> tuple[bool, bool]:
|
194
|
-
closing_text = ""
|
195
|
-
while open_tool_calls:
|
196
|
-
open_tool_calls.pop()
|
197
|
-
closing_text += "\n</tool call>\n"
|
198
|
-
|
199
|
-
if _thinking_open:
|
200
|
-
closing_text += "\n</thinking>\n"
|
201
|
-
_thinking_open = False
|
202
|
-
|
203
|
-
if _response_open:
|
204
|
-
closing_text += "\n</response>\n"
|
205
|
-
_response_open = False
|
206
|
-
|
207
|
-
if closing_text:
|
208
|
-
stream_text(closing_text, color)
|
209
|
-
|
210
|
-
return _thinking_open, _response_open
|
211
|
-
|
212
208
|
def _get_color(event: Event[Any], role: Role = Role.ASSISTANT) -> Color:
|
213
209
|
if color_by == "agent":
|
214
210
|
return Printer.get_agent_color(event.proc_name or "")
|
@@ -232,10 +228,6 @@ async def print_event_stream(
|
|
232
228
|
text += f"<{src} output>\n"
|
233
229
|
for p in event.data.payloads:
|
234
230
|
if isinstance(p, BaseModel):
|
235
|
-
for field_info in type(p).model_fields.values():
|
236
|
-
if field_info.exclude:
|
237
|
-
field_info.exclude = False
|
238
|
-
type(p).model_rebuild(force=True)
|
239
231
|
p_str = p.model_dump_json(indent=2)
|
240
232
|
else:
|
241
233
|
try:
|
@@ -250,56 +242,64 @@ async def print_event_stream(
|
|
250
242
|
async for event in event_generator:
|
251
243
|
yield event
|
252
244
|
|
253
|
-
if isinstance(event, CompletionChunkEvent)
|
254
|
-
|
255
|
-
|
256
|
-
new_completion = chunk_id != prev_chunk_id
|
245
|
+
if isinstance(event, CompletionChunkEvent) and isinstance(
|
246
|
+
event.data, CompletionChunk
|
247
|
+
):
|
257
248
|
color = _get_color(event, Role.ASSISTANT)
|
258
249
|
|
259
250
|
text = ""
|
260
251
|
|
261
|
-
if
|
262
|
-
thinking_open, response_open = _close_blocks(
|
263
|
-
thinking_open, response_open, color
|
264
|
-
)
|
252
|
+
if isinstance(event, CompletionStartEvent):
|
265
253
|
text += f"\n<{event.proc_name}> [{event.call_id}]\n"
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
254
|
+
elif isinstance(event, ThinkingStartEvent):
|
255
|
+
text += "<thinking>\n"
|
256
|
+
elif isinstance(event, ResponseStartEvent):
|
257
|
+
text += "<response>\n"
|
258
|
+
elif isinstance(event, ToolCallStartEvent):
|
259
|
+
tc = event.data.tool_call
|
260
|
+
text += f"<tool call> {tc.tool_name} [{tc.id}]\n"
|
261
|
+
elif isinstance(event, AnnotationsStartEvent):
|
262
|
+
text += "<annotations>\n"
|
263
|
+
|
264
|
+
# if isinstance(event, CompletionEndEvent):
|
265
|
+
# text += f"\n</{event.proc_name}>\n"
|
266
|
+
if isinstance(event, ThinkingEndEvent):
|
273
267
|
text += "\n</thinking>\n"
|
274
|
-
|
275
|
-
|
276
|
-
if delta.content:
|
277
|
-
if not response_open:
|
278
|
-
text += "<response>\n"
|
279
|
-
response_open = True
|
280
|
-
text += delta.content
|
281
|
-
elif response_open:
|
268
|
+
elif isinstance(event, ResponseEndEvent):
|
282
269
|
text += "\n</response>\n"
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
270
|
+
elif isinstance(event, ToolCallEndEvent):
|
271
|
+
text += "\n</tool call>\n"
|
272
|
+
elif isinstance(event, AnnotationsEndEvent):
|
273
|
+
text += "\n</annotations>\n"
|
274
|
+
|
275
|
+
if isinstance(event, ThinkingChunkEvent):
|
276
|
+
thinking = event.data.thinking
|
277
|
+
if isinstance(thinking, str):
|
278
|
+
text += thinking
|
279
|
+
else:
|
280
|
+
text = "\n".join(
|
281
|
+
[block.get("thinking", "[redacted]") for block in thinking]
|
282
|
+
)
|
283
|
+
|
284
|
+
if isinstance(event, ResponseChunkEvent):
|
285
|
+
text += event.data.response
|
286
|
+
|
287
|
+
if isinstance(event, ToolCallChunkEvent):
|
288
|
+
text += event.data.tool_call.tool_arguments or ""
|
289
|
+
|
290
|
+
if isinstance(event, AnnotationsChunkEvent):
|
291
|
+
text += "\n".join(
|
292
|
+
[
|
293
|
+
json.dumps(annotation, indent=2)
|
294
|
+
for annotation in event.data.annotations
|
295
|
+
]
|
296
|
+
)
|
293
297
|
|
294
298
|
stream_text(text, color)
|
295
|
-
prev_chunk_id = chunk_id
|
296
|
-
|
297
|
-
else:
|
298
|
-
thinking_open, response_open = _close_blocks(
|
299
|
-
thinking_open, response_open, color
|
300
|
-
)
|
301
299
|
|
302
300
|
if isinstance(event, MessageEvent) and not isinstance(event, GenMessageEvent):
|
301
|
+
assert isinstance(event.data, (SystemMessage | UserMessage | ToolMessage))
|
302
|
+
|
303
303
|
message = event.data
|
304
304
|
role = message.role
|
305
305
|
content = Printer.content_to_str(message.content, role=role)
|
@@ -317,6 +317,7 @@ async def print_event_stream(
|
|
317
317
|
text += f"<input>\n{content}\n</input>\n"
|
318
318
|
|
319
319
|
elif isinstance(event, ToolMessageEvent):
|
320
|
+
message = event.data
|
320
321
|
try:
|
321
322
|
content = json.dumps(json.loads(content), indent=2)
|
322
323
|
except Exception:
|
@@ -1,38 +1,106 @@
|
|
1
|
-
import asyncio
|
2
1
|
import logging
|
3
|
-
from abc import ABC
|
4
|
-
from collections.abc import AsyncIterator,
|
5
|
-
from
|
2
|
+
from abc import ABC, abstractmethod
|
3
|
+
from collections.abc import AsyncIterator, Callable, Coroutine
|
4
|
+
from functools import wraps
|
5
|
+
from typing import (
|
6
|
+
Any,
|
7
|
+
ClassVar,
|
8
|
+
Generic,
|
9
|
+
Protocol,
|
10
|
+
TypeVar,
|
11
|
+
cast,
|
12
|
+
final,
|
13
|
+
)
|
6
14
|
from uuid import uuid4
|
7
15
|
|
8
16
|
from pydantic import BaseModel, TypeAdapter
|
9
17
|
from pydantic import ValidationError as PydanticValidationError
|
10
18
|
|
11
|
-
from
|
19
|
+
from ..errors import (
|
12
20
|
PacketRoutingError,
|
13
21
|
ProcInputValidationError,
|
14
22
|
ProcOutputValidationError,
|
15
23
|
ProcRunError,
|
16
24
|
)
|
17
|
-
from
|
18
|
-
from
|
19
|
-
from
|
20
|
-
from
|
21
|
-
from
|
25
|
+
from ..generics_utils import AutoInstanceAttributesMixin
|
26
|
+
from ..memory import DummyMemory, MemT
|
27
|
+
from ..packet import Packet
|
28
|
+
from ..run_context import CtxT, RunContext
|
29
|
+
from ..typing.events import (
|
30
|
+
DummyEvent,
|
22
31
|
Event,
|
23
|
-
ProcPacketOutputEvent,
|
24
|
-
ProcPayloadOutputEvent,
|
25
32
|
ProcStreamingErrorData,
|
26
33
|
ProcStreamingErrorEvent,
|
27
34
|
)
|
28
|
-
from
|
29
|
-
from
|
30
|
-
from .utils import stream_concurrent
|
35
|
+
from ..typing.io import InT, OutT, ProcName
|
36
|
+
from ..typing.tool import BaseTool
|
31
37
|
|
32
38
|
logger = logging.getLogger(__name__)
|
33
39
|
|
34
40
|
_OutT_contra = TypeVar("_OutT_contra", contravariant=True)
|
35
41
|
|
42
|
+
F = TypeVar("F", bound=Callable[..., Coroutine[Any, Any, Packet[Any]]])
|
43
|
+
F_stream = TypeVar("F_stream", bound=Callable[..., AsyncIterator[Event[Any]]])
|
44
|
+
|
45
|
+
|
46
|
+
def with_retry(func: F) -> F:
|
47
|
+
@wraps(func)
|
48
|
+
async def wrapper(
|
49
|
+
self: "BaseProcessor[Any, Any, Any, Any]", *args: Any, **kwargs: Any
|
50
|
+
) -> Packet[Any]:
|
51
|
+
call_id = kwargs.get("call_id", "unknown")
|
52
|
+
for n_attempt in range(self.max_retries + 1):
|
53
|
+
try:
|
54
|
+
return await func(self, *args, **kwargs)
|
55
|
+
except Exception as err:
|
56
|
+
err_message = (
|
57
|
+
f"\nProcessor run failed [proc_name={self.name}; call_id={call_id}]"
|
58
|
+
)
|
59
|
+
if n_attempt == self.max_retries:
|
60
|
+
if self.max_retries == 0:
|
61
|
+
logger.warning(f"{err_message}:\n{err}")
|
62
|
+
else:
|
63
|
+
logger.warning(f"{err_message} after retrying:\n{err}")
|
64
|
+
raise ProcRunError(proc_name=self.name, call_id=call_id) from err
|
65
|
+
|
66
|
+
logger.warning(f"{err_message} (retry attempt {n_attempt + 1}):\n{err}")
|
67
|
+
# This part should not be reachable due to the raise in the loop
|
68
|
+
raise ProcRunError(proc_name=self.name, call_id=call_id)
|
69
|
+
|
70
|
+
return cast("F", wrapper)
|
71
|
+
|
72
|
+
|
73
|
+
def with_retry_stream(func: F_stream) -> F_stream:
|
74
|
+
@wraps(func)
|
75
|
+
async def wrapper(
|
76
|
+
self: "BaseProcessor[Any, Any, Any, Any]", *args: Any, **kwargs: Any
|
77
|
+
) -> AsyncIterator[Event[Any]]:
|
78
|
+
call_id = kwargs.get("call_id", "unknown")
|
79
|
+
for n_attempt in range(self.max_retries + 1):
|
80
|
+
try:
|
81
|
+
async for event in func(self, *args, **kwargs):
|
82
|
+
yield event
|
83
|
+
return
|
84
|
+
except Exception as err:
|
85
|
+
err_data = ProcStreamingErrorData(error=err, call_id=call_id)
|
86
|
+
yield ProcStreamingErrorEvent(
|
87
|
+
data=err_data, proc_name=self.name, call_id=call_id
|
88
|
+
)
|
89
|
+
err_message = (
|
90
|
+
"\nStreaming processor run failed "
|
91
|
+
f"[proc_name={self.name}; call_id={call_id}]"
|
92
|
+
)
|
93
|
+
if n_attempt == self.max_retries:
|
94
|
+
if self.max_retries == 0:
|
95
|
+
logger.warning(f"{err_message}:\n{err}")
|
96
|
+
else:
|
97
|
+
logger.warning(f"{err_message} after retrying:\n{err}")
|
98
|
+
raise ProcRunError(proc_name=self.name, call_id=call_id) from err
|
99
|
+
|
100
|
+
logger.warning(f"{err_message} (retry attempt {n_attempt}):\n{err}")
|
101
|
+
|
102
|
+
return cast("F_stream", wrapper)
|
103
|
+
|
36
104
|
|
37
105
|
class RecipientSelector(Protocol[_OutT_contra, CtxT]):
|
38
106
|
def __call__(
|
@@ -40,7 +108,7 @@ class RecipientSelector(Protocol[_OutT_contra, CtxT]):
|
|
40
108
|
) -> list[ProcName] | None: ...
|
41
109
|
|
42
110
|
|
43
|
-
class
|
111
|
+
class BaseProcessor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT, MemT, CtxT]):
|
44
112
|
_generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
|
45
113
|
0: "_in_type",
|
46
114
|
1: "_out_type",
|
@@ -66,7 +134,6 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT, MemT, CtxT]
|
|
66
134
|
|
67
135
|
self.recipient_selector: RecipientSelector[OutT, CtxT] | None
|
68
136
|
if not hasattr(type(self), "recipient_selector"):
|
69
|
-
# Set to None if not defined in the subclass
|
70
137
|
self.recipient_selector = None
|
71
138
|
|
72
139
|
@property
|
@@ -183,19 +250,6 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT, MemT, CtxT]
|
|
183
250
|
allowed_recipients=cast("list[str]", self.recipients),
|
184
251
|
)
|
185
252
|
|
186
|
-
def _validate_par_recipients(
|
187
|
-
self, out_packets: Sequence[Packet[OutT]], call_id: str
|
188
|
-
) -> None:
|
189
|
-
recipient_sets = [set(p.recipients or []) for p in out_packets]
|
190
|
-
same_recipients = all(rs == recipient_sets[0] for rs in recipient_sets)
|
191
|
-
if not same_recipients:
|
192
|
-
raise PacketRoutingError(
|
193
|
-
proc_name=self.name,
|
194
|
-
call_id=call_id,
|
195
|
-
message="Parallel runs must return the same recipients "
|
196
|
-
f"[proc_name={self.name}; call_id={call_id}]",
|
197
|
-
)
|
198
|
-
|
199
253
|
@final
|
200
254
|
def _select_recipients(
|
201
255
|
self, output: OutT, ctx: RunContext[CtxT] | None = None
|
@@ -212,108 +266,7 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT, MemT, CtxT]
|
|
212
266
|
|
213
267
|
return func
|
214
268
|
|
215
|
-
|
216
|
-
self,
|
217
|
-
chat_inputs: Any | None = None,
|
218
|
-
*,
|
219
|
-
in_args: InT | None = None,
|
220
|
-
memory: MemT,
|
221
|
-
call_id: str,
|
222
|
-
ctx: RunContext[CtxT] | None = None,
|
223
|
-
) -> OutT:
|
224
|
-
return cast("OutT", in_args)
|
225
|
-
|
226
|
-
async def _process_stream(
|
227
|
-
self,
|
228
|
-
chat_inputs: Any | None = None,
|
229
|
-
*,
|
230
|
-
in_args: InT | None = None,
|
231
|
-
memory: MemT,
|
232
|
-
call_id: str,
|
233
|
-
ctx: RunContext[CtxT] | None = None,
|
234
|
-
) -> AsyncIterator[Event[Any]]:
|
235
|
-
output = cast("OutT", in_args)
|
236
|
-
yield ProcPayloadOutputEvent(data=output, proc_name=self.name, call_id=call_id)
|
237
|
-
|
238
|
-
async def _run_single_once(
|
239
|
-
self,
|
240
|
-
chat_inputs: Any | None = None,
|
241
|
-
*,
|
242
|
-
in_args: InT | None = None,
|
243
|
-
forgetful: bool = False,
|
244
|
-
call_id: str,
|
245
|
-
ctx: RunContext[CtxT] | None = None,
|
246
|
-
) -> Packet[OutT]:
|
247
|
-
_memory = self.memory.model_copy(deep=True) if forgetful else self.memory
|
248
|
-
|
249
|
-
output = await self._process(
|
250
|
-
chat_inputs=chat_inputs,
|
251
|
-
in_args=in_args,
|
252
|
-
memory=_memory,
|
253
|
-
call_id=call_id,
|
254
|
-
ctx=ctx,
|
255
|
-
)
|
256
|
-
val_output = self._validate_output(output, call_id=call_id)
|
257
|
-
|
258
|
-
recipients = self._select_recipients(output=val_output, ctx=ctx)
|
259
|
-
self._validate_recipients(recipients, call_id=call_id)
|
260
|
-
|
261
|
-
return Packet(payloads=[val_output], sender=self.name, recipients=recipients)
|
262
|
-
|
263
|
-
async def _run_single(
|
264
|
-
self,
|
265
|
-
chat_inputs: Any | None = None,
|
266
|
-
*,
|
267
|
-
in_args: InT | None = None,
|
268
|
-
forgetful: bool = False,
|
269
|
-
call_id: str,
|
270
|
-
ctx: RunContext[CtxT] | None = None,
|
271
|
-
) -> Packet[OutT]:
|
272
|
-
n_attempt = 0
|
273
|
-
while n_attempt <= self.max_retries:
|
274
|
-
try:
|
275
|
-
return await self._run_single_once(
|
276
|
-
chat_inputs=chat_inputs,
|
277
|
-
in_args=in_args,
|
278
|
-
forgetful=forgetful,
|
279
|
-
call_id=call_id,
|
280
|
-
ctx=ctx,
|
281
|
-
)
|
282
|
-
except Exception as err:
|
283
|
-
err_message = (
|
284
|
-
f"\nProcessor run failed [proc_name={self.name}; call_id={call_id}]"
|
285
|
-
)
|
286
|
-
n_attempt += 1
|
287
|
-
if n_attempt > self.max_retries:
|
288
|
-
if n_attempt == 1:
|
289
|
-
logger.warning(f"{err_message}:\n{err}")
|
290
|
-
if n_attempt > 1:
|
291
|
-
logger.warning(f"{err_message} after retrying:\n{err}")
|
292
|
-
raise ProcRunError(proc_name=self.name, call_id=call_id) from err
|
293
|
-
|
294
|
-
logger.warning(f"{err_message} (retry attempt {n_attempt}):\n{err}")
|
295
|
-
|
296
|
-
raise ProcRunError(proc_name=self.name, call_id=call_id)
|
297
|
-
|
298
|
-
async def _run_par(
|
299
|
-
self, in_args: list[InT], call_id: str, ctx: RunContext[CtxT] | None = None
|
300
|
-
) -> Packet[OutT]:
|
301
|
-
tasks = [
|
302
|
-
self._run_single(
|
303
|
-
in_args=inp, forgetful=True, call_id=f"{call_id}/{idx}", ctx=ctx
|
304
|
-
)
|
305
|
-
for idx, inp in enumerate(in_args)
|
306
|
-
]
|
307
|
-
out_packets = await asyncio.gather(*tasks)
|
308
|
-
|
309
|
-
self._validate_par_recipients(out_packets, call_id=call_id)
|
310
|
-
|
311
|
-
return Packet(
|
312
|
-
payloads=[out_packet.payloads[0] for out_packet in out_packets],
|
313
|
-
sender=self.name,
|
314
|
-
recipients=out_packets[0].recipients,
|
315
|
-
)
|
316
|
-
|
269
|
+
@abstractmethod
|
317
270
|
async def run(
|
318
271
|
self,
|
319
272
|
chat_inputs: Any | None = None,
|
@@ -324,140 +277,9 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT, MemT, CtxT]
|
|
324
277
|
call_id: str | None = None,
|
325
278
|
ctx: RunContext[CtxT] | None = None,
|
326
279
|
) -> Packet[OutT]:
|
327
|
-
|
328
|
-
|
329
|
-
val_in_args = self._validate_inputs(
|
330
|
-
call_id=call_id,
|
331
|
-
chat_inputs=chat_inputs,
|
332
|
-
in_packet=in_packet,
|
333
|
-
in_args=in_args,
|
334
|
-
)
|
335
|
-
|
336
|
-
if val_in_args and len(val_in_args) > 1:
|
337
|
-
return await self._run_par(in_args=val_in_args, call_id=call_id, ctx=ctx)
|
338
|
-
return await self._run_single(
|
339
|
-
chat_inputs=chat_inputs,
|
340
|
-
in_args=val_in_args[0] if val_in_args else None,
|
341
|
-
forgetful=forgetful,
|
342
|
-
call_id=call_id,
|
343
|
-
ctx=ctx,
|
344
|
-
)
|
345
|
-
|
346
|
-
async def _run_single_stream_once(
|
347
|
-
self,
|
348
|
-
chat_inputs: Any | None = None,
|
349
|
-
*,
|
350
|
-
in_args: InT | None = None,
|
351
|
-
forgetful: bool = False,
|
352
|
-
call_id: str,
|
353
|
-
ctx: RunContext[CtxT] | None = None,
|
354
|
-
) -> AsyncIterator[Event[Any]]:
|
355
|
-
_memory = self.memory.model_copy(deep=True) if forgetful else self.memory
|
356
|
-
|
357
|
-
output: OutT | None = None
|
358
|
-
async for event in self._process_stream(
|
359
|
-
chat_inputs=chat_inputs,
|
360
|
-
in_args=in_args,
|
361
|
-
memory=_memory,
|
362
|
-
call_id=call_id,
|
363
|
-
ctx=ctx,
|
364
|
-
):
|
365
|
-
if isinstance(event, ProcPayloadOutputEvent):
|
366
|
-
output = event.data
|
367
|
-
yield event
|
368
|
-
|
369
|
-
assert output is not None
|
370
|
-
|
371
|
-
val_output = self._validate_output(output, call_id=call_id)
|
372
|
-
|
373
|
-
recipients = self._select_recipients(output=val_output, ctx=ctx)
|
374
|
-
self._validate_recipients(recipients, call_id=call_id)
|
375
|
-
|
376
|
-
out_packet = Packet[OutT](
|
377
|
-
payloads=[val_output], sender=self.name, recipients=recipients
|
378
|
-
)
|
379
|
-
|
380
|
-
yield ProcPacketOutputEvent(
|
381
|
-
data=out_packet, proc_name=self.name, call_id=call_id
|
382
|
-
)
|
383
|
-
|
384
|
-
async def _run_single_stream(
|
385
|
-
self,
|
386
|
-
chat_inputs: Any | None = None,
|
387
|
-
*,
|
388
|
-
in_args: InT | None = None,
|
389
|
-
forgetful: bool = False,
|
390
|
-
call_id: str,
|
391
|
-
ctx: RunContext[CtxT] | None = None,
|
392
|
-
) -> AsyncIterator[Event[Any]]:
|
393
|
-
n_attempt = 0
|
394
|
-
while n_attempt <= self.max_retries:
|
395
|
-
try:
|
396
|
-
async for event in self._run_single_stream_once(
|
397
|
-
chat_inputs=chat_inputs,
|
398
|
-
in_args=in_args,
|
399
|
-
forgetful=forgetful,
|
400
|
-
call_id=call_id,
|
401
|
-
ctx=ctx,
|
402
|
-
):
|
403
|
-
yield event
|
404
|
-
|
405
|
-
return
|
406
|
-
|
407
|
-
except Exception as err:
|
408
|
-
err_data = ProcStreamingErrorData(error=err, call_id=call_id)
|
409
|
-
yield ProcStreamingErrorEvent(
|
410
|
-
data=err_data, proc_name=self.name, call_id=call_id
|
411
|
-
)
|
412
|
-
|
413
|
-
err_message = (
|
414
|
-
"\nStreaming processor run failed "
|
415
|
-
f"[proc_name={self.name}; call_id={call_id}]"
|
416
|
-
)
|
417
|
-
|
418
|
-
n_attempt += 1
|
419
|
-
if n_attempt > self.max_retries:
|
420
|
-
if n_attempt == 1:
|
421
|
-
logger.warning(f"{err_message}:\n{err}")
|
422
|
-
if n_attempt > 1:
|
423
|
-
logger.warning(f"{err_message} after retrying:\n{err}")
|
424
|
-
raise ProcRunError(proc_name=self.name, call_id=call_id) from err
|
425
|
-
|
426
|
-
logger.warning(f"{err_message} (retry attempt {n_attempt}):\n{err}")
|
427
|
-
|
428
|
-
async def _run_par_stream(
|
429
|
-
self,
|
430
|
-
in_args: list[InT],
|
431
|
-
call_id: str,
|
432
|
-
ctx: RunContext[CtxT] | None = None,
|
433
|
-
) -> AsyncIterator[Event[Any]]:
|
434
|
-
streams = [
|
435
|
-
self._run_single_stream(
|
436
|
-
in_args=inp, forgetful=True, call_id=f"{call_id}/{idx}", ctx=ctx
|
437
|
-
)
|
438
|
-
for idx, inp in enumerate(in_args)
|
439
|
-
]
|
440
|
-
|
441
|
-
out_packets_map: dict[int, Packet[OutT]] = {}
|
442
|
-
async for idx, event in stream_concurrent(streams):
|
443
|
-
if isinstance(event, ProcPacketOutputEvent):
|
444
|
-
out_packets_map[idx] = event.data
|
445
|
-
else:
|
446
|
-
yield event
|
447
|
-
|
448
|
-
out_packet = Packet(
|
449
|
-
payloads=[
|
450
|
-
out_packet.payloads[0]
|
451
|
-
for _, out_packet in sorted(out_packets_map.items())
|
452
|
-
],
|
453
|
-
sender=self.name,
|
454
|
-
recipients=out_packets_map[0].recipients,
|
455
|
-
)
|
456
|
-
|
457
|
-
yield ProcPacketOutputEvent(
|
458
|
-
data=out_packet, proc_name=self.name, call_id=call_id
|
459
|
-
)
|
280
|
+
pass
|
460
281
|
|
282
|
+
@abstractmethod
|
461
283
|
async def run_stream(
|
462
284
|
self,
|
463
285
|
chat_inputs: Any | None = None,
|
@@ -468,27 +290,7 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT, MemT, CtxT]
|
|
468
290
|
call_id: str | None = None,
|
469
291
|
ctx: RunContext[CtxT] | None = None,
|
470
292
|
) -> AsyncIterator[Event[Any]]:
|
471
|
-
|
472
|
-
|
473
|
-
val_in_args = self._validate_inputs(
|
474
|
-
call_id=call_id,
|
475
|
-
chat_inputs=chat_inputs,
|
476
|
-
in_packet=in_packet,
|
477
|
-
in_args=in_args,
|
478
|
-
)
|
479
|
-
|
480
|
-
if val_in_args and len(val_in_args) > 1:
|
481
|
-
stream = self._run_par_stream(in_args=val_in_args, call_id=call_id, ctx=ctx)
|
482
|
-
else:
|
483
|
-
stream = self._run_single_stream(
|
484
|
-
chat_inputs=chat_inputs,
|
485
|
-
in_args=val_in_args[0] if val_in_args else None,
|
486
|
-
forgetful=forgetful,
|
487
|
-
call_id=call_id,
|
488
|
-
ctx=ctx,
|
489
|
-
)
|
490
|
-
async for event in stream:
|
491
|
-
yield event
|
293
|
+
yield DummyEvent()
|
492
294
|
|
493
295
|
@final
|
494
296
|
def as_tool(
|
@@ -510,7 +312,7 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT, MemT, CtxT]
|
|
510
312
|
|
511
313
|
async def run(self, inp: InT, ctx: RunContext[CtxT] | None = None) -> OutT:
|
512
314
|
result = await processor_instance.run(
|
513
|
-
in_args=
|
315
|
+
in_args=inp, forgetful=True, ctx=ctx
|
514
316
|
)
|
515
317
|
|
516
318
|
return result.payloads[0]
|