grasp_agents 0.5.5__tar.gz → 0.5.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/PKG-INFO +1 -1
  2. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/pyproject.toml +1 -1
  3. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/src/grasp_agents/__init__.py +5 -1
  4. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/src/grasp_agents/llm.py +5 -1
  5. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/src/grasp_agents/llm_agent.py +18 -7
  6. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/src/grasp_agents/packet_pool.py +6 -1
  7. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/src/grasp_agents/printer.py +7 -4
  8. grasp_agents-0.5.5/src/grasp_agents/processor.py → grasp_agents-0.5.6/src/grasp_agents/processors/base_processor.py +89 -287
  9. grasp_agents-0.5.6/src/grasp_agents/processors/parallel_processor.py +244 -0
  10. grasp_agents-0.5.6/src/grasp_agents/processors/processor.py +161 -0
  11. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/src/grasp_agents/runner.py +20 -1
  12. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/src/grasp_agents/typing/events.py +4 -0
  13. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/src/grasp_agents/workflow/looped_workflow.py +35 -27
  14. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/src/grasp_agents/workflow/sequential_workflow.py +14 -3
  15. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/src/grasp_agents/workflow/workflow_processor.py +21 -15
  16. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/.gitignore +0 -0
  17. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/LICENSE.md +0 -0
  18. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/README.md +0 -0
  19. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/src/grasp_agents/cloud_llm.py +0 -0
  20. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/src/grasp_agents/costs_dict.yaml +0 -0
  21. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/src/grasp_agents/errors.py +0 -0
  22. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/src/grasp_agents/generics_utils.py +0 -0
  23. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/src/grasp_agents/grasp_logging.py +0 -0
  24. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/src/grasp_agents/http_client.py +0 -0
  25. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/src/grasp_agents/litellm/__init__.py +0 -0
  26. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/src/grasp_agents/litellm/completion_chunk_converters.py +0 -0
  27. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/src/grasp_agents/litellm/completion_converters.py +0 -0
  28. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/src/grasp_agents/litellm/converters.py +0 -0
  29. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/src/grasp_agents/litellm/lite_llm.py +0 -0
  30. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/src/grasp_agents/litellm/message_converters.py +0 -0
  31. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/src/grasp_agents/llm_agent_memory.py +0 -0
  32. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/src/grasp_agents/llm_policy_executor.py +0 -0
  33. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/src/grasp_agents/memory.py +0 -0
  34. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/src/grasp_agents/openai/__init__.py +0 -0
  35. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/src/grasp_agents/openai/completion_chunk_converters.py +0 -0
  36. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/src/grasp_agents/openai/completion_converters.py +0 -0
  37. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/src/grasp_agents/openai/content_converters.py +0 -0
  38. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/src/grasp_agents/openai/converters.py +0 -0
  39. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/src/grasp_agents/openai/message_converters.py +0 -0
  40. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/src/grasp_agents/openai/openai_llm.py +0 -0
  41. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/src/grasp_agents/openai/tool_converters.py +0 -0
  42. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/src/grasp_agents/packet.py +0 -0
  43. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/src/grasp_agents/prompt_builder.py +0 -0
  44. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/src/grasp_agents/rate_limiting/__init__.py +0 -0
  45. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/src/grasp_agents/rate_limiting/rate_limiter_chunked.py +0 -0
  46. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/src/grasp_agents/rate_limiting/types.py +0 -0
  47. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/src/grasp_agents/rate_limiting/utils.py +0 -0
  48. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/src/grasp_agents/run_context.py +0 -0
  49. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/src/grasp_agents/typing/__init__.py +0 -0
  50. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/src/grasp_agents/typing/completion.py +0 -0
  51. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/src/grasp_agents/typing/completion_chunk.py +0 -0
  52. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/src/grasp_agents/typing/content.py +0 -0
  53. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/src/grasp_agents/typing/converters.py +0 -0
  54. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/src/grasp_agents/typing/io.py +0 -0
  55. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/src/grasp_agents/typing/message.py +0 -0
  56. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/src/grasp_agents/typing/tool.py +0 -0
  57. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/src/grasp_agents/usage_tracker.py +0 -0
  58. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/src/grasp_agents/utils.py +0 -0
  59. {grasp_agents-0.5.5 → grasp_agents-0.5.6}/src/grasp_agents/workflow/__init__.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: grasp_agents
