grasp_agents 0.5.4__py3-none-any.whl → 0.5.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,23 +4,23 @@ from typing import Any, Generic
4
4
 
5
5
  from ..errors import WorkflowConstructionError
6
6
  from ..packet import Packet
7
- from ..processor import Processor
7
+ from ..processors.base_processor import BaseProcessor, RecipientSelector
8
8
  from ..run_context import CtxT, RunContext
9
- from ..typing.events import Event
9
+ from ..typing.events import DummyEvent, Event
10
10
  from ..typing.io import InT, OutT, ProcName
11
11
 
12
12
 
13
13
  class WorkflowProcessor(
14
- Processor[InT, OutT, Any, CtxT],
14
+ BaseProcessor[InT, OutT, Any, CtxT],
15
15
  ABC,
16
16
  Generic[InT, OutT, CtxT],
17
17
  ):
18
18
  def __init__(
19
19
  self,
20
20
  name: ProcName,
21
- subprocs: Sequence[Processor[Any, Any, Any, CtxT]],
22
- start_proc: Processor[InT, Any, Any, CtxT],
23
- end_proc: Processor[Any, OutT, Any, CtxT],
21
+ subprocs: Sequence[BaseProcessor[Any, Any, Any, CtxT]],
22
+ start_proc: BaseProcessor[InT, Any, Any, CtxT],
23
+ end_proc: BaseProcessor[Any, OutT, Any, CtxT],
24
24
  recipients: list[ProcName] | None = None,
25
25
  max_retries: int = 0,
26
26
  ) -> None:
