grasp_agents 0.5.4__py3-none-any.whl → 0.5.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_agents/__init__.py +6 -3
- grasp_agents/llm.py +5 -1
- grasp_agents/llm_agent.py +100 -107
- grasp_agents/llm_agent_memory.py +1 -1
- grasp_agents/llm_policy_executor.py +15 -13
- grasp_agents/packet_pool.py +6 -1
- grasp_agents/printer.py +8 -5
- grasp_agents/processors/base_processor.py +320 -0
- grasp_agents/processors/parallel_processor.py +244 -0
- grasp_agents/processors/processor.py +161 -0
- grasp_agents/prompt_builder.py +22 -60
- grasp_agents/run_context.py +3 -8
- grasp_agents/runner.py +20 -1
- grasp_agents/typing/events.py +4 -0
- grasp_agents/typing/io.py +0 -7
- grasp_agents/workflow/looped_workflow.py +35 -27
- grasp_agents/workflow/sequential_workflow.py +14 -3
- grasp_agents/workflow/workflow_processor.py +32 -11
- {grasp_agents-0.5.4.dist-info → grasp_agents-0.5.6.dist-info}/METADATA +1 -1
- {grasp_agents-0.5.4.dist-info → grasp_agents-0.5.6.dist-info}/RECORD +22 -20
- grasp_agents/processor.py +0 -512
- {grasp_agents-0.5.4.dist-info → grasp_agents-0.5.6.dist-info}/WHEEL +0 -0
- {grasp_agents-0.5.4.dist-info → grasp_agents-0.5.6.dist-info}/licenses/LICENSE.md +0 -0
@@ -4,23 +4,23 @@ from typing import Any, Generic
|
|
4
4
|
|
5
5
|
from ..errors import WorkflowConstructionError
|
6
6
|
from ..packet import Packet
|
7
|
-
from ..
|
7
|
+
from ..processors.base_processor import BaseProcessor, RecipientSelector
|
8
8
|
from ..run_context import CtxT, RunContext
|
9
|
-
from ..typing.events import Event
|
9
|
+
from ..typing.events import DummyEvent, Event
|
10
10
|
from ..typing.io import InT, OutT, ProcName
|
11
11
|
|
12
12
|
|
13
13
|
class WorkflowProcessor(
|
14
|
-
|
14
|
+
BaseProcessor[InT, OutT, Any, CtxT],
|
15
15
|
ABC,
|
16
16
|
Generic[InT, OutT, CtxT],
|
17
17
|
):
|
18
18
|
def __init__(
|
19
19
|
self,
|
20
20
|
name: ProcName,
|
21
|
-
subprocs: Sequence[
|
22
|
-
start_proc:
|
23
|
-
end_proc:
|
21
|
+
subprocs: Sequence[BaseProcessor[Any, Any, Any, CtxT]],
|
22
|
+
start_proc: BaseProcessor[InT, Any, Any, CtxT],
|
23
|
+
end_proc: BaseProcessor[Any, OutT, Any, CtxT],
|
24
24
|
recipients: list[ProcName] | None = None,
|
25
25
|
max_retries: int = 0,
|
26
26
|
) -> None:
|
@@ -44,20 +44,41 @@ class WorkflowProcessor(
|
|
44
44
|
self._start_proc = start_proc
|
45
45
|
self._end_proc = end_proc
|
46
46
|
|
47
|
+
self.recipients = recipients
|
48
|
+
if hasattr(type(self), "recipient_selector"):
|
49
|
+
self._end_proc.recipient_selector = self.recipient_selector
|
50
|
+
|
51
|
+
def add_recipient_selector(
|
52
|
+
self, func: RecipientSelector[OutT, CtxT]
|
53
|
+
) -> RecipientSelector[OutT, CtxT]:
|
54
|
+
self._end_proc.recipient_selector = func
|
55
|
+
self.recipient_selector = func
|
56
|
+
|
57
|
+
return func
|
58
|
+
|
47
59
|
@property
|
48
|
-
def
|
60
|
+
def recipients(self) -> list[ProcName] | None:
|
61
|
+
return self._end_proc.recipients
|
62
|
+
|
63
|
+
@recipients.setter
|
64
|
+
def recipients(self, value: list[ProcName] | None) -> None:
|
65
|
+
if hasattr(self, "_end_proc"):
|
66
|
+
self._end_proc.recipients = value
|
67
|
+
|
68
|
+
@property
|
69
|
+
def subprocs(self) -> Sequence[BaseProcessor[Any, Any, Any, CtxT]]:
|
49
70
|
return self._subprocs
|
50
71
|
|
51
72
|
@property
|
52
|
-
def start_proc(self) ->
|
73
|
+
def start_proc(self) -> BaseProcessor[InT, Any, Any, CtxT]:
|
53
74
|
return self._start_proc
|
54
75
|
|
55
76
|
@property
|
56
|
-
def end_proc(self) ->
|
77
|
+
def end_proc(self) -> BaseProcessor[Any, OutT, Any, CtxT]:
|
57
78
|
return self._end_proc
|
58
79
|
|
59
80
|
def _generate_subproc_call_id(
|
60
|
-
self, call_id: str | None, subproc:
|
81
|
+
self, call_id: str | None, subproc: BaseProcessor[Any, Any, Any, CtxT]
|
61
82
|
) -> str | None:
|
62
83
|
return f"{self._generate_call_id(call_id)}/{subproc.name}"
|
63
84
|
|
@@ -85,4 +106,4 @@ class WorkflowProcessor(
|
|
85
106
|
forgetful: bool = False,
|
86
107
|
call_id: str | None = None,
|
87
108
|
) -> AsyncIterator[Event[Any]]:
|
88
|
-
|
109
|
+
yield DummyEvent()
|
@@ -1,22 +1,21 @@
|
|
1
|
-
grasp_agents/__init__.py,sha256=
|
1
|
+
grasp_agents/__init__.py,sha256=Z3a_j2Etiap9H6lvE8-PQP_OIGMUcHNPeJAJO12B8kY,1031
|
2
2
|
grasp_agents/cloud_llm.py,sha256=C6xrKYhiQb4tNXK_rFo2pNlVTXPS_gYd5uevAnpLFeE,13119
|
3
3
|
grasp_agents/costs_dict.yaml,sha256=2MFNWtkv5W5WSCcv1Cj13B1iQLVv5Ot9pS_KW2Gu2DA,2510
|
4
4
|
grasp_agents/errors.py,sha256=K-22TCM1Klhsej47Rg5eTqnGiGPaXgKOpdOZZ7cPipw,4633
|
5
5
|
grasp_agents/generics_utils.py,sha256=5Pw3I9dlnKC2VGqYKC4ZZUO3Z_vTNT-NPFovNfPkl6I,6542
|
6
6
|
grasp_agents/grasp_logging.py,sha256=H1GYhXdQvVkmauFDZ-KDwvVmPQHZUUm9sRqX_ObK2xI,1111
|
7
7
|
grasp_agents/http_client.py,sha256=Es8NXGDkp4Nem7g24-jW0KFGA9Hp_o2Cv3cOvjup-iU,859
|
8
|
-
grasp_agents/llm.py,sha256=
|
9
|
-
grasp_agents/llm_agent.py,sha256=
|
10
|
-
grasp_agents/llm_agent_memory.py,sha256=
|
11
|
-
grasp_agents/llm_policy_executor.py,sha256=
|
8
|
+
grasp_agents/llm.py,sha256=ZkAeGEkpMsOY6T_zj2pL3ZWkP0mwDN_e9ArN99qZfmY,6574
|
9
|
+
grasp_agents/llm_agent.py,sha256=hX3T2Y5qiTt5CrsahNo5t08HFCyBWEiurzYnFykJN9Y,13513
|
10
|
+
grasp_agents/llm_agent_memory.py,sha256=gQwH3g4Ib3ciW2jrBiW13ttwax_pcPobH5RhXRmbc0E,1842
|
11
|
+
grasp_agents/llm_policy_executor.py,sha256=glusWe4wTIh_rl3bXPbAMueJ5eDfJxPYBTzoecrHPYg,16849
|
12
12
|
grasp_agents/memory.py,sha256=keHuNEZNSxHT9FKpMohHOCNi7UAz_oRIc91IQEuzaWE,1162
|
13
13
|
grasp_agents/packet.py,sha256=EmE-W4ZSMVZoqClECGFe7OGqrT4FSJ8IVGICrdjtdEY,1462
|
14
|
-
grasp_agents/packet_pool.py,sha256=
|
15
|
-
grasp_agents/printer.py,sha256=
|
16
|
-
grasp_agents/
|
17
|
-
grasp_agents/
|
18
|
-
grasp_agents/
|
19
|
-
grasp_agents/runner.py,sha256=nZXK5OuM7QbLyzNA-rfXD0YDtdftTK5v2cCpeujhe2k,4479
|
14
|
+
grasp_agents/packet_pool.py,sha256=i0g4O_fnpsYTU3LNjKgM4TiiiJkHA3YeZP9Y5GatM_I,5082
|
15
|
+
grasp_agents/printer.py,sha256=wtCH75DgWAwjJntr57kK2e7dagcfliUAfefYHpiwWi0,11203
|
16
|
+
grasp_agents/prompt_builder.py,sha256=UuQNnvjrhzd3_NVnvHlCs-NRNRMo4jsMePgZeIxmzSY,5894
|
17
|
+
grasp_agents/run_context.py,sha256=ikakNK1khm0UBEIPETB508BL0IlOKbOUPuq0FZ-iQHQ,942
|
18
|
+
grasp_agents/runner.py,sha256=sRuKkX8Iopxv_MdGZf22b_AsuH4FIPY28rl15RqwGEY,5005
|
20
19
|
grasp_agents/usage_tracker.py,sha256=ZQfVUUpG0C89hyPWT_JgXnjQOxoYmumcQ9t-aCfcMo8,3561
|
21
20
|
grasp_agents/utils.py,sha256=qKmGBwrQHw1-BgqRLuGTPKGs3J_zbrpk3nxnP1iZBiQ,6152
|
22
21
|
grasp_agents/litellm/__init__.py,sha256=wD8RZBYokFDfbS9Cs7nO_zKb3w7RIVwEGj7g2D5CJH0,4510
|
@@ -33,6 +32,9 @@ grasp_agents/openai/converters.py,sha256=CXHF2GehEHLEzjL45HywZ_1qaB3N29-lbac5oBD
|
|
33
32
|
grasp_agents/openai/message_converters.py,sha256=fhSN81uK51EGbLyM2-f0MvPX_UBrMy7SF3JQPo-dkXg,4686
|
34
33
|
grasp_agents/openai/openai_llm.py,sha256=uJbbCytqpv8OCncKdzpiUdkVh3mJWgo95Y9Xetk_Ptg,10556
|
35
34
|
grasp_agents/openai/tool_converters.py,sha256=IotZvpe3xMQcBfcjUTfAsn4LtZljj3zkU9bfpcoiqPw,1177
|
35
|
+
grasp_agents/processors/base_processor.py,sha256=j2_QY6HUjckdxfsf7yAF0xRDp_V-DNDb7hIRMRKUyWw,10685
|
36
|
+
grasp_agents/processors/parallel_processor.py,sha256=EoB1sqtnuv362S1MTUyn3-TUSO4M5Vpe86Q9HlBmbok,7816
|
37
|
+
grasp_agents/processors/processor.py,sha256=v7Bf6CGVsjb43XuOtTMuev9UedMy_lBTGifzzCZwh4Q,5157
|
36
38
|
grasp_agents/rate_limiting/__init__.py,sha256=KRgtF_E7R3YfA2cpYcFcZ7wycV0pWVJ0xRQC7YhiIEQ,158
|
37
39
|
grasp_agents/rate_limiting/rate_limiter_chunked.py,sha256=BPgkUXvhmZhTpZs2T6uujNFuxH_kYHiISuf6_-eNhUc,5544
|
38
40
|
grasp_agents/rate_limiting/types.py,sha256=PbnNhEAcYedQdIpPJWud8HUVcxa_xZS2RDZu4c5jr40,1003
|
@@ -42,15 +44,15 @@ grasp_agents/typing/completion.py,sha256=PHJ01m7WI2KYQL8w7W2ti6hMsKEZnzYGaxbNcBC
|
|
42
44
|
grasp_agents/typing/completion_chunk.py,sha256=t6PvkDWQxRN5xA4efBdeou46RifMFodBmZc45Sx7qxQ,7610
|
43
45
|
grasp_agents/typing/content.py,sha256=XFmLpNWkGhkw5JujO6UsYwhzTHkU67PfhzaXH2waLcQ,3659
|
44
46
|
grasp_agents/typing/converters.py,sha256=kHlocHQS8QnduZOzNPbj3aRD8JpvJd53oudYqWdOxKE,2978
|
45
|
-
grasp_agents/typing/events.py,sha256=
|
46
|
-
grasp_agents/typing/io.py,sha256=
|
47
|
+
grasp_agents/typing/events.py,sha256=UB0G2mQLmyLHYdEp1YnFi8XF2jpWeUZ4TLFnY64lK0I,5450
|
48
|
+
grasp_agents/typing/io.py,sha256=MGEoUjAwKH1AHYglFkKNpHiielw-NFf13Epg3B4Q7Iw,139
|
47
49
|
grasp_agents/typing/message.py,sha256=o7bN84AgrC5Fm3Wx20gqL9ArAMcEtYvnHnXbb04ngCs,3224
|
48
50
|
grasp_agents/typing/tool.py,sha256=4N-k_GvHVPAFyVyEq7z_LYKkA24iQlGoVYiWCzGTgT4,1786
|
49
51
|
grasp_agents/workflow/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
50
|
-
grasp_agents/workflow/looped_workflow.py,sha256=
|
51
|
-
grasp_agents/workflow/sequential_workflow.py,sha256=
|
52
|
-
grasp_agents/workflow/workflow_processor.py,sha256=
|
53
|
-
grasp_agents-0.5.
|
54
|
-
grasp_agents-0.5.
|
55
|
-
grasp_agents-0.5.
|
56
|
-
grasp_agents-0.5.
|
52
|
+
grasp_agents/workflow/looped_workflow.py,sha256=WHp9O3Za2sBVfY_BLOdvPvtY20XsjZQaWSO2-oAFvOY,6806
|
53
|
+
grasp_agents/workflow/sequential_workflow.py,sha256=e3BIWzy_2novmEWNwIteyMbrzvl1-evHrTBE3r3SpU8,3648
|
54
|
+
grasp_agents/workflow/workflow_processor.py,sha256=yrxqAGfznmdkbP5zScKKJguxATfU4ObmA6BDR7YCBNU,3549
|
55
|
+
grasp_agents-0.5.6.dist-info/METADATA,sha256=c0CNKc5v1lF7iPtHiLfanUt0UKQUztGVJsMcq-vco_0,6865
|
56
|
+
grasp_agents-0.5.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
57
|
+
grasp_agents-0.5.6.dist-info/licenses/LICENSE.md,sha256=-nNNdWqGB8gJ2O-peFQ2Irshv5tW5pHKyTcYkwvH7CE,1201
|
58
|
+
grasp_agents-0.5.6.dist-info/RECORD,,
|
grasp_agents/processor.py
DELETED
@@ -1,512 +0,0 @@
|
|
1
|
-
import asyncio
|
2
|
-
import logging
|
3
|
-
from abc import ABC
|
4
|
-
from collections.abc import AsyncIterator, Sequence
|
5
|
-
from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast, final
|
6
|
-
from uuid import uuid4
|
7
|
-
|
8
|
-
from pydantic import BaseModel, TypeAdapter
|
9
|
-
from pydantic import ValidationError as PydanticValidationError
|
10
|
-
|
11
|
-
from .errors import (
|
12
|
-
PacketRoutingError,
|
13
|
-
ProcInputValidationError,
|
14
|
-
ProcOutputValidationError,
|
15
|
-
ProcRunError,
|
16
|
-
)
|
17
|
-
from .generics_utils import AutoInstanceAttributesMixin
|
18
|
-
from .memory import DummyMemory, MemT
|
19
|
-
from .packet import Packet
|
20
|
-
from .run_context import CtxT, RunContext
|
21
|
-
from .typing.events import (
|
22
|
-
Event,
|
23
|
-
ProcPacketOutputEvent,
|
24
|
-
ProcPayloadOutputEvent,
|
25
|
-
ProcStreamingErrorData,
|
26
|
-
ProcStreamingErrorEvent,
|
27
|
-
)
|
28
|
-
from .typing.io import InT, OutT, ProcName
|
29
|
-
from .typing.tool import BaseTool
|
30
|
-
from .utils import stream_concurrent
|
31
|
-
|
32
|
-
logger = logging.getLogger(__name__)
|
33
|
-
|
34
|
-
_OutT_contra = TypeVar("_OutT_contra", contravariant=True)
|
35
|
-
|
36
|
-
|
37
|
-
class SelectRecipientsHandler(Protocol[_OutT_contra, CtxT]):
|
38
|
-
def __call__(
|
39
|
-
self, output: _OutT_contra, ctx: RunContext[CtxT] | None
|
40
|
-
) -> list[ProcName] | None: ...
|
41
|
-
|
42
|
-
|
43
|
-
class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT, MemT, CtxT]):
|
44
|
-
_generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
|
45
|
-
0: "_in_type",
|
46
|
-
1: "_out_type",
|
47
|
-
}
|
48
|
-
|
49
|
-
def __init__(
|
50
|
-
self,
|
51
|
-
name: ProcName,
|
52
|
-
max_retries: int = 0,
|
53
|
-
recipients: list[ProcName] | None = None,
|
54
|
-
**kwargs: Any,
|
55
|
-
) -> None:
|
56
|
-
self._in_type: type[InT]
|
57
|
-
self._out_type: type[OutT]
|
58
|
-
|
59
|
-
super().__init__()
|
60
|
-
|
61
|
-
self._name: ProcName = name
|
62
|
-
self._memory: MemT = cast("MemT", DummyMemory())
|
63
|
-
self._max_retries: int = max_retries
|
64
|
-
self.recipients = recipients
|
65
|
-
self.select_recipients_impl: SelectRecipientsHandler[OutT, CtxT] | None = None
|
66
|
-
|
67
|
-
@property
|
68
|
-
def in_type(self) -> type[InT]:
|
69
|
-
return self._in_type
|
70
|
-
|
71
|
-
@property
|
72
|
-
def out_type(self) -> type[OutT]:
|
73
|
-
return self._out_type
|
74
|
-
|
75
|
-
@property
|
76
|
-
def name(self) -> ProcName:
|
77
|
-
return self._name
|
78
|
-
|
79
|
-
@property
|
80
|
-
def memory(self) -> MemT:
|
81
|
-
return self._memory
|
82
|
-
|
83
|
-
@property
|
84
|
-
def max_retries(self) -> int:
|
85
|
-
return self._max_retries
|
86
|
-
|
87
|
-
def _generate_call_id(self, call_id: str | None) -> str:
|
88
|
-
if call_id is None:
|
89
|
-
return str(uuid4())[:6] + "_" + self.name
|
90
|
-
return call_id
|
91
|
-
|
92
|
-
def _validate_inputs(
|
93
|
-
self,
|
94
|
-
call_id: str,
|
95
|
-
chat_inputs: Any | None = None,
|
96
|
-
in_packet: Packet[InT] | None = None,
|
97
|
-
in_args: InT | Sequence[InT] | None = None,
|
98
|
-
) -> Sequence[InT] | None:
|
99
|
-
mult_inputs_err_message = (
|
100
|
-
"Only one of chat_inputs, in_args, or in_message must be provided."
|
101
|
-
)
|
102
|
-
err_kwargs = {"proc_name": self.name, "call_id": call_id}
|
103
|
-
|
104
|
-
if chat_inputs is not None and in_args is not None:
|
105
|
-
raise ProcInputValidationError(
|
106
|
-
message=mult_inputs_err_message, **err_kwargs
|
107
|
-
)
|
108
|
-
if chat_inputs is not None and in_packet is not None:
|
109
|
-
raise ProcInputValidationError(
|
110
|
-
message=mult_inputs_err_message, **err_kwargs
|
111
|
-
)
|
112
|
-
if in_args is not None and in_packet is not None:
|
113
|
-
raise ProcInputValidationError(
|
114
|
-
message=mult_inputs_err_message, **err_kwargs
|
115
|
-
)
|
116
|
-
|
117
|
-
if in_packet is not None and not in_packet.payloads:
|
118
|
-
raise ProcInputValidationError(
|
119
|
-
message="in_packet must contain at least one payload.", **err_kwargs
|
120
|
-
)
|
121
|
-
if in_args is not None and not in_args:
|
122
|
-
raise ProcInputValidationError(
|
123
|
-
message="in_args must contain at least one argument.", **err_kwargs
|
124
|
-
)
|
125
|
-
|
126
|
-
if chat_inputs is not None:
|
127
|
-
return None
|
128
|
-
|
129
|
-
resolved_args: Sequence[InT]
|
130
|
-
|
131
|
-
if isinstance(in_args, Sequence):
|
132
|
-
_in_args = cast("Sequence[Any]", in_args)
|
133
|
-
if all(isinstance(x, self.in_type) for x in _in_args):
|
134
|
-
resolved_args = cast("Sequence[InT]", _in_args)
|
135
|
-
elif isinstance(_in_args, self.in_type):
|
136
|
-
resolved_args = cast("Sequence[InT]", [_in_args])
|
137
|
-
else:
|
138
|
-
raise ProcInputValidationError(
|
139
|
-
message=f"in_args are neither of type {self.in_type} "
|
140
|
-
f"nor a sequence of {self.in_type}.",
|
141
|
-
**err_kwargs,
|
142
|
-
)
|
143
|
-
|
144
|
-
elif in_args is not None:
|
145
|
-
resolved_args = cast("Sequence[InT]", [in_args])
|
146
|
-
|
147
|
-
else:
|
148
|
-
assert in_packet is not None
|
149
|
-
resolved_args = in_packet.payloads
|
150
|
-
|
151
|
-
try:
|
152
|
-
for args in resolved_args:
|
153
|
-
TypeAdapter(self._in_type).validate_python(args)
|
154
|
-
except PydanticValidationError as err:
|
155
|
-
raise ProcInputValidationError(message=str(err), **err_kwargs) from err
|
156
|
-
|
157
|
-
return resolved_args
|
158
|
-
|
159
|
-
def _validate_output(self, out_payload: OutT, call_id: str) -> OutT:
|
160
|
-
if out_payload is None:
|
161
|
-
return out_payload
|
162
|
-
try:
|
163
|
-
return TypeAdapter(self._out_type).validate_python(out_payload)
|
164
|
-
except PydanticValidationError as err:
|
165
|
-
raise ProcOutputValidationError(
|
166
|
-
schema=self._out_type, proc_name=self.name, call_id=call_id
|
167
|
-
) from err
|
168
|
-
|
169
|
-
def _validate_recipients(
|
170
|
-
self, recipients: Sequence[ProcName] | None, call_id: str
|
171
|
-
) -> None:
|
172
|
-
for r in recipients or []:
|
173
|
-
if r not in (self.recipients or []):
|
174
|
-
raise PacketRoutingError(
|
175
|
-
proc_name=self.name,
|
176
|
-
call_id=call_id,
|
177
|
-
selected_recipient=r,
|
178
|
-
allowed_recipients=cast("list[str]", self.recipients),
|
179
|
-
)
|
180
|
-
|
181
|
-
def _validate_par_recipients(
|
182
|
-
self, out_packets: Sequence[Packet[OutT]], call_id: str
|
183
|
-
) -> None:
|
184
|
-
recipient_sets = [set(p.recipients or []) for p in out_packets]
|
185
|
-
same_recipients = all(rs == recipient_sets[0] for rs in recipient_sets)
|
186
|
-
if not same_recipients:
|
187
|
-
raise PacketRoutingError(
|
188
|
-
proc_name=self.name,
|
189
|
-
call_id=call_id,
|
190
|
-
message="Parallel runs must return the same recipients "
|
191
|
-
f"[proc_name={self.name}; call_id={call_id}]",
|
192
|
-
)
|
193
|
-
|
194
|
-
async def _process(
|
195
|
-
self,
|
196
|
-
chat_inputs: Any | None = None,
|
197
|
-
*,
|
198
|
-
in_args: InT | None = None,
|
199
|
-
memory: MemT,
|
200
|
-
call_id: str,
|
201
|
-
ctx: RunContext[CtxT] | None = None,
|
202
|
-
) -> OutT:
|
203
|
-
return cast("OutT", in_args)
|
204
|
-
|
205
|
-
async def _process_stream(
|
206
|
-
self,
|
207
|
-
chat_inputs: Any | None = None,
|
208
|
-
*,
|
209
|
-
in_args: InT | None = None,
|
210
|
-
memory: MemT,
|
211
|
-
call_id: str,
|
212
|
-
ctx: RunContext[CtxT] | None = None,
|
213
|
-
) -> AsyncIterator[Event[Any]]:
|
214
|
-
output = cast("OutT", in_args)
|
215
|
-
yield ProcPayloadOutputEvent(data=output, proc_name=self.name, call_id=call_id)
|
216
|
-
|
217
|
-
async def _run_single_once(
|
218
|
-
self,
|
219
|
-
chat_inputs: Any | None = None,
|
220
|
-
*,
|
221
|
-
in_args: InT | None = None,
|
222
|
-
forgetful: bool = False,
|
223
|
-
call_id: str,
|
224
|
-
ctx: RunContext[CtxT] | None = None,
|
225
|
-
) -> Packet[OutT]:
|
226
|
-
_memory = self.memory.model_copy(deep=True) if forgetful else self.memory
|
227
|
-
|
228
|
-
output = await self._process(
|
229
|
-
chat_inputs=chat_inputs,
|
230
|
-
in_args=in_args,
|
231
|
-
memory=_memory,
|
232
|
-
call_id=call_id,
|
233
|
-
ctx=ctx,
|
234
|
-
)
|
235
|
-
val_output = self._validate_output(output, call_id=call_id)
|
236
|
-
|
237
|
-
recipients = self._select_recipients(output=val_output, ctx=ctx)
|
238
|
-
self._validate_recipients(recipients, call_id=call_id)
|
239
|
-
|
240
|
-
return Packet(payloads=[val_output], sender=self.name, recipients=recipients)
|
241
|
-
|
242
|
-
async def _run_single(
|
243
|
-
self,
|
244
|
-
chat_inputs: Any | None = None,
|
245
|
-
*,
|
246
|
-
in_args: InT | None = None,
|
247
|
-
forgetful: bool = False,
|
248
|
-
call_id: str,
|
249
|
-
ctx: RunContext[CtxT] | None = None,
|
250
|
-
) -> Packet[OutT]:
|
251
|
-
n_attempt = 0
|
252
|
-
while n_attempt <= self.max_retries:
|
253
|
-
try:
|
254
|
-
return await self._run_single_once(
|
255
|
-
chat_inputs=chat_inputs,
|
256
|
-
in_args=in_args,
|
257
|
-
forgetful=forgetful,
|
258
|
-
call_id=call_id,
|
259
|
-
ctx=ctx,
|
260
|
-
)
|
261
|
-
except Exception as err:
|
262
|
-
err_message = (
|
263
|
-
f"\nProcessor run failed [proc_name={self.name}; call_id={call_id}]"
|
264
|
-
)
|
265
|
-
n_attempt += 1
|
266
|
-
if n_attempt > self.max_retries:
|
267
|
-
if n_attempt == 1:
|
268
|
-
logger.warning(f"{err_message}:\n{err}")
|
269
|
-
if n_attempt > 1:
|
270
|
-
logger.warning(f"{err_message} after retrying:\n{err}")
|
271
|
-
raise ProcRunError(proc_name=self.name, call_id=call_id) from err
|
272
|
-
|
273
|
-
logger.warning(f"{err_message} (retry attempt {n_attempt}):\n{err}")
|
274
|
-
|
275
|
-
raise ProcRunError(proc_name=self.name, call_id=call_id)
|
276
|
-
|
277
|
-
async def _run_par(
|
278
|
-
self, in_args: Sequence[InT], call_id: str, ctx: RunContext[CtxT] | None = None
|
279
|
-
) -> Packet[OutT]:
|
280
|
-
tasks = [
|
281
|
-
self._run_single(
|
282
|
-
in_args=inp, forgetful=True, call_id=f"{call_id}/{idx}", ctx=ctx
|
283
|
-
)
|
284
|
-
for idx, inp in enumerate(in_args)
|
285
|
-
]
|
286
|
-
out_packets = await asyncio.gather(*tasks)
|
287
|
-
|
288
|
-
self._validate_par_recipients(out_packets, call_id=call_id)
|
289
|
-
|
290
|
-
return Packet(
|
291
|
-
payloads=[out_packet.payloads[0] for out_packet in out_packets],
|
292
|
-
sender=self.name,
|
293
|
-
recipients=out_packets[0].recipients,
|
294
|
-
)
|
295
|
-
|
296
|
-
async def run(
|
297
|
-
self,
|
298
|
-
chat_inputs: Any | None = None,
|
299
|
-
*,
|
300
|
-
in_packet: Packet[InT] | None = None,
|
301
|
-
in_args: InT | Sequence[InT] | None = None,
|
302
|
-
forgetful: bool = False,
|
303
|
-
call_id: str | None = None,
|
304
|
-
ctx: RunContext[CtxT] | None = None,
|
305
|
-
) -> Packet[OutT]:
|
306
|
-
call_id = self._generate_call_id(call_id)
|
307
|
-
|
308
|
-
val_in_args = self._validate_inputs(
|
309
|
-
call_id=call_id,
|
310
|
-
chat_inputs=chat_inputs,
|
311
|
-
in_packet=in_packet,
|
312
|
-
in_args=in_args,
|
313
|
-
)
|
314
|
-
|
315
|
-
if val_in_args and len(val_in_args) > 1:
|
316
|
-
return await self._run_par(in_args=val_in_args, call_id=call_id, ctx=ctx)
|
317
|
-
return await self._run_single(
|
318
|
-
chat_inputs=chat_inputs,
|
319
|
-
in_args=val_in_args[0] if val_in_args else None,
|
320
|
-
forgetful=forgetful,
|
321
|
-
call_id=call_id,
|
322
|
-
ctx=ctx,
|
323
|
-
)
|
324
|
-
|
325
|
-
async def _run_single_stream_once(
|
326
|
-
self,
|
327
|
-
chat_inputs: Any | None = None,
|
328
|
-
*,
|
329
|
-
in_args: InT | None = None,
|
330
|
-
forgetful: bool = False,
|
331
|
-
call_id: str,
|
332
|
-
ctx: RunContext[CtxT] | None = None,
|
333
|
-
) -> AsyncIterator[Event[Any]]:
|
334
|
-
_memory = self.memory.model_copy(deep=True) if forgetful else self.memory
|
335
|
-
|
336
|
-
output: OutT | None = None
|
337
|
-
async for event in self._process_stream(
|
338
|
-
chat_inputs=chat_inputs,
|
339
|
-
in_args=in_args,
|
340
|
-
memory=_memory,
|
341
|
-
call_id=call_id,
|
342
|
-
ctx=ctx,
|
343
|
-
):
|
344
|
-
if isinstance(event, ProcPayloadOutputEvent):
|
345
|
-
output = event.data
|
346
|
-
yield event
|
347
|
-
|
348
|
-
assert output is not None
|
349
|
-
|
350
|
-
val_output = self._validate_output(output, call_id=call_id)
|
351
|
-
|
352
|
-
recipients = self._select_recipients(output=val_output, ctx=ctx)
|
353
|
-
self._validate_recipients(recipients, call_id=call_id)
|
354
|
-
|
355
|
-
out_packet = Packet[OutT](
|
356
|
-
payloads=[val_output], sender=self.name, recipients=recipients
|
357
|
-
)
|
358
|
-
|
359
|
-
yield ProcPacketOutputEvent(
|
360
|
-
data=out_packet, proc_name=self.name, call_id=call_id
|
361
|
-
)
|
362
|
-
|
363
|
-
async def _run_single_stream(
|
364
|
-
self,
|
365
|
-
chat_inputs: Any | None = None,
|
366
|
-
*,
|
367
|
-
in_args: InT | None = None,
|
368
|
-
forgetful: bool = False,
|
369
|
-
call_id: str,
|
370
|
-
ctx: RunContext[CtxT] | None = None,
|
371
|
-
) -> AsyncIterator[Event[Any]]:
|
372
|
-
n_attempt = 0
|
373
|
-
while n_attempt <= self.max_retries:
|
374
|
-
try:
|
375
|
-
async for event in self._run_single_stream_once(
|
376
|
-
chat_inputs=chat_inputs,
|
377
|
-
in_args=in_args,
|
378
|
-
forgetful=forgetful,
|
379
|
-
call_id=call_id,
|
380
|
-
ctx=ctx,
|
381
|
-
):
|
382
|
-
yield event
|
383
|
-
|
384
|
-
return
|
385
|
-
|
386
|
-
except Exception as err:
|
387
|
-
err_data = ProcStreamingErrorData(error=err, call_id=call_id)
|
388
|
-
yield ProcStreamingErrorEvent(
|
389
|
-
data=err_data, proc_name=self.name, call_id=call_id
|
390
|
-
)
|
391
|
-
|
392
|
-
err_message = (
|
393
|
-
"\nStreaming processor run failed "
|
394
|
-
f"[proc_name={self.name}; call_id={call_id}]"
|
395
|
-
)
|
396
|
-
|
397
|
-
n_attempt += 1
|
398
|
-
if n_attempt > self.max_retries:
|
399
|
-
if n_attempt == 1:
|
400
|
-
logger.warning(f"{err_message}:\n{err}")
|
401
|
-
if n_attempt > 1:
|
402
|
-
logger.warning(f"{err_message} after retrying:\n{err}")
|
403
|
-
raise ProcRunError(proc_name=self.name, call_id=call_id) from err
|
404
|
-
|
405
|
-
logger.warning(f"{err_message} (retry attempt {n_attempt}):\n{err}")
|
406
|
-
|
407
|
-
async def _run_par_stream(
|
408
|
-
self,
|
409
|
-
in_args: Sequence[InT],
|
410
|
-
call_id: str,
|
411
|
-
ctx: RunContext[CtxT] | None = None,
|
412
|
-
) -> AsyncIterator[Event[Any]]:
|
413
|
-
streams = [
|
414
|
-
self._run_single_stream(
|
415
|
-
in_args=inp, forgetful=True, call_id=f"{call_id}/{idx}", ctx=ctx
|
416
|
-
)
|
417
|
-
for idx, inp in enumerate(in_args)
|
418
|
-
]
|
419
|
-
|
420
|
-
out_packets_map: dict[int, Packet[OutT]] = {}
|
421
|
-
async for idx, event in stream_concurrent(streams):
|
422
|
-
if isinstance(event, ProcPacketOutputEvent):
|
423
|
-
out_packets_map[idx] = event.data
|
424
|
-
else:
|
425
|
-
yield event
|
426
|
-
|
427
|
-
out_packet = Packet(
|
428
|
-
payloads=[
|
429
|
-
out_packet.payloads[0]
|
430
|
-
for _, out_packet in sorted(out_packets_map.items())
|
431
|
-
],
|
432
|
-
sender=self.name,
|
433
|
-
recipients=out_packets_map[0].recipients,
|
434
|
-
)
|
435
|
-
|
436
|
-
yield ProcPacketOutputEvent(
|
437
|
-
data=out_packet, proc_name=self.name, call_id=call_id
|
438
|
-
)
|
439
|
-
|
440
|
-
async def run_stream(
|
441
|
-
self,
|
442
|
-
chat_inputs: Any | None = None,
|
443
|
-
*,
|
444
|
-
in_packet: Packet[InT] | None = None,
|
445
|
-
in_args: InT | Sequence[InT] | None = None,
|
446
|
-
forgetful: bool = False,
|
447
|
-
call_id: str | None = None,
|
448
|
-
ctx: RunContext[CtxT] | None = None,
|
449
|
-
) -> AsyncIterator[Event[Any]]:
|
450
|
-
call_id = self._generate_call_id(call_id)
|
451
|
-
|
452
|
-
val_in_args = self._validate_inputs(
|
453
|
-
call_id=call_id,
|
454
|
-
chat_inputs=chat_inputs,
|
455
|
-
in_packet=in_packet,
|
456
|
-
in_args=in_args,
|
457
|
-
)
|
458
|
-
|
459
|
-
if val_in_args and len(val_in_args) > 1:
|
460
|
-
stream = self._run_par_stream(in_args=val_in_args, call_id=call_id, ctx=ctx)
|
461
|
-
else:
|
462
|
-
stream = self._run_single_stream(
|
463
|
-
chat_inputs=chat_inputs,
|
464
|
-
in_args=val_in_args[0] if val_in_args else None,
|
465
|
-
forgetful=forgetful,
|
466
|
-
call_id=call_id,
|
467
|
-
ctx=ctx,
|
468
|
-
)
|
469
|
-
async for event in stream:
|
470
|
-
yield event
|
471
|
-
|
472
|
-
def _select_recipients(
|
473
|
-
self, output: OutT, ctx: RunContext[CtxT] | None = None
|
474
|
-
) -> list[ProcName] | None:
|
475
|
-
if self.select_recipients_impl:
|
476
|
-
return self.select_recipients_impl(output=output, ctx=ctx)
|
477
|
-
|
478
|
-
return self.recipients
|
479
|
-
|
480
|
-
def select_recipients(
|
481
|
-
self, func: SelectRecipientsHandler[OutT, CtxT]
|
482
|
-
) -> SelectRecipientsHandler[OutT, CtxT]:
|
483
|
-
self.select_recipients_impl = func
|
484
|
-
|
485
|
-
return func
|
486
|
-
|
487
|
-
@final
|
488
|
-
def as_tool(
|
489
|
-
self, tool_name: str, tool_description: str
|
490
|
-
) -> BaseTool[InT, OutT, Any]: # type: ignore[override]
|
491
|
-
# TODO: stream tools
|
492
|
-
processor_instance = self
|
493
|
-
in_type = processor_instance.in_type
|
494
|
-
out_type = processor_instance.out_type
|
495
|
-
if not issubclass(in_type, BaseModel):
|
496
|
-
raise TypeError(
|
497
|
-
"Cannot create a tool from an agent with "
|
498
|
-
f"non-BaseModel input type: {in_type}"
|
499
|
-
)
|
500
|
-
|
501
|
-
class ProcessorTool(BaseTool[in_type, out_type, Any]):
|
502
|
-
name: str = tool_name
|
503
|
-
description: str = tool_description
|
504
|
-
|
505
|
-
async def run(self, inp: InT, ctx: RunContext[CtxT] | None = None) -> OutT:
|
506
|
-
result = await processor_instance.run(
|
507
|
-
in_args=in_type.model_validate(inp), forgetful=True, ctx=ctx
|
508
|
-
)
|
509
|
-
|
510
|
-
return result.payloads[0]
|
511
|
-
|
512
|
-
return ProcessorTool()
|
File without changes
|
File without changes
|