3
- Version: 0.5.5
3
+ Version: 0.5.6
4
4
  Summary: Grasp Agents Library
5
5
  License-File: LICENSE.md
6
6
  Requires-Python: <4,>=3.11.4
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "grasp_agents"
3
- version = "0.5.5"
3
+ version = "0.5.6"
4
4
  description = "Grasp Agents Library"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.11.4,<4"
@@ -6,7 +6,9 @@ from .llm_agent import LLMAgent
6
6
  from .llm_agent_memory import LLMAgentMemory
7
7
  from .memory import Memory
8
8
  from .packet import Packet
9
- from .processor import Processor
9
+ from .processors.base_processor import BaseProcessor
10
+ from .processors.parallel_processor import ParallelProcessor
11
+ from .processors.processor import Processor
10
12
  from .run_context import RunContext
11
13
  from .typing.completion import Completion
12
14
  from .typing.content import Content, ImageData
@@ -17,6 +19,7 @@ from .typing.tool import BaseTool
17
19
  __all__ = [
18
20
  "LLM",
19
21
  "AssistantMessage",
22
+ "BaseProcessor",
20
23
  "BaseTool",
21
24
  "Completion",
22
25
  "Content",
@@ -29,6 +32,7 @@ __all__ = [
29
32
  "Messages",
30
33
  "Packet",
31
34
  "Packet",
35
+ "ParallelProcessor",
32
36
  "ProcName",
33
37
  "Processor",
34
38
  "RunContext",
@@ -19,7 +19,11 @@ from .errors import (
19
19
  )
20
20
  from .typing.completion import Completion
21
21
  from .typing.converters import Converters
22
- from .typing.events import CompletionChunkEvent, CompletionEvent, LLMStreamingErrorEvent
22
+ from .typing.events import (
23
+ CompletionChunkEvent,
24
+ CompletionEvent,
25
+ LLMStreamingErrorEvent,
26
+ )
23
27
  from .typing.message import Messages
24
28
  from .typing.tool import BaseTool, ToolChoice
25
29
 
@@ -11,7 +11,7 @@ from .llm_policy_executor import (
11
11
  MemoryManager,
12
12
  ToolCallLoopTerminator,
13
13
  )
14
- from .processor import Processor
14
+ from .processors.parallel_processor import ParallelProcessor
15
15
  from .prompt_builder import (
16
16
  InputContentBuilder,
17
17
  PromptBuilder,
@@ -46,7 +46,7 @@ class OutputParser(Protocol[_InT_contra, _OutT_co, CtxT]):
46
46
 
47
47
 
48
48
  class LLMAgent(
49
- Processor[InT, OutT, LLMAgentMemory, CtxT],
49
+ ParallelProcessor[InT, OutT, LLMAgentMemory, CtxT],
50
50
  Generic[InT, OutT, CtxT],
51
51
  ):
52
52
  _generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
@@ -196,6 +196,20 @@ class LLMAgent(
196
196
 
197
197
  return system_message, input_message
198
198
 
199
+ def _parse_output_default(
200
+ self,
201
+ conversation: Messages,
202
+ *,
203
+ in_args: InT | None = None,
204
+ ctx: RunContext[CtxT] | None = None,
205
+ ) -> OutT:
206
+ return validate_obj_from_json_or_py_string(
207
+ str(conversation[-1].content or ""),
208
+ schema=self._out_type,
209
+ from_substring=False,
210
+ strip_language_markdown=True,
211
+ )
212
+
199
213
  def _parse_output(
200
214
  self,
201
215
  conversation: Messages,
@@ -208,11 +222,8 @@ class LLMAgent(
208
222
  conversation=conversation, in_args=in_args, ctx=ctx
209
223
  )
210
224
 
211
- return validate_obj_from_json_or_py_string(
212
- str(conversation[-1].content or ""),
213
- schema=self._out_type,
214
- from_substring=False,
215
- strip_language_markdown=True,
225
+ return self._parse_output_default(
226
+ conversation=conversation, in_args=in_args, ctx=ctx
216
227
  )
217
228
 
218
229
  async def _process(
@@ -68,6 +68,11 @@ class PacketPool(Generic[CtxT]):
68
68
  finally:
69
69
  await self.shutdown()
70
70
 
71
+ @property
72
+ def final_result_ready(self) -> bool:
73
+ fut = self._final_result_fut
74
+ return fut is not None and fut.done()
75
+
71
76
  def register_packet_handler(
72
77
  self,
73
78
  proc_name: ProcName,
@@ -121,7 +126,7 @@ class PacketPool(Generic[CtxT]):
121
126
  queue = self._packet_queues[proc_name]
122
127
  handler = self._packet_handlers[proc_name]
123
128
 
124
- while True:
129
+ while not self.final_result_ready:
125
130
  packet = await queue.get()
126
131
  if packet is None:
127
132
  break
@@ -232,11 +232,14 @@ async def print_event_stream(
232
232
  text += f"<{src} output>\n"
233
233
  for p in event.data.payloads:
234
234
  if isinstance(p, BaseModel):
235
- for field_info in type(p).model_fields.values():
236
- if field_info.exclude:
237
- field_info.exclude = False
238
- type(p).model_rebuild(force=True)
235
+ # for field_info in type(p).model_fields.values():
236
+ # if field_info.exclude:
237
+ # field_info.exclude = False
238
+ # break
239
+ # type(p).model_rebuild(force=True)
239
240
  p_str = p.model_dump_json(indent=2)
241
+ # field_info.exclude = True # type: ignore
242
+ # type(p).model_rebuild(force=True)
240
243
  else:
241
244
  try:
242
245
  p_str = json.dumps(p, indent=2)
@@ -1,38 +1,106 @@
1
- import asyncio
2
1
  import logging
3
- from abc import ABC
4
- from collections.abc import AsyncIterator, Sequence
5
- from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast, final
2
+ from abc import ABC, abstractmethod
3
+ from collections.abc import AsyncIterator, Callable, Coroutine
4
+ from functools import wraps
5
+ from typing import (
6
+ Any,
7
+ ClassVar,
8
+ Generic,
9
+ Protocol,
10
+ TypeVar,
11
+ cast,
12
+ final,
13
+ )
6
14
  from uuid import uuid4
7
15
 
8
16
  from pydantic import BaseModel, TypeAdapter
9
17
  from pydantic import ValidationError as PydanticValidationError
10
18
 
11
- from .errors import (
19
+ from ..errors import (
12
20
  PacketRoutingError,
13
21
  ProcInputValidationError,
14
22
  ProcOutputValidationError,
15
23
  ProcRunError,
16
24
  )
17
- from .generics_utils import AutoInstanceAttributesMixin
18
- from .memory import DummyMemory, MemT
19
- from .packet import Packet
20
- from .run_context import CtxT, RunContext
21
- from .typing.events import (
25
+ from ..generics_utils import AutoInstanceAttributesMixin
26
+ from ..memory import DummyMemory, MemT
27
+ from ..packet import Packet
28
+ from ..run_context import CtxT, RunContext
29
+ from ..typing.events import (
30
+ DummyEvent,
22
31
  Event,
23
- ProcPacketOutputEvent,
24
- ProcPayloadOutputEvent,
25
32
  ProcStreamingErrorData,
26
33
  ProcStreamingErrorEvent,
27
34
  )
28
- from .typing.io import InT, OutT, ProcName
29
- from .typing.tool import BaseTool
30
- from .utils import stream_concurrent
35
+ from ..typing.io import InT, OutT, ProcName
36
+ from ..typing.tool import BaseTool
31
37
 
32
38
  logger = logging.getLogger(__name__)
33
39
 
34
40
  _OutT_contra = TypeVar("_OutT_contra", contravariant=True)
35
41
 
42
+ F = TypeVar("F", bound=Callable[..., Coroutine[Any, Any, Packet[Any]]])
43
+ F_stream = TypeVar("F_stream", bound=Callable[..., AsyncIterator[Event[Any]]])
44
+
45
+
46
+ def with_retry(func: F) -> F:
47
+ @wraps(func)
48
+ async def wrapper(
49
+ self: "BaseProcessor[Any, Any, Any, Any]", *args: Any, **kwargs: Any
50
+ ) -> Packet[Any]:
51
+ call_id = kwargs.get("call_id", "unknown")
52
+ for n_attempt in range(self.max_retries + 1):
53
+ try:
54
+ return await func(self, *args, **kwargs)
55
+ except Exception as err:
56
+ err_message = (
57
+ f"\nProcessor run failed [proc_name={self.name}; call_id={call_id}]"
58
+ )
59
+ if n_attempt == self.max_retries:
60
+ if self.max_retries == 0:
61
+ logger.warning(f"{err_message}:\n{err}")
62
+ else:
63
+ logger.warning(f"{err_message} after retrying:\n{err}")
64
+ raise ProcRunError(proc_name=self.name, call_id=call_id) from err
65
+
66
+ logger.warning(f"{err_message} (retry attempt {n_attempt + 1}):\n{err}")
67
+ # This part should not be reachable due to the raise in the loop
68
+ raise ProcRunError(proc_name=self.name, call_id=call_id)
69
+
70
+ return cast("F", wrapper)
71
+
72
+
73
+ def with_retry_stream(func: F_stream) -> F_stream:
74
+ @wraps(func)
75
+ async def wrapper(
76
+ self: "BaseProcessor[Any, Any, Any, Any]", *args: Any, **kwargs: Any
77
+ ) -> AsyncIterator[Event[Any]]:
78
+ call_id = kwargs.get("call_id", "unknown")
79
+ for n_attempt in range(self.max_retries + 1):
80
+ try:
81
+ async for event in func(self, *args, **kwargs):
82
+ yield event
83
+ return
84
+ except Exception as err:
85
+ err_data = ProcStreamingErrorData(error=err, call_id=call_id)
86
+ yield ProcStreamingErrorEvent(
87
+ data=err_data, proc_name=self.name, call_id=call_id
88
+ )
89
+ err_message = (
90
+ "\nStreaming processor run failed "
91
+ f"[proc_name={self.name}; call_id={call_id}]"
92
+ )
93
+ if n_attempt == self.max_retries:
94
+ if self.max_retries == 0:
95
+ logger.warning(f"{err_message}:\n{err}")
96
+ else:
97
+ logger.warning(f"{err_message} after retrying:\n{err}")
98
+ raise ProcRunError(proc_name=self.name, call_id=call_id) from err
99
+
100
+ logger.warning(f"{err_message} (retry attempt {n_attempt}):\n{err}")
101
+
102
+ return cast("F_stream", wrapper)
103
+
36
104
 
37
105
  class RecipientSelector(Protocol[_OutT_contra, CtxT]):
38
106
  def __call__(
@@ -40,7 +108,7 @@ class RecipientSelector(Protocol[_OutT_contra, CtxT]):
40
108
  ) -> list[ProcName] | None: ...
41
109
 
42
110
 
43
- class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT, MemT, CtxT]):
111
+ class BaseProcessor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT, MemT, CtxT]):
44
112
  _generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
45
113
  0: "_in_type",
46
114
  1: "_out_type",
@@ -66,7 +134,6 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT, MemT, CtxT]
66
134
 
67
135
  self.recipient_selector: RecipientSelector[OutT, CtxT] | None
68
136
  if not hasattr(type(self), "recipient_selector"):
69
- # Set to None if not defined in the subclass
70
137
  self.recipient_selector = None
71
138
 
72
139
  @property
@@ -183,19 +250,6 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT, MemT, CtxT]
183
250
  allowed_recipients=cast("list[str]", self.recipients),
184
251
  )
185
252
 
186
- def _validate_par_recipients(
187
- self, out_packets: Sequence[Packet[OutT]], call_id: str
188
- ) -> None:
189
- recipient_sets = [set(p.recipients or []) for p in out_packets]
190
- same_recipients = all(rs == recipient_sets[0] for rs in recipient_sets)
191
- if not same_recipients:
192
- raise PacketRoutingError(
193
- proc_name=self.name,
194
- call_id=call_id,
195
- message="Parallel runs must return the same recipients "
196
- f"[proc_name={self.name}; call_id={call_id}]",
197
- )
198
-
199
253
  @final
200
254
  def _select_recipients(
201
255
  self, output: OutT, ctx: RunContext[CtxT] | None = None
@@ -212,108 +266,7 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT, MemT, CtxT]
212
266
 
213
267
  return func
214
268
 
215
- async def _process(
216
- self,
217
- chat_inputs: Any | None = None,
218
- *,
219
- in_args: InT | None = None,
220
- memory: MemT,
221
- call_id: str,
222
- ctx: RunContext[CtxT] | None = None,
223
- ) -> OutT:
224
- return cast("OutT", in_args)
225
-
226
- async def _process_stream(
227
- self,
228
- chat_inputs: Any | None = None,
229
- *,
230
- in_args: InT | None = None,
231
- memory: MemT,
232
- call_id: str,
233
- ctx: RunContext[CtxT] | None = None,
234
- ) -> AsyncIterator[Event[Any]]:
235
- output = cast("OutT", in_args)
236
- yield ProcPayloadOutputEvent(data=output, proc_name=self.name, call_id=call_id)
237
-
238
- async def _run_single_once(
239
- self,
240
- chat_inputs: Any | None = None,
241
- *,
242
- in_args: InT | None = None,
243
- forgetful: bool = False,
244
- call_id: str,
245
- ctx: RunContext[CtxT] | None = None,
246
- ) -> Packet[OutT]:
247
- _memory = self.memory.model_copy(deep=True) if forgetful else self.memory
248
-
249
- output = await self._process(
250
- chat_inputs=chat_inputs,
251
- in_args=in_args,
252
- memory=_memory,
253
- call_id=call_id,
254
- ctx=ctx,
255
- )
256
- val_output = self._validate_output(output, call_id=call_id)
257
-
258
- recipients = self._select_recipients(output=val_output, ctx=ctx)
259
- self._validate_recipients(recipients, call_id=call_id)
260
-
261
- return Packet(payloads=[val_output], sender=self.name, recipients=recipients)
262
-
263
- async def _run_single(
264
- self,
265
- chat_inputs: Any | None = None,
266
- *,
267
- in_args: InT | None = None,
268
- forgetful: bool = False,
269
- call_id: str,
270
- ctx: RunContext[CtxT] | None = None,
271
- ) -> Packet[OutT]:
272
- n_attempt = 0
273
- while n_attempt <= self.max_retries:
274
- try:
275
- return await self._run_single_once(
276
- chat_inputs=chat_inputs,
277
- in_args=in_args,
278
- forgetful=forgetful,
279
- call_id=call_id,
280
- ctx=ctx,
281
- )
282
- except Exception as err:
283
- err_message = (
284
- f"\nProcessor run failed [proc_name={self.name}; call_id={call_id}]"
285
- )
286
- n_attempt += 1
287
- if n_attempt > self.max_retries:
288
- if n_attempt == 1:
289
- logger.warning(f"{err_message}:\n{err}")
290
- if n_attempt > 1:
291
- logger.warning(f"{err_message} after retrying:\n{err}")
292
- raise ProcRunError(proc_name=self.name, call_id=call_id) from err
293
-
294
- logger.warning(f"{err_message} (retry attempt {n_attempt}):\n{err}")
295
-
296
- raise ProcRunError(proc_name=self.name, call_id=call_id)
297
-
298
- async def _run_par(
299
- self, in_args: list[InT], call_id: str, ctx: RunContext[CtxT] | None = None
300
- ) -> Packet[OutT]:
301
- tasks = [
302
- self._run_single(
303
- in_args=inp, forgetful=True, call_id=f"{call_id}/{idx}", ctx=ctx
304
- )
305
- for idx, inp in enumerate(in_args)
306
- ]
307
- out_packets = await asyncio.gather(*tasks)
308
-
309
- self._validate_par_recipients(out_packets, call_id=call_id)
310
-
311
- return Packet(
312
- payloads=[out_packet.payloads[0] for out_packet in out_packets],
313
- sender=self.name,
314
- recipients=out_packets[0].recipients,
315
- )
316
-
269
+ @abstractmethod
317
270
  async def run(
318
271
  self,
319
272
  chat_inputs: Any | None = None,
@@ -324,140 +277,9 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT, MemT, CtxT]
324
277
  call_id: str | None = None,
325
278
  ctx: RunContext[CtxT] | None = None,
326
279
  ) -> Packet[OutT]:
327
- call_id = self._generate_call_id(call_id)
328
-
329
- val_in_args = self._validate_inputs(
330
- call_id=call_id,
331
- chat_inputs=chat_inputs,
332
- in_packet=in_packet,
333
- in_args=in_args,
334
- )
335
-
336
- if val_in_args and len(val_in_args) > 1:
337
- return await self._run_par(in_args=val_in_args, call_id=call_id, ctx=ctx)
338
- return await self._run_single(
339
- chat_inputs=chat_inputs,
340
- in_args=val_in_args[0] if val_in_args else None,
341
- forgetful=forgetful,
342
- call_id=call_id,
343
- ctx=ctx,
344
- )
345
-
346
- async def _run_single_stream_once(
347
- self,
348
- chat_inputs: Any | None = None,
349
- *,
350
- in_args: InT | None = None,
351
- forgetful: bool = False,
352
- call_id: str,
353
- ctx: RunContext[CtxT] | None = None,
354
- ) -> AsyncIterator[Event[Any]]:
355
- _memory = self.memory.model_copy(deep=True) if forgetful else self.memory
356
-
357
- output: OutT | None = None
358
- async for event in self._process_stream(
359
- chat_inputs=chat_inputs,
360
- in_args=in_args,
361
- memory=_memory,
362
- call_id=call_id,
363
- ctx=ctx,
364
- ):
365
- if isinstance(event, ProcPayloadOutputEvent):
366
- output = event.data
367
- yield event
368
-
369
- assert output is not None
370
-
371
- val_output = self._validate_output(output, call_id=call_id)
372
-
373
- recipients = self._select_recipients(output=val_output, ctx=ctx)
374
- self._validate_recipients(recipients, call_id=call_id)
375
-
376
- out_packet = Packet[OutT](
377
- payloads=[val_output], sender=self.name, recipients=recipients
378
- )
379
-
380
- yield ProcPacketOutputEvent(
381
- data=out_packet, proc_name=self.name, call_id=call_id
382
- )
383
-
384
- async def _run_single_stream(
385
- self,
386
- chat_inputs: Any | None = None,
387
- *,
388
- in_args: InT | None = None,
389
- forgetful: bool = False,
390
- call_id: str,
391
- ctx: RunContext[CtxT] | None = None,
392
- ) -> AsyncIterator[Event[Any]]:
393
- n_attempt = 0
394
- while n_attempt <= self.max_retries:
395
- try:
396
- async for event in self._run_single_stream_once(
397
- chat_inputs=chat_inputs,
398
- in_args=in_args,
399
- forgetful=forgetful,
400
- call_id=call_id,
401
- ctx=ctx,
402
- ):
403
- yield event
404
-
405
- return
406
-
407
- except Exception as err:
408
- err_data = ProcStreamingErrorData(error=err, call_id=call_id)
409
- yield ProcStreamingErrorEvent(
410
- data=err_data, proc_name=self.name, call_id=call_id
411
- )
412
-
413
- err_message = (
414
- "\nStreaming processor run failed "
415
- f"[proc_name={self.name}; call_id={call_id}]"
416
- )
417
-
418
- n_attempt += 1
419
- if n_attempt > self.max_retries:
420
- if n_attempt == 1:
421
- logger.warning(f"{err_message}:\n{err}")
422
- if n_attempt > 1:
423
- logger.warning(f"{err_message} after retrying:\n{err}")
424
- raise ProcRunError(proc_name=self.name, call_id=call_id) from err
425
-
426
- logger.warning(f"{err_message} (retry attempt {n_attempt}):\n{err}")
427
-
428
- async def _run_par_stream(
429
- self,
430
- in_args: list[InT],
431
- call_id: str,
432
- ctx: RunContext[CtxT] | None = None,
433
- ) -> AsyncIterator[Event[Any]]:
434
- streams = [
435
- self._run_single_stream(
436
- in_args=inp, forgetful=True, call_id=f"{call_id}/{idx}", ctx=ctx
437
- )
438
- for idx, inp in enumerate(in_args)
439
- ]
440
-
441
- out_packets_map: dict[int, Packet[OutT]] = {}
442
- async for idx, event in stream_concurrent(streams):
443
- if isinstance(event, ProcPacketOutputEvent):
444
- out_packets_map[idx] = event.data
445
- else:
446
- yield event
447
-
448
- out_packet = Packet(
449
- payloads=[
450
- out_packet.payloads[0]
451
- for _, out_packet in sorted(out_packets_map.items())
452
- ],
453
- sender=self.name,
454
- recipients=out_packets_map[0].recipients,
455
- )
456
-
457
- yield ProcPacketOutputEvent(
458
- data=out_packet, proc_name=self.name, call_id=call_id
459
- )
280
+ pass
460
281
 
282
+ @abstractmethod
461
283
  async def run_stream(
462
284
  self,
463
285
  chat_inputs: Any | None = None,
@@ -468,27 +290,7 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT, MemT, CtxT]
468
290
  call_id: str | None = None,
469
291
  ctx: RunContext[CtxT] | None = None,
470
292
  ) -> AsyncIterator[Event[Any]]:
471
- call_id = self._generate_call_id(call_id)
472
-
473
- val_in_args = self._validate_inputs(
474
- call_id=call_id,
475
- chat_inputs=chat_inputs,
476
- in_packet=in_packet,
477
- in_args=in_args,
478
- )
479
-
480
- if val_in_args and len(val_in_args) > 1:
481
- stream = self._run_par_stream(in_args=val_in_args, call_id=call_id, ctx=ctx)
482
- else:
483
- stream = self._run_single_stream(
484
- chat_inputs=chat_inputs,
485
- in_args=val_in_args[0] if val_in_args else None,
486
- forgetful=forgetful,
487
- call_id=call_id,
488
- ctx=ctx,
489
- )
490
- async for event in stream:
491
- yield event
293
+ yield DummyEvent()
492
294
 
493
295
  @final
494
296
  def as_tool(
@@ -510,7 +312,7 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT, MemT, CtxT]
510
312
 
511
313
  async def run(self, inp: InT, ctx: RunContext[CtxT] | None = None) -> OutT:
512
314
  result = await processor_instance.run(
513
- in_args=in_type.model_validate(inp), forgetful=True, ctx=ctx
315
+ in_args=inp, forgetful=True, ctx=ctx
514
316
  )
515
317
 
516
318
  return result.payloads[0]