@@ -44,20 +44,41 @@ class WorkflowProcessor(
44
44
  self._start_proc = start_proc
45
45
  self._end_proc = end_proc
46
46
 
47
+ self.recipients = recipients
48
+ if hasattr(type(self), "recipient_selector"):
49
+ self._end_proc.recipient_selector = self.recipient_selector
50
+
51
+ def add_recipient_selector(
52
+ self, func: RecipientSelector[OutT, CtxT]
53
+ ) -> RecipientSelector[OutT, CtxT]:
54
+ self._end_proc.recipient_selector = func
55
+ self.recipient_selector = func
56
+
57
+ return func
58
+
47
59
  @property
48
- def subprocs(self) -> Sequence[Processor[Any, Any, Any, CtxT]]:
60
+ def recipients(self) -> list[ProcName] | None:
61
+ return self._end_proc.recipients
62
+
63
+ @recipients.setter
64
+ def recipients(self, value: list[ProcName] | None) -> None:
65
+ if hasattr(self, "_end_proc"):
66
+ self._end_proc.recipients = value
67
+
68
+ @property
69
+ def subprocs(self) -> Sequence[BaseProcessor[Any, Any, Any, CtxT]]:
49
70
  return self._subprocs
50
71
 
51
72
  @property
52
- def start_proc(self) -> Processor[InT, Any, Any, CtxT]:
73
+ def start_proc(self) -> BaseProcessor[InT, Any, Any, CtxT]:
53
74
  return self._start_proc
54
75
 
55
76
  @property
56
- def end_proc(self) -> Processor[Any, OutT, Any, CtxT]:
77
+ def end_proc(self) -> BaseProcessor[Any, OutT, Any, CtxT]:
57
78
  return self._end_proc
58
79
 
59
80
  def _generate_subproc_call_id(
60
- self, call_id: str | None, subproc: Processor[Any, Any, Any, CtxT]
81
+ self, call_id: str | None, subproc: BaseProcessor[Any, Any, Any, CtxT]
61
82
  ) -> str | None:
62
83
  return f"{self._generate_call_id(call_id)}/{subproc.name}"
63
84
 
@@ -85,4 +106,4 @@ class WorkflowProcessor(
85
106
  forgetful: bool = False,
86
107
  call_id: str | None = None,
87
108
  ) -> AsyncIterator[Event[Any]]:
88
- pass
109
+ yield DummyEvent()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: grasp_agents
3
- Version: 0.5.4
3
+ Version: 0.5.6
4
4
  Summary: Grasp Agents Library
5
5
  License-File: LICENSE.md
6
6
  Requires-Python: <4,>=3.11.4
@@ -1,22 +1,21 @@
1
- grasp_agents/__init__.py,sha256=NfnMiLttHb5iudzeDoLAK-wHCK0QBo0guPcupM99ISU,896
1
+ grasp_agents/__init__.py,sha256=Z3a_j2Etiap9H6lvE8-PQP_OIGMUcHNPeJAJO12B8kY,1031
2
2
  grasp_agents/cloud_llm.py,sha256=C6xrKYhiQb4tNXK_rFo2pNlVTXPS_gYd5uevAnpLFeE,13119
3
3
  grasp_agents/costs_dict.yaml,sha256=2MFNWtkv5W5WSCcv1Cj13B1iQLVv5Ot9pS_KW2Gu2DA,2510
4
4
  grasp_agents/errors.py,sha256=K-22TCM1Klhsej47Rg5eTqnGiGPaXgKOpdOZZ7cPipw,4633
5
5
  grasp_agents/generics_utils.py,sha256=5Pw3I9dlnKC2VGqYKC4ZZUO3Z_vTNT-NPFovNfPkl6I,6542
6
6
  grasp_agents/grasp_logging.py,sha256=H1GYhXdQvVkmauFDZ-KDwvVmPQHZUUm9sRqX_ObK2xI,1111
7
7
  grasp_agents/http_client.py,sha256=Es8NXGDkp4Nem7g24-jW0KFGA9Hp_o2Cv3cOvjup-iU,859
8
- grasp_agents/llm.py,sha256=YCzHlb2gdWVy77DmQpcu0dyEzGp4VQa3UMRJKodCYSs,6557
9
- grasp_agents/llm_agent.py,sha256=oUMNe5yXt0JKHmVNGSpR1nyb5KyYDkrz9Nyn8sBxx1g,13682
10
- grasp_agents/llm_agent_memory.py,sha256=GZ2Z66_JC_sR10vATvR53ui62xxY4lDDtL0XvL_AQNk,1846
11
- grasp_agents/llm_policy_executor.py,sha256=jYJENUooaEz5u-ZkgCijcGIJhA2CFkXQH2TUz5KwgB4,16822
8
+ grasp_agents/llm.py,sha256=ZkAeGEkpMsOY6T_zj2pL3ZWkP0mwDN_e9ArN99qZfmY,6574
9
+ grasp_agents/llm_agent.py,sha256=hX3T2Y5qiTt5CrsahNo5t08HFCyBWEiurzYnFykJN9Y,13513
10
+ grasp_agents/llm_agent_memory.py,sha256=gQwH3g4Ib3ciW2jrBiW13ttwax_pcPobH5RhXRmbc0E,1842
11
+ grasp_agents/llm_policy_executor.py,sha256=glusWe4wTIh_rl3bXPbAMueJ5eDfJxPYBTzoecrHPYg,16849
12
12
  grasp_agents/memory.py,sha256=keHuNEZNSxHT9FKpMohHOCNi7UAz_oRIc91IQEuzaWE,1162
13
13
  grasp_agents/packet.py,sha256=EmE-W4ZSMVZoqClECGFe7OGqrT4FSJ8IVGICrdjtdEY,1462
14
- grasp_agents/packet_pool.py,sha256=A5hqEvh7XwaPJVkfbItMm5FQqxujVzXDId7HnSUqluI,4919
15
- grasp_agents/printer.py,sha256=RvUXPtbD6pdhGxoKZiddKlfMRH6WaSAIdoJSEkmpId8,11041
16
- grasp_agents/processor.py,sha256=tJTD2T1cmuQ-m_7bsVFsF7k4oZafh_f0p2uqdng2-sA,17053
17
- grasp_agents/prompt_builder.py,sha256=IFvwvl_YqwL_LGfmauf4Gr1DUGZWkHsKk0ZABFnR6HI,7124
18
- grasp_agents/run_context.py,sha256=f8PasaPDCBiR59AFTHRIp1wxIPAQSzJGvv_kglWOGlk,1098
19
- grasp_agents/runner.py,sha256=nZXK5OuM7QbLyzNA-rfXD0YDtdftTK5v2cCpeujhe2k,4479
14
+ grasp_agents/packet_pool.py,sha256=i0g4O_fnpsYTU3LNjKgM4TiiiJkHA3YeZP9Y5GatM_I,5082
15
+ grasp_agents/printer.py,sha256=wtCH75DgWAwjJntr57kK2e7dagcfliUAfefYHpiwWi0,11203
16
+ grasp_agents/prompt_builder.py,sha256=UuQNnvjrhzd3_NVnvHlCs-NRNRMo4jsMePgZeIxmzSY,5894
17
+ grasp_agents/run_context.py,sha256=ikakNK1khm0UBEIPETB508BL0IlOKbOUPuq0FZ-iQHQ,942
18
+ grasp_agents/runner.py,sha256=sRuKkX8Iopxv_MdGZf22b_AsuH4FIPY28rl15RqwGEY,5005
20
19
  grasp_agents/usage_tracker.py,sha256=ZQfVUUpG0C89hyPWT_JgXnjQOxoYmumcQ9t-aCfcMo8,3561
21
20
  grasp_agents/utils.py,sha256=qKmGBwrQHw1-BgqRLuGTPKGs3J_zbrpk3nxnP1iZBiQ,6152
22
21
  grasp_agents/litellm/__init__.py,sha256=wD8RZBYokFDfbS9Cs7nO_zKb3w7RIVwEGj7g2D5CJH0,4510
@@ -33,6 +32,9 @@ grasp_agents/openai/converters.py,sha256=CXHF2GehEHLEzjL45HywZ_1qaB3N29-lbac5oBD
33
32
  grasp_agents/openai/message_converters.py,sha256=fhSN81uK51EGbLyM2-f0MvPX_UBrMy7SF3JQPo-dkXg,4686
34
33
  grasp_agents/openai/openai_llm.py,sha256=uJbbCytqpv8OCncKdzpiUdkVh3mJWgo95Y9Xetk_Ptg,10556
35
34
  grasp_agents/openai/tool_converters.py,sha256=IotZvpe3xMQcBfcjUTfAsn4LtZljj3zkU9bfpcoiqPw,1177
35
+ grasp_agents/processors/base_processor.py,sha256=j2_QY6HUjckdxfsf7yAF0xRDp_V-DNDb7hIRMRKUyWw,10685
36
+ grasp_agents/processors/parallel_processor.py,sha256=EoB1sqtnuv362S1MTUyn3-TUSO4M5Vpe86Q9HlBmbok,7816
37
+ grasp_agents/processors/processor.py,sha256=v7Bf6CGVsjb43XuOtTMuev9UedMy_lBTGifzzCZwh4Q,5157
36
38
  grasp_agents/rate_limiting/__init__.py,sha256=KRgtF_E7R3YfA2cpYcFcZ7wycV0pWVJ0xRQC7YhiIEQ,158
37
39
  grasp_agents/rate_limiting/rate_limiter_chunked.py,sha256=BPgkUXvhmZhTpZs2T6uujNFuxH_kYHiISuf6_-eNhUc,5544
38
40
  grasp_agents/rate_limiting/types.py,sha256=PbnNhEAcYedQdIpPJWud8HUVcxa_xZS2RDZu4c5jr40,1003
@@ -42,15 +44,15 @@ grasp_agents/typing/completion.py,sha256=PHJ01m7WI2KYQL8w7W2ti6hMsKEZnzYGaxbNcBC
42
44
  grasp_agents/typing/completion_chunk.py,sha256=t6PvkDWQxRN5xA4efBdeou46RifMFodBmZc45Sx7qxQ,7610
43
45
  grasp_agents/typing/content.py,sha256=XFmLpNWkGhkw5JujO6UsYwhzTHkU67PfhzaXH2waLcQ,3659
44
46
  grasp_agents/typing/converters.py,sha256=kHlocHQS8QnduZOzNPbj3aRD8JpvJd53oudYqWdOxKE,2978
45
- grasp_agents/typing/events.py,sha256=1n6EVS8vU0RLlQHMHUjiV5i4PYLovT7RcOmpcx0ztlc,5256
46
- grasp_agents/typing/io.py,sha256=WmFfAVnqd4uNQygNAmlo0BX3q8focS6KLvCChmM0ep8,215
47
+ grasp_agents/typing/events.py,sha256=UB0G2mQLmyLHYdEp1YnFi8XF2jpWeUZ4TLFnY64lK0I,5450
48
+ grasp_agents/typing/io.py,sha256=MGEoUjAwKH1AHYglFkKNpHiielw-NFf13Epg3B4Q7Iw,139
47
49
  grasp_agents/typing/message.py,sha256=o7bN84AgrC5Fm3Wx20gqL9ArAMcEtYvnHnXbb04ngCs,3224
48
50
  grasp_agents/typing/tool.py,sha256=4N-k_GvHVPAFyVyEq7z_LYKkA24iQlGoVYiWCzGTgT4,1786
49
51
  grasp_agents/workflow/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
50
- grasp_agents/workflow/looped_workflow.py,sha256=J-mdy1gxsMUC622j8W-_fDX3oAyKHg9TTzCXEksMS-Q,6311
51
- grasp_agents/workflow/sequential_workflow.py,sha256=7xBl6YtH97y6jSv6eAHtSIeKjoqiUWKu63C-gKp6Eus,3295
52
- grasp_agents/workflow/workflow_processor.py,sha256=56XC6KDRQuRftwhWnALmPUlQWCGfuTPRgIznyV2-tOg,2776
53
- grasp_agents-0.5.4.dist-info/METADATA,sha256=ePDw81wa-FDQu1jj46FN4tO4yUkgCt7lvn_0mDcwaDA,6865
54
- grasp_agents-0.5.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
55
- grasp_agents-0.5.4.dist-info/licenses/LICENSE.md,sha256=-nNNdWqGB8gJ2O-peFQ2Irshv5tW5pHKyTcYkwvH7CE,1201
56
- grasp_agents-0.5.4.dist-info/RECORD,,
52
+ grasp_agents/workflow/looped_workflow.py,sha256=WHp9O3Za2sBVfY_BLOdvPvtY20XsjZQaWSO2-oAFvOY,6806
53
+ grasp_agents/workflow/sequential_workflow.py,sha256=e3BIWzy_2novmEWNwIteyMbrzvl1-evHrTBE3r3SpU8,3648
54
+ grasp_agents/workflow/workflow_processor.py,sha256=yrxqAGfznmdkbP5zScKKJguxATfU4ObmA6BDR7YCBNU,3549
55
+ grasp_agents-0.5.6.dist-info/METADATA,sha256=c0CNKc5v1lF7iPtHiLfanUt0UKQUztGVJsMcq-vco_0,6865
56
+ grasp_agents-0.5.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
57
+ grasp_agents-0.5.6.dist-info/licenses/LICENSE.md,sha256=-nNNdWqGB8gJ2O-peFQ2Irshv5tW5pHKyTcYkwvH7CE,1201
58
+ grasp_agents-0.5.6.dist-info/RECORD,,
grasp_agents/processor.py DELETED
@@ -1,512 +0,0 @@
1
- import asyncio
2
- import logging
3
- from abc import ABC
4
- from collections.abc import AsyncIterator, Sequence
5
- from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast, final
6
- from uuid import uuid4
7
-
8
- from pydantic import BaseModel, TypeAdapter
9
- from pydantic import ValidationError as PydanticValidationError
10
-
11
- from .errors import (
12
- PacketRoutingError,
13
- ProcInputValidationError,
14
- ProcOutputValidationError,
15
- ProcRunError,
16
- )
17
- from .generics_utils import AutoInstanceAttributesMixin
18
- from .memory import DummyMemory, MemT
19
- from .packet import Packet
20
- from .run_context import CtxT, RunContext
21
- from .typing.events import (
22
- Event,
23
- ProcPacketOutputEvent,
24
- ProcPayloadOutputEvent,
25
- ProcStreamingErrorData,
26
- ProcStreamingErrorEvent,
27
- )
28
- from .typing.io import InT, OutT, ProcName
29
- from .typing.tool import BaseTool
30
- from .utils import stream_concurrent
31
-
32
- logger = logging.getLogger(__name__)
33
-
34
- _OutT_contra = TypeVar("_OutT_contra", contravariant=True)
35
-
36
-
37
- class SelectRecipientsHandler(Protocol[_OutT_contra, CtxT]):
38
- def __call__(
39
- self, output: _OutT_contra, ctx: RunContext[CtxT] | None
40
- ) -> list[ProcName] | None: ...
41
-
42
-
43
- class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT, MemT, CtxT]):
44
- _generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
45
- 0: "_in_type",
46
- 1: "_out_type",
47
- }
48
-
49
- def __init__(
50
- self,
51
- name: ProcName,
52
- max_retries: int = 0,
53
- recipients: list[ProcName] | None = None,
54
- **kwargs: Any,
55
- ) -> None:
56
- self._in_type: type[InT]
57
- self._out_type: type[OutT]
58
-
59
- super().__init__()
60
-
61
- self._name: ProcName = name
62
- self._memory: MemT = cast("MemT", DummyMemory())
63
- self._max_retries: int = max_retries
64
- self.recipients = recipients
65
- self.select_recipients_impl: SelectRecipientsHandler[OutT, CtxT] | None = None
66
-
67
- @property
68
- def in_type(self) -> type[InT]:
69
- return self._in_type
70
-
71
- @property
72
- def out_type(self) -> type[OutT]:
73
- return self._out_type
74
-
75
- @property
76
- def name(self) -> ProcName:
77
- return self._name
78
-
79
- @property
80
- def memory(self) -> MemT:
81
- return self._memory
82
-
83
- @property
84
- def max_retries(self) -> int:
85
- return self._max_retries
86
-
87
- def _generate_call_id(self, call_id: str | None) -> str:
88
- if call_id is None:
89
- return str(uuid4())[:6] + "_" + self.name
90
- return call_id
91
-
92
- def _validate_inputs(
93
- self,
94
- call_id: str,
95
- chat_inputs: Any | None = None,
96
- in_packet: Packet[InT] | None = None,
97
- in_args: InT | Sequence[InT] | None = None,
98
- ) -> Sequence[InT] | None:
99
- mult_inputs_err_message = (
100
- "Only one of chat_inputs, in_args, or in_message must be provided."
101
- )
102
- err_kwargs = {"proc_name": self.name, "call_id": call_id}
103
-
104
- if chat_inputs is not None and in_args is not None:
105
- raise ProcInputValidationError(
106
- message=mult_inputs_err_message, **err_kwargs
107
- )
108
- if chat_inputs is not None and in_packet is not None:
109
- raise ProcInputValidationError(
110
- message=mult_inputs_err_message, **err_kwargs
111
- )
112
- if in_args is not None and in_packet is not None:
113
- raise ProcInputValidationError(
114
- message=mult_inputs_err_message, **err_kwargs
115
- )
116
-
117
- if in_packet is not None and not in_packet.payloads:
118
- raise ProcInputValidationError(
119
- message="in_packet must contain at least one payload.", **err_kwargs
120
- )
121
- if in_args is not None and not in_args:
122
- raise ProcInputValidationError(
123
- message="in_args must contain at least one argument.", **err_kwargs
124
- )
125
-
126
- if chat_inputs is not None:
127
- return None
128
-
129
- resolved_args: Sequence[InT]
130
-
131
- if isinstance(in_args, Sequence):
132
- _in_args = cast("Sequence[Any]", in_args)
133
- if all(isinstance(x, self.in_type) for x in _in_args):
134
- resolved_args = cast("Sequence[InT]", _in_args)
135
- elif isinstance(_in_args, self.in_type):
136
- resolved_args = cast("Sequence[InT]", [_in_args])
137
- else:
138
- raise ProcInputValidationError(
139
- message=f"in_args are neither of type {self.in_type} "
140
- f"nor a sequence of {self.in_type}.",
141
- **err_kwargs,
142
- )
143
-
144
- elif in_args is not None:
145
- resolved_args = cast("Sequence[InT]", [in_args])
146
-
147
- else:
148
- assert in_packet is not None
149
- resolved_args = in_packet.payloads
150
-
151
- try:
152
- for args in resolved_args:
153
- TypeAdapter(self._in_type).validate_python(args)
154
- except PydanticValidationError as err:
155
- raise ProcInputValidationError(message=str(err), **err_kwargs) from err
156
-
157
- return resolved_args
158
-
159
- def _validate_output(self, out_payload: OutT, call_id: str) -> OutT:
160
- if out_payload is None:
161
- return out_payload
162
- try:
163
- return TypeAdapter(self._out_type).validate_python(out_payload)
164
- except PydanticValidationError as err:
165
- raise ProcOutputValidationError(
166
- schema=self._out_type, proc_name=self.name, call_id=call_id
167
- ) from err
168
-
169
- def _validate_recipients(
170
- self, recipients: Sequence[ProcName] | None, call_id: str
171
- ) -> None:
172
- for r in recipients or []:
173
- if r not in (self.recipients or []):
174
- raise PacketRoutingError(
175
- proc_name=self.name,
176
- call_id=call_id,
177
- selected_recipient=r,
178
- allowed_recipients=cast("list[str]", self.recipients),
179
- )
180
-
181
- def _validate_par_recipients(
182
- self, out_packets: Sequence[Packet[OutT]], call_id: str
183
- ) -> None:
184
- recipient_sets = [set(p.recipients or []) for p in out_packets]
185
- same_recipients = all(rs == recipient_sets[0] for rs in recipient_sets)
186
- if not same_recipients:
187
- raise PacketRoutingError(
188
- proc_name=self.name,
189
- call_id=call_id,
190
- message="Parallel runs must return the same recipients "
191
- f"[proc_name={self.name}; call_id={call_id}]",
192
- )
193
-
194
- async def _process(
195
- self,
196
- chat_inputs: Any | None = None,
197
- *,
198
- in_args: InT | None = None,
199
- memory: MemT,
200
- call_id: str,
201
- ctx: RunContext[CtxT] | None = None,
202
- ) -> OutT:
203
- return cast("OutT", in_args)
204
-
205
- async def _process_stream(
206
- self,
207
- chat_inputs: Any | None = None,
208
- *,
209
- in_args: InT | None = None,
210
- memory: MemT,
211
- call_id: str,
212
- ctx: RunContext[CtxT] | None = None,
213
- ) -> AsyncIterator[Event[Any]]:
214
- output = cast("OutT", in_args)
215
- yield ProcPayloadOutputEvent(data=output, proc_name=self.name, call_id=call_id)
216
-
217
- async def _run_single_once(
218
- self,
219
- chat_inputs: Any | None = None,
220
- *,
221
- in_args: InT | None = None,
222
- forgetful: bool = False,
223
- call_id: str,
224
- ctx: RunContext[CtxT] | None = None,
225
- ) -> Packet[OutT]:
226
- _memory = self.memory.model_copy(deep=True) if forgetful else self.memory
227
-
228
- output = await self._process(
229
- chat_inputs=chat_inputs,
230
- in_args=in_args,
231
- memory=_memory,
232
- call_id=call_id,
233
- ctx=ctx,
234
- )
235
- val_output = self._validate_output(output, call_id=call_id)
236
-
237
- recipients = self._select_recipients(output=val_output, ctx=ctx)
238
- self._validate_recipients(recipients, call_id=call_id)
239
-
240
- return Packet(payloads=[val_output], sender=self.name, recipients=recipients)
241
-
242
- async def _run_single(
243
- self,
244
- chat_inputs: Any | None = None,
245
- *,
246
- in_args: InT | None = None,
247
- forgetful: bool = False,
248
- call_id: str,
249
- ctx: RunContext[CtxT] | None = None,
250
- ) -> Packet[OutT]:
251
- n_attempt = 0
252
- while n_attempt <= self.max_retries:
253
- try:
254
- return await self._run_single_once(
255
- chat_inputs=chat_inputs,
256
- in_args=in_args,
257
- forgetful=forgetful,
258
- call_id=call_id,
259
- ctx=ctx,
260
- )
261
- except Exception as err:
262
- err_message = (
263
- f"\nProcessor run failed [proc_name={self.name}; call_id={call_id}]"
264
- )
265
- n_attempt += 1
266
- if n_attempt > self.max_retries:
267
- if n_attempt == 1:
268
- logger.warning(f"{err_message}:\n{err}")
269
- if n_attempt > 1:
270
- logger.warning(f"{err_message} after retrying:\n{err}")
271
- raise ProcRunError(proc_name=self.name, call_id=call_id) from err
272
-
273
- logger.warning(f"{err_message} (retry attempt {n_attempt}):\n{err}")
274
-
275
- raise ProcRunError(proc_name=self.name, call_id=call_id)
276
-
277
- async def _run_par(
278
- self, in_args: Sequence[InT], call_id: str, ctx: RunContext[CtxT] | None = None
279
- ) -> Packet[OutT]:
280
- tasks = [
281
- self._run_single(
282
- in_args=inp, forgetful=True, call_id=f"{call_id}/{idx}", ctx=ctx
283
- )
284
- for idx, inp in enumerate(in_args)
285
- ]
286
- out_packets = await asyncio.gather(*tasks)
287
-
288
- self._validate_par_recipients(out_packets, call_id=call_id)
289
-
290
- return Packet(
291
- payloads=[out_packet.payloads[0] for out_packet in out_packets],
292
- sender=self.name,
293
- recipients=out_packets[0].recipients,
294
- )
295
-
296
- async def run(
297
- self,
298
- chat_inputs: Any | None = None,
299
- *,
300
- in_packet: Packet[InT] | None = None,
301
- in_args: InT | Sequence[InT] | None = None,
302
- forgetful: bool = False,
303
- call_id: str | None = None,
304
- ctx: RunContext[CtxT] | None = None,
305
- ) -> Packet[OutT]:
306
- call_id = self._generate_call_id(call_id)
307
-
308
- val_in_args = self._validate_inputs(
309
- call_id=call_id,
310
- chat_inputs=chat_inputs,
311
- in_packet=in_packet,
312
- in_args=in_args,
313
- )
314
-
315
- if val_in_args and len(val_in_args) > 1:
316
- return await self._run_par(in_args=val_in_args, call_id=call_id, ctx=ctx)
317
- return await self._run_single(
318
- chat_inputs=chat_inputs,
319
- in_args=val_in_args[0] if val_in_args else None,
320
- forgetful=forgetful,
321
- call_id=call_id,
322
- ctx=ctx,
323
- )
324
-
325
- async def _run_single_stream_once(
326
- self,
327
- chat_inputs: Any | None = None,
328
- *,
329
- in_args: InT | None = None,
330
- forgetful: bool = False,
331
- call_id: str,
332
- ctx: RunContext[CtxT] | None = None,
333
- ) -> AsyncIterator[Event[Any]]:
334
- _memory = self.memory.model_copy(deep=True) if forgetful else self.memory
335
-
336
- output: OutT | None = None
337
- async for event in self._process_stream(
338
- chat_inputs=chat_inputs,
339
- in_args=in_args,
340
- memory=_memory,
341
- call_id=call_id,
342
- ctx=ctx,
343
- ):
344
- if isinstance(event, ProcPayloadOutputEvent):
345
- output = event.data
346
- yield event
347
-
348
- assert output is not None
349
-
350
- val_output = self._validate_output(output, call_id=call_id)
351
-
352
- recipients = self._select_recipients(output=val_output, ctx=ctx)
353
- self._validate_recipients(recipients, call_id=call_id)
354
-
355
- out_packet = Packet[OutT](
356
- payloads=[val_output], sender=self.name, recipients=recipients
357
- )
358
-
359
- yield ProcPacketOutputEvent(
360
- data=out_packet, proc_name=self.name, call_id=call_id
361
- )
362
-
363
- async def _run_single_stream(
364
- self,
365
- chat_inputs: Any | None = None,
366
- *,
367
- in_args: InT | None = None,
368
- forgetful: bool = False,
369
- call_id: str,
370
- ctx: RunContext[CtxT] | None = None,
371
- ) -> AsyncIterator[Event[Any]]:
372
- n_attempt = 0
373
- while n_attempt <= self.max_retries:
374
- try:
375
- async for event in self._run_single_stream_once(
376
- chat_inputs=chat_inputs,
377
- in_args=in_args,
378
- forgetful=forgetful,
379
- call_id=call_id,
380
- ctx=ctx,
381
- ):
382
- yield event
383
-
384
- return
385
-
386
- except Exception as err:
387
- err_data = ProcStreamingErrorData(error=err, call_id=call_id)
388
- yield ProcStreamingErrorEvent(
389
- data=err_data, proc_name=self.name, call_id=call_id
390
- )
391
-
392
- err_message = (
393
- "\nStreaming processor run failed "
394
- f"[proc_name={self.name}; call_id={call_id}]"
395
- )
396
-
397
- n_attempt += 1
398
- if n_attempt > self.max_retries:
399
- if n_attempt == 1:
400
- logger.warning(f"{err_message}:\n{err}")
401
- if n_attempt > 1:
402
- logger.warning(f"{err_message} after retrying:\n{err}")
403
- raise ProcRunError(proc_name=self.name, call_id=call_id) from err
404
-
405
- logger.warning(f"{err_message} (retry attempt {n_attempt}):\n{err}")
406
-
407
- async def _run_par_stream(
408
- self,
409
- in_args: Sequence[InT],
410
- call_id: str,
411
- ctx: RunContext[CtxT] | None = None,
412
- ) -> AsyncIterator[Event[Any]]:
413
- streams = [
414
- self._run_single_stream(
415
- in_args=inp, forgetful=True, call_id=f"{call_id}/{idx}", ctx=ctx
416
- )
417
- for idx, inp in enumerate(in_args)
418
- ]
419
-
420
- out_packets_map: dict[int, Packet[OutT]] = {}
421
- async for idx, event in stream_concurrent(streams):
422
- if isinstance(event, ProcPacketOutputEvent):
423
- out_packets_map[idx] = event.data
424
- else:
425
- yield event
426
-
427
- out_packet = Packet(
428
- payloads=[
429
- out_packet.payloads[0]
430
- for _, out_packet in sorted(out_packets_map.items())
431
- ],
432
- sender=self.name,
433
- recipients=out_packets_map[0].recipients,
434
- )
435
-
436
- yield ProcPacketOutputEvent(
437
- data=out_packet, proc_name=self.name, call_id=call_id
438
- )
439
-
440
- async def run_stream(
441
- self,
442
- chat_inputs: Any | None = None,
443
- *,
444
- in_packet: Packet[InT] | None = None,
445
- in_args: InT | Sequence[InT] | None = None,
446
- forgetful: bool = False,
447
- call_id: str | None = None,
448
- ctx: RunContext[CtxT] | None = None,
449
- ) -> AsyncIterator[Event[Any]]:
450
- call_id = self._generate_call_id(call_id)
451
-
452
- val_in_args = self._validate_inputs(
453
- call_id=call_id,
454
- chat_inputs=chat_inputs,
455
- in_packet=in_packet,
456
- in_args=in_args,
457
- )
458
-
459
- if val_in_args and len(val_in_args) > 1:
460
- stream = self._run_par_stream(in_args=val_in_args, call_id=call_id, ctx=ctx)
461
- else:
462
- stream = self._run_single_stream(
463
- chat_inputs=chat_inputs,
464
- in_args=val_in_args[0] if val_in_args else None,
465
- forgetful=forgetful,
466
- call_id=call_id,
467
- ctx=ctx,
468
- )
469
- async for event in stream:
470
- yield event
471
-
472
- def _select_recipients(
473
- self, output: OutT, ctx: RunContext[CtxT] | None = None
474
- ) -> list[ProcName] | None:
475
- if self.select_recipients_impl:
476
- return self.select_recipients_impl(output=output, ctx=ctx)
477
-
478
- return self.recipients
479
-
480
- def select_recipients(
481
- self, func: SelectRecipientsHandler[OutT, CtxT]
482
- ) -> SelectRecipientsHandler[OutT, CtxT]:
483
- self.select_recipients_impl = func
484
-
485
- return func
486
-
487
- @final
488
- def as_tool(
489
- self, tool_name: str, tool_description: str
490
- ) -> BaseTool[InT, OutT, Any]: # type: ignore[override]
491
- # TODO: stream tools
492
- processor_instance = self
493
- in_type = processor_instance.in_type
494
- out_type = processor_instance.out_type
495
- if not issubclass(in_type, BaseModel):
496
- raise TypeError(
497
- "Cannot create a tool from an agent with "
498
- f"non-BaseModel input type: {in_type}"
499
- )
500
-
501
- class ProcessorTool(BaseTool[in_type, out_type, Any]):
502
- name: str = tool_name
503
- description: str = tool_description
504
-
505
- async def run(self, inp: InT, ctx: RunContext[CtxT] | None = None) -> OutT:
506
- result = await processor_instance.run(
507
- in_args=in_type.model_validate(inp), forgetful=True, ctx=ctx
508
- )
509
-
510
- return result.payloads[0]
511
-
512
- return ProcessorTool()