grasp_agents 0.5.4__py3-none-any.whl → 0.5.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,320 @@
1
+ import logging
2
+ from abc import ABC, abstractmethod
3
+ from collections.abc import AsyncIterator, Callable, Coroutine
4
+ from functools import wraps
5
+ from typing import (
6
+ Any,
7
+ ClassVar,
8
+ Generic,
9
+ Protocol,
10
+ TypeVar,
11
+ cast,
12
+ final,
13
+ )
14
+ from uuid import uuid4
15
+
16
+ from pydantic import BaseModel, TypeAdapter
17
+ from pydantic import ValidationError as PydanticValidationError
18
+
19
+ from ..errors import (
20
+ PacketRoutingError,
21
+ ProcInputValidationError,
22
+ ProcOutputValidationError,
23
+ ProcRunError,
24
+ )
25
+ from ..generics_utils import AutoInstanceAttributesMixin
26
+ from ..memory import DummyMemory, MemT
27
+ from ..packet import Packet
28
+ from ..run_context import CtxT, RunContext
29
+ from ..typing.events import (
30
+ DummyEvent,
31
+ Event,
32
+ ProcStreamingErrorData,
33
+ ProcStreamingErrorEvent,
34
+ )
35
+ from ..typing.io import InT, OutT, ProcName
36
+ from ..typing.tool import BaseTool
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+ _OutT_contra = TypeVar("_OutT_contra", contravariant=True)
41
+
42
+ F = TypeVar("F", bound=Callable[..., Coroutine[Any, Any, Packet[Any]]])
43
+ F_stream = TypeVar("F_stream", bound=Callable[..., AsyncIterator[Event[Any]]])
44
+
45
+
46
+ def with_retry(func: F) -> F:
47
+ @wraps(func)
48
+ async def wrapper(
49
+ self: "BaseProcessor[Any, Any, Any, Any]", *args: Any, **kwargs: Any
50
+ ) -> Packet[Any]:
51
+ call_id = kwargs.get("call_id", "unknown")
52
+ for n_attempt in range(self.max_retries + 1):
53
+ try:
54
+ return await func(self, *args, **kwargs)
55
+ except Exception as err:
56
+ err_message = (
57
+ f"\nProcessor run failed [proc_name={self.name}; call_id={call_id}]"
58
+ )
59
+ if n_attempt == self.max_retries:
60
+ if self.max_retries == 0:
61
+ logger.warning(f"{err_message}:\n{err}")
62
+ else:
63
+ logger.warning(f"{err_message} after retrying:\n{err}")
64
+ raise ProcRunError(proc_name=self.name, call_id=call_id) from err
65
+
66
+ logger.warning(f"{err_message} (retry attempt {n_attempt + 1}):\n{err}")
67
+ # This part should not be reachable due to the raise in the loop
68
+ raise ProcRunError(proc_name=self.name, call_id=call_id)
69
+
70
+ return cast("F", wrapper)
71
+
72
+
73
+ def with_retry_stream(func: F_stream) -> F_stream:
74
+ @wraps(func)
75
+ async def wrapper(
76
+ self: "BaseProcessor[Any, Any, Any, Any]", *args: Any, **kwargs: Any
77
+ ) -> AsyncIterator[Event[Any]]:
78
+ call_id = kwargs.get("call_id", "unknown")
79
+ for n_attempt in range(self.max_retries + 1):
80
+ try:
81
+ async for event in func(self, *args, **kwargs):
82
+ yield event
83
+ return
84
+ except Exception as err:
85
+ err_data = ProcStreamingErrorData(error=err, call_id=call_id)
86
+ yield ProcStreamingErrorEvent(
87
+ data=err_data, proc_name=self.name, call_id=call_id
88
+ )
89
+ err_message = (
90
+ "\nStreaming processor run failed "
91
+ f"[proc_name={self.name}; call_id={call_id}]"
92
+ )
93
+ if n_attempt == self.max_retries:
94
+ if self.max_retries == 0:
95
+ logger.warning(f"{err_message}:\n{err}")
96
+ else:
97
+ logger.warning(f"{err_message} after retrying:\n{err}")
98
+ raise ProcRunError(proc_name=self.name, call_id=call_id) from err
99
+
100
+ logger.warning(f"{err_message} (retry attempt {n_attempt}):\n{err}")
101
+
102
+ return cast("F_stream", wrapper)
103
+
104
+
105
+ class RecipientSelector(Protocol[_OutT_contra, CtxT]):
106
+ def __call__(
107
+ self, output: _OutT_contra, ctx: RunContext[CtxT] | None
108
+ ) -> list[ProcName] | None: ...
109
+
110
+
111
+ class BaseProcessor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT, MemT, CtxT]):
112
+ _generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
113
+ 0: "_in_type",
114
+ 1: "_out_type",
115
+ }
116
+
117
+ def __init__(
118
+ self,
119
+ name: ProcName,
120
+ max_retries: int = 0,
121
+ recipients: list[ProcName] | None = None,
122
+ **kwargs: Any,
123
+ ) -> None:
124
+ self._in_type: type[InT]
125
+ self._out_type: type[OutT]
126
+
127
+ super().__init__()
128
+
129
+ self._name: ProcName = name
130
+ self._memory: MemT = cast("MemT", DummyMemory())
131
+ self._max_retries: int = max_retries
132
+
133
+ self.recipients = recipients
134
+
135
+ self.recipient_selector: RecipientSelector[OutT, CtxT] | None
136
+ if not hasattr(type(self), "recipient_selector"):
137
+ self.recipient_selector = None
138
+
139
+ @property
140
+ def in_type(self) -> type[InT]:
141
+ return self._in_type
142
+
143
+ @property
144
+ def out_type(self) -> type[OutT]:
145
+ return self._out_type
146
+
147
+ @property
148
+ def name(self) -> ProcName:
149
+ return self._name
150
+
151
+ @property
152
+ def memory(self) -> MemT:
153
+ return self._memory
154
+
155
+ @property
156
+ def max_retries(self) -> int:
157
+ return self._max_retries
158
+
159
+ def _generate_call_id(self, call_id: str | None) -> str:
160
+ if call_id is None:
161
+ return str(uuid4())[:6] + "_" + self.name
162
+ return call_id
163
+
164
+ def _validate_inputs(
165
+ self,
166
+ call_id: str,
167
+ chat_inputs: Any | None = None,
168
+ in_packet: Packet[InT] | None = None,
169
+ in_args: InT | list[InT] | None = None,
170
+ ) -> list[InT] | None:
171
+ mult_inputs_err_message = (
172
+ "Only one of chat_inputs, in_args, or in_message must be provided."
173
+ )
174
+ err_kwargs = {"proc_name": self.name, "call_id": call_id}
175
+
176
+ if chat_inputs is not None and in_args is not None:
177
+ raise ProcInputValidationError(
178
+ message=mult_inputs_err_message, **err_kwargs
179
+ )
180
+ if chat_inputs is not None and in_packet is not None:
181
+ raise ProcInputValidationError(
182
+ message=mult_inputs_err_message, **err_kwargs
183
+ )
184
+ if in_args is not None and in_packet is not None:
185
+ raise ProcInputValidationError(
186
+ message=mult_inputs_err_message, **err_kwargs
187
+ )
188
+
189
+ if in_packet is not None and not in_packet.payloads:
190
+ raise ProcInputValidationError(
191
+ message="in_packet must contain at least one payload.", **err_kwargs
192
+ )
193
+ if in_args is not None and not in_args:
194
+ raise ProcInputValidationError(
195
+ message="in_args must contain at least one argument.", **err_kwargs
196
+ )
197
+
198
+ if chat_inputs is not None:
199
+ return None
200
+
201
+ resolved_args: list[InT]
202
+
203
+ if isinstance(in_args, list):
204
+ _in_args = cast("list[Any]", in_args)
205
+ if all(isinstance(x, self.in_type) for x in _in_args):
206
+ resolved_args = cast("list[InT]", _in_args)
207
+ elif isinstance(_in_args, self.in_type):
208
+ resolved_args = cast("list[InT]", [_in_args])
209
+ else:
210
+ raise ProcInputValidationError(
211
+ message=f"in_args are neither of type {self.in_type} "
212
+ f"nor a sequence of {self.in_type}.",
213
+ **err_kwargs,
214
+ )
215
+
216
+ elif in_args is not None:
217
+ resolved_args = cast("list[InT]", [in_args])
218
+
219
+ else:
220
+ assert in_packet is not None
221
+ resolved_args = cast("list[InT]", in_packet.payloads)
222
+
223
+ try:
224
+ for args in resolved_args:
225
+ TypeAdapter(self._in_type).validate_python(args)
226
+ except PydanticValidationError as err:
227
+ raise ProcInputValidationError(message=str(err), **err_kwargs) from err
228
+
229
+ return resolved_args
230
+
231
+ def _validate_output(self, out_payload: OutT, call_id: str) -> OutT:
232
+ if out_payload is None:
233
+ return out_payload
234
+ try:
235
+ return TypeAdapter(self._out_type).validate_python(out_payload)
236
+ except PydanticValidationError as err:
237
+ raise ProcOutputValidationError(
238
+ schema=self._out_type, proc_name=self.name, call_id=call_id
239
+ ) from err
240
+
241
+ def _validate_recipients(
242
+ self, recipients: list[ProcName] | None, call_id: str
243
+ ) -> None:
244
+ for r in recipients or []:
245
+ if r not in (self.recipients or []):
246
+ raise PacketRoutingError(
247
+ proc_name=self.name,
248
+ call_id=call_id,
249
+ selected_recipient=r,
250
+ allowed_recipients=cast("list[str]", self.recipients),
251
+ )
252
+
253
+ @final
254
+ def _select_recipients(
255
+ self, output: OutT, ctx: RunContext[CtxT] | None = None
256
+ ) -> list[ProcName] | None:
257
+ if self.recipient_selector:
258
+ return self.recipient_selector(output=output, ctx=ctx)
259
+
260
+ return self.recipients
261
+
262
+ def add_recipient_selector(
263
+ self, func: RecipientSelector[OutT, CtxT]
264
+ ) -> RecipientSelector[OutT, CtxT]:
265
+ self.recipient_selector = func
266
+
267
+ return func
268
+
269
+ @abstractmethod
270
+ async def run(
271
+ self,
272
+ chat_inputs: Any | None = None,
273
+ *,
274
+ in_packet: Packet[InT] | None = None,
275
+ in_args: InT | list[InT] | None = None,
276
+ forgetful: bool = False,
277
+ call_id: str | None = None,
278
+ ctx: RunContext[CtxT] | None = None,
279
+ ) -> Packet[OutT]:
280
+ pass
281
+
282
+ @abstractmethod
283
+ async def run_stream(
284
+ self,
285
+ chat_inputs: Any | None = None,
286
+ *,
287
+ in_packet: Packet[InT] | None = None,
288
+ in_args: InT | list[InT] | None = None,
289
+ forgetful: bool = False,
290
+ call_id: str | None = None,
291
+ ctx: RunContext[CtxT] | None = None,
292
+ ) -> AsyncIterator[Event[Any]]:
293
+ yield DummyEvent()
294
+
295
+ @final
296
+ def as_tool(
297
+ self, tool_name: str, tool_description: str
298
+ ) -> BaseTool[InT, OutT, Any]: # type: ignore[override]
299
+ # TODO: stream tools
300
+ processor_instance = self
301
+ in_type = processor_instance.in_type
302
+ out_type = processor_instance.out_type
303
+ if not issubclass(in_type, BaseModel):
304
+ raise TypeError(
305
+ "Cannot create a tool from an agent with "
306
+ f"non-BaseModel input type: {in_type}"
307
+ )
308
+
309
+ class ProcessorTool(BaseTool[in_type, out_type, Any]):
310
+ name: str = tool_name
311
+ description: str = tool_description
312
+
313
+ async def run(self, inp: InT, ctx: RunContext[CtxT] | None = None) -> OutT:
314
+ result = await processor_instance.run(
315
+ in_args=inp, forgetful=True, ctx=ctx
316
+ )
317
+
318
+ return result.payloads[0]
319
+
320
+ return ProcessorTool()
@@ -0,0 +1,244 @@
1
+ import asyncio
2
+ import logging
3
+ from collections.abc import AsyncIterator, Sequence
4
+ from typing import Any, ClassVar, Generic, cast
5
+
6
+
7
+ from ..errors import PacketRoutingError
8
+ from ..memory import MemT
9
+ from ..packet import Packet
10
+ from ..run_context import CtxT, RunContext
11
+ from ..typing.events import (
12
+ Event, ProcPacketOutputEvent, ProcPayloadOutputEvent
13
+ )
14
+ from ..typing.io import InT, OutT
15
+ from ..utils import stream_concurrent
16
+ from .base_processor import BaseProcessor, with_retry, with_retry_stream
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class ParallelProcessor(BaseProcessor[InT, OutT, MemT, CtxT], Generic[InT, OutT, MemT, CtxT]):
22
+ _generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
23
+ 0: "_in_type",
24
+ 1: "_out_type",
25
+ }
26
+
27
+ async def _process(
28
+ self,
29
+ chat_inputs: Any | None = None,
30
+ *,
31
+ in_args: InT | None = None,
32
+ memory: MemT,
33
+ call_id: str,
34
+ ctx: RunContext[CtxT] | None = None,
35
+ ) -> OutT:
36
+ return cast(OutT, in_args)
37
+
38
+ async def _process_stream(
39
+ self,
40
+ chat_inputs: Any | None = None,
41
+ *,
42
+ in_args: InT | None = None,
43
+ memory: MemT,
44
+ call_id: str,
45
+ ctx: RunContext[CtxT] | None = None,
46
+ ) -> AsyncIterator[Event[Any]]:
47
+ output = cast(OutT, in_args)
48
+ yield ProcPayloadOutputEvent(data=output, proc_name=self.name, call_id=call_id)
49
+
50
+ def _validate_parallel_recipients(
51
+ self, out_packets: Sequence[Packet[OutT]], call_id: str
52
+ ) -> None:
53
+ recipient_sets = [set(p.recipients or []) for p in out_packets]
54
+ same_recipients = all(rs == recipient_sets[0] for rs in recipient_sets)
55
+ if not same_recipients:
56
+ raise PacketRoutingError(
57
+ proc_name=self.name,
58
+ call_id=call_id,
59
+ message="Parallel runs must return the same recipients "
60
+ f"[proc_name={self.name}; call_id={call_id}]",
61
+ )
62
+
63
+ @with_retry
64
+ async def _run_single(
65
+ self,
66
+ chat_inputs: Any | None = None,
67
+ *,
68
+ in_args: InT | None = None,
69
+ forgetful: bool = False,
70
+ call_id: str,
71
+ ctx: RunContext[CtxT] | None = None,
72
+ ) -> Packet[OutT]:
73
+ memory = self.memory.model_copy(deep=True) if forgetful else self.memory
74
+
75
+ output = await self._process(
76
+ chat_inputs=chat_inputs,
77
+ in_args=in_args,
78
+ memory=memory,
79
+ call_id=call_id,
80
+ ctx=ctx,
81
+ )
82
+ val_output = self._validate_output(output, call_id=call_id)
83
+
84
+ recipients = self._select_recipients(output=val_output, ctx=ctx)
85
+ self._validate_recipients(recipients, call_id=call_id)
86
+
87
+ return Packet(payloads=[val_output], sender=self.name, recipients=recipients)
88
+
89
+
90
+ async def _run_parallel(
91
+ self, in_args: list[InT], call_id: str, ctx: RunContext[CtxT] | None = None
92
+ ) -> Packet[OutT]:
93
+ tasks = [
94
+ self._run_single(
95
+ in_args=inp, forgetful=True, call_id=f"{call_id}/{idx}", ctx=ctx
96
+ )
97
+ for idx, inp in enumerate(in_args)
98
+ ]
99
+ out_packets = await asyncio.gather(*tasks)
100
+ self._validate_parallel_recipients(out_packets, call_id=call_id)
101
+
102
+ return Packet(
103
+ payloads=[out_packet.payloads[0] for out_packet in out_packets],
104
+ sender=self.name,
105
+ recipients=out_packets[0].recipients,
106
+ )
107
+
108
+ async def run(
109
+ self,
110
+ chat_inputs: Any | None = None,
111
+ *,
112
+ in_packet: Packet[InT] | None = None,
113
+ in_args: InT | list[InT] | None = None,
114
+ forgetful: bool = False,
115
+ call_id: str | None = None,
116
+ ctx: RunContext[CtxT] | None = None,
117
+ ) -> Packet[OutT]:
118
+ call_id = self._generate_call_id(call_id)
119
+
120
+ val_in_args = self._validate_inputs(
121
+ call_id=call_id,
122
+ chat_inputs=chat_inputs,
123
+ in_packet=in_packet,
124
+ in_args=in_args,
125
+ )
126
+
127
+ if val_in_args and len(val_in_args) > 1:
128
+ return await self._run_parallel(in_args=val_in_args, call_id=call_id, ctx=ctx)
129
+
130
+ return await self._run_single(
131
+ chat_inputs=chat_inputs,
132
+ in_args=val_in_args[0] if val_in_args else None,
133
+ forgetful=forgetful,
134
+ call_id=call_id,
135
+ ctx=ctx,
136
+ )
137
+
138
+ @with_retry_stream
139
+ async def _run_single_stream(
140
+ self,
141
+ chat_inputs: Any | None = None,
142
+ *,
143
+ in_args: InT | None = None,
144
+ forgetful: bool = False,
145
+ call_id: str,
146
+ ctx: RunContext[CtxT] | None = None,
147
+ ) -> AsyncIterator[Event[Any]]:
148
+ memory = self.memory.model_copy(deep=True) if forgetful else self.memory
149
+
150
+ output: OutT | None = None
151
+ async for event in self._process_stream(
152
+ chat_inputs=chat_inputs,
153
+ in_args=in_args,
154
+ memory=memory,
155
+ call_id=call_id,
156
+ ctx=ctx,
157
+ ):
158
+ if isinstance(event, ProcPayloadOutputEvent):
159
+ output = event.data
160
+ yield event
161
+
162
+ assert output is not None
163
+
164
+ val_output = self._validate_output(output, call_id=call_id)
165
+
166
+ recipients = self._select_recipients(output=val_output, ctx=ctx)
167
+ self._validate_recipients(recipients, call_id=call_id)
168
+
169
+ out_packet = Packet[OutT](
170
+ payloads=[val_output], sender=self.name, recipients=recipients
171
+ )
172
+
173
+ yield ProcPacketOutputEvent(
174
+ data=out_packet, proc_name=self.name, call_id=call_id
175
+ )
176
+
177
+ async def _run_parallel_stream(
178
+ self,
179
+ in_args: list[InT],
180
+ call_id: str,
181
+ ctx: RunContext[CtxT] | None = None,
182
+ ) -> AsyncIterator[Event[Any]]:
183
+ streams = [
184
+ self._run_single_stream(
185
+ in_args=inp, forgetful=True, call_id=f"{call_id}/{idx}", ctx=ctx
186
+ )
187
+ for idx, inp in enumerate(in_args)
188
+ ]
189
+
190
+ out_packets_map: dict[int, Packet[OutT]] = {}
191
+ async for idx, event in stream_concurrent(streams):
192
+ if isinstance(event, ProcPacketOutputEvent):
193
+ out_packets_map[idx] = event.data
194
+ else:
195
+ yield event
196
+
197
+ self._validate_parallel_recipients(
198
+ out_packets=list(out_packets_map.values()), call_id=call_id
199
+ )
200
+
201
+ out_packet = Packet(
202
+ payloads=[
203
+ out_packet.payloads[0]
204
+ for _, out_packet in sorted(out_packets_map.items())
205
+ ],
206
+ sender=self.name,
207
+ recipients=out_packets_map[0].recipients,
208
+ )
209
+
210
+ yield ProcPacketOutputEvent(
211
+ data=out_packet, proc_name=self.name, call_id=call_id
212
+ )
213
+
214
+ async def run_stream(
215
+ self,
216
+ chat_inputs: Any | None = None,
217
+ *,
218
+ in_packet: Packet[InT] | None = None,
219
+ in_args: InT | list[InT] | None = None,
220
+ forgetful: bool = False,
221
+ call_id: str | None = None,
222
+ ctx: RunContext[CtxT] | None = None,
223
+ ) -> AsyncIterator[Event[Any]]:
224
+ call_id = self._generate_call_id(call_id)
225
+
226
+ val_in_args = self._validate_inputs(
227
+ call_id=call_id,
228
+ chat_inputs=chat_inputs,
229
+ in_packet=in_packet,
230
+ in_args=in_args,
231
+ )
232
+
233
+ if val_in_args and len(val_in_args) > 1:
234
+ stream = self._run_parallel_stream(in_args=val_in_args, call_id=call_id, ctx=ctx)
235
+ else:
236
+ stream = self._run_single_stream(
237
+ chat_inputs=chat_inputs,
238
+ in_args=val_in_args[0] if val_in_args else None,
239
+ forgetful=forgetful,
240
+ call_id=call_id,
241
+ ctx=ctx,
242
+ )
243
+ async for event in stream:
244
+ yield event