grasp_agents 0.5.3__py3-none-any.whl → 0.5.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,35 +1,26 @@
1
1
  import json
2
- from collections.abc import Mapping, Sequence
3
- from typing import ClassVar, Generic, Protocol, TypeAlias, TypeVar
2
+ from collections.abc import Sequence
3
+ from typing import ClassVar, Generic, Protocol, TypeAlias, TypeVar, final
4
4
 
5
5
  from pydantic import BaseModel, TypeAdapter
6
6
 
7
- from .errors import InputPromptBuilderError, SystemPromptBuilderError
7
+ from .errors import InputPromptBuilderError
8
8
  from .generics_utils import AutoInstanceAttributesMixin
9
9
  from .run_context import CtxT, RunContext
10
10
  from .typing.content import Content, ImageData
11
- from .typing.io import InT, LLMPrompt, LLMPromptArgs
11
+ from .typing.io import InT, LLMPrompt
12
12
  from .typing.message import UserMessage
13
13
 
14
14
  _InT_contra = TypeVar("_InT_contra", contravariant=True)
15
15
 
16
16
 
17
- class MakeSystemPromptHandler(Protocol[CtxT]):
18
- def __call__(
19
- self,
20
- sys_args: LLMPromptArgs | None,
21
- *,
22
- ctx: RunContext[CtxT] | None,
23
- ) -> str | None: ...
17
+ class SystemPromptBuilder(Protocol[CtxT]):
18
+ def __call__(self, ctx: RunContext[CtxT] | None) -> str | None: ...
24
19
 
25
20
 
26
- class MakeInputContentHandler(Protocol[_InT_contra, CtxT]):
21
+ class InputContentBuilder(Protocol[_InT_contra, CtxT]):
27
22
  def __call__(
28
- self,
29
- *,
30
- in_args: _InT_contra | None,
31
- usr_args: LLMPromptArgs | None,
32
- ctx: RunContext[CtxT] | None,
23
+ self, in_args: _InT_contra | None, *, ctx: RunContext[CtxT] | None
33
24
  ) -> Content: ...
34
25
 
35
26
 
@@ -40,134 +31,101 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
40
31
  _generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {0: "_in_type"}
41
32
 
42
33
  def __init__(
43
- self,
44
- agent_name: str,
45
- sys_prompt_template: LLMPrompt | None,
46
- in_prompt_template: LLMPrompt | None,
47
- sys_args_schema: type[LLMPromptArgs] | None = None,
48
- usr_args_schema: type[LLMPromptArgs] | None = None,
34
+ self, agent_name: str, sys_prompt: LLMPrompt | None, in_prompt: LLMPrompt | None
49
35
  ):
50
36
  self._in_type: type[InT]
51
37
  super().__init__()
52
38
 
53
39
  self._agent_name = agent_name
54
- self.sys_prompt_template = sys_prompt_template
55
- self.in_prompt_template = in_prompt_template
56
- self.sys_args_schema = sys_args_schema
57
- self.usr_args_schema = usr_args_schema
58
- self.make_system_prompt_impl: MakeSystemPromptHandler[CtxT] | None = None
59
- self.make_input_content_impl: MakeInputContentHandler[InT, CtxT] | None = None
40
+ self.sys_prompt = sys_prompt
41
+ self.in_prompt = in_prompt
42
+ self.system_prompt_builder: SystemPromptBuilder[CtxT] | None = None
43
+ self.input_content_builder: InputContentBuilder[InT, CtxT] | None = None
60
44
 
61
45
  self._in_args_type_adapter: TypeAdapter[InT] = TypeAdapter(self._in_type)
62
46
 
63
- def make_system_prompt(
64
- self, sys_args: LLMPromptArgs | None = None, ctx: RunContext[CtxT] | None = None
65
- ) -> str | None:
66
- if self.sys_prompt_template is None:
67
- return None
68
-
69
- val_sys_args = sys_args
70
- if sys_args is not None:
71
- if self.sys_args_schema is not None:
72
- val_sys_args = self.sys_args_schema.model_validate(sys_args)
73
- else:
74
- raise SystemPromptBuilderError(
75
- "System prompt template and arguments is set, but system arguments "
76
- "schema is not provided."
77
- )
47
+ @final
48
+ def build_system_prompt(self, ctx: RunContext[CtxT] | None = None) -> str | None:
49
+ if self.system_prompt_builder:
50
+ return self.system_prompt_builder(ctx=ctx)
78
51
 
79
- if self.make_system_prompt_impl:
80
- return self.make_system_prompt_impl(sys_args=val_sys_args, ctx=ctx)
52
+ return self.sys_prompt
81
53
 
82
- sys_args_dict = val_sys_args.model_dump() if val_sys_args else {}
54
+ def _validate_input_args(self, in_args: InT) -> InT:
55
+ val_in_args = self._in_args_type_adapter.validate_python(in_args)
56
+ if isinstance(val_in_args, BaseModel):
57
+ has_image = self._has_image_data(val_in_args)
58
+ if has_image and self.in_prompt is None:
59
+ raise InputPromptBuilderError(
60
+ proc_name=self._agent_name,
61
+ message="BaseModel input arguments contain ImageData, "
62
+ "but input prompt template is not set "
63
+ f"[agent_name={self._agent_name}]. Cannot format input arguments.",
64
+ )
65
+ elif self.in_prompt is not None:
66
+ raise InputPromptBuilderError(
67
+ proc_name=self._agent_name,
68
+ message="Cannot use the input prompt template with "
69
+ f"non-BaseModel input arguments [agent_name={self._agent_name}]",
70
+ )
83
71
 
84
- return self.sys_prompt_template.format(**sys_args_dict)
72
+ return val_in_args
85
73
 
86
- def make_input_content(
74
+ @final
75
+ def _build_input_content(
87
76
  self,
88
- *,
89
- in_args: InT | None,
90
- usr_args: LLMPromptArgs | None,
77
+ in_args: InT | None = None,
91
78
  ctx: RunContext[CtxT] | None = None,
92
79
  ) -> Content:
93
- val_in_args, val_usr_args = self._validate_prompt_args(
94
- in_args=in_args, usr_args=usr_args
95
- )
96
-
97
- if self.make_input_content_impl:
98
- return self.make_input_content_impl(
99
- in_args=val_in_args, usr_args=val_usr_args, ctx=ctx
100
- )
80
+ val_in_args = in_args
81
+ if in_args is not None:
82
+ val_in_args = self._validate_input_args(in_args=in_args)
101
83
 
102
- combined_args = self._combine_args(in_args=val_in_args, usr_args=val_usr_args)
103
- if isinstance(combined_args, str):
104
- return Content.from_text(combined_args)
84
+ if self.input_content_builder:
85
+ return self.input_content_builder(in_args=val_in_args, ctx=ctx)
105
86
 
106
- if self.in_prompt_template is not None:
107
- return Content.from_formatted_prompt(
108
- self.in_prompt_template, **combined_args
87
+ if val_in_args is None:
88
+ raise InputPromptBuilderError(
89
+ proc_name=self._agent_name,
90
+ message="Input arguments are not provided, "
91
+ f"but input content is required [agent_name={self._agent_name}]",
109
92
  )
110
93
 
111
- return Content.from_text(json.dumps(combined_args, indent=2))
94
+ if issubclass(self._in_type, BaseModel) and isinstance(val_in_args, BaseModel):
95
+ val_in_args_map = self._format_pydantic_prompt_args(val_in_args)
96
+ if self.in_prompt is not None:
97
+ return Content.from_formatted_prompt(self.in_prompt, **val_in_args_map)
98
+ return Content.from_text(json.dumps(val_in_args_map, indent=2))
99
+
100
+ fmt_in_args = self._in_args_type_adapter.dump_json(
101
+ val_in_args, indent=2, warnings="error"
102
+ ).decode("utf-8")
103
+ return Content.from_text(fmt_in_args)
112
104
 
113
- def make_input_message(
105
+ @final
106
+ def build_input_message(
114
107
  self,
115
108
  chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
116
109
  in_args: InT | None = None,
117
- usr_args: LLMPromptArgs | None = None,
118
110
  ctx: RunContext[CtxT] | None = None,
119
111
  ) -> UserMessage | None:
120
- if chat_inputs is None and in_args is None and usr_args is None:
121
- return None
112
+ if chat_inputs is not None and in_args is not None:
113
+ raise InputPromptBuilderError(
114
+ proc_name=self._agent_name,
115
+ message="Cannot use both chat inputs and input arguments "
116
+ f"at the same time [agent_name={self._agent_name}]",
117
+ )
122
118
 
123
119
  if chat_inputs:
124
120
  if isinstance(chat_inputs, LLMPrompt):
125
121
  return UserMessage.from_text(chat_inputs, name=self._agent_name)
126
122
  return UserMessage.from_content_parts(chat_inputs, name=self._agent_name)
127
123
 
128
- in_content = self.make_input_content(
129
- in_args=in_args, usr_args=usr_args, ctx=ctx
124
+ return UserMessage(
125
+ content=self._build_input_content(in_args=in_args, ctx=ctx),
126
+ name=self._agent_name,
130
127
  )
131
128
 
132
- return UserMessage(content=in_content, name=self._agent_name)
133
-
134
- def _validate_prompt_args(
135
- self,
136
- *,
137
- in_args: InT | None,
138
- usr_args: LLMPromptArgs | None,
139
- ) -> tuple[InT | None, LLMPromptArgs | None]:
140
- val_usr_args = usr_args
141
- if usr_args is not None:
142
- if self.in_prompt_template is None:
143
- raise InputPromptBuilderError(
144
- "Input prompt template is not set, but user arguments are provided."
145
- )
146
- if self.usr_args_schema is None:
147
- raise InputPromptBuilderError(
148
- "User arguments schema is not provided, but user arguments are "
149
- "given."
150
- )
151
- val_usr_args = self.usr_args_schema.model_validate(usr_args)
152
-
153
- val_in_args = in_args
154
- if in_args is not None:
155
- val_in_args = self._in_args_type_adapter.validate_python(in_args)
156
- if isinstance(val_in_args, BaseModel):
157
- has_image = self._has_image_data(val_in_args)
158
- if has_image and self.in_prompt_template is None:
159
- raise InputPromptBuilderError(
160
- "BaseModel input arguments contain ImageData, but input prompt "
161
- "template is not set. Cannot format input arguments."
162
- )
163
- elif self.in_prompt_template is not None:
164
- raise InputPromptBuilderError(
165
- "Cannot use the input prompt template with "
166
- "non-BaseModel input arguments."
167
- )
168
-
169
- return val_in_args, val_usr_args
170
-
171
129
  @staticmethod
172
130
  def _has_image_data(inp: BaseModel) -> bool:
173
131
  contains_image_data = False
@@ -180,39 +138,18 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
180
138
  @staticmethod
181
139
  def _format_pydantic_prompt_args(inp: BaseModel) -> dict[str, PromptArgumentType]:
182
140
  formatted_args: dict[str, PromptArgumentType] = {}
183
- for field in type(inp).model_fields:
184
- if field == "selected_recipients":
141
+ for field_name, field_info in type(inp).model_fields.items():
142
+ if field_info.exclude:
185
143
  continue
186
144
 
187
- val = getattr(inp, field)
145
+ val = getattr(inp, field_name)
188
146
  if isinstance(val, (int, str, bool, ImageData)):
189
- formatted_args[field] = val
147
+ formatted_args[field_name] = val
190
148
  else:
191
- formatted_args[field] = (
149
+ formatted_args[field_name] = (
192
150
  TypeAdapter(type(val)) # type: ignore[return-value]
193
151
  .dump_json(val, indent=2, warnings="error")
194
152
  .decode("utf-8")
195
153
  )
196
154
 
197
155
  return formatted_args
198
-
199
- def _combine_args(
200
- self, *, in_args: InT | None, usr_args: LLMPromptArgs | None
201
- ) -> Mapping[str, PromptArgumentType] | str:
202
- fmt_usr_args = self._format_pydantic_prompt_args(usr_args) if usr_args else {}
203
-
204
- if in_args is None:
205
- return fmt_usr_args
206
-
207
- if isinstance(in_args, BaseModel):
208
- fmt_in_args = self._format_pydantic_prompt_args(in_args)
209
- return fmt_in_args | fmt_usr_args
210
-
211
- combined_args_str = self._in_args_type_adapter.dump_json(
212
- in_args, indent=2, warnings="error"
213
- ).decode("utf-8")
214
- if usr_args is not None:
215
- fmt_usr_args_str = usr_args.model_dump_json(indent=2, warnings="error")
216
- combined_args_str += "\n" + fmt_usr_args_str
217
-
218
- return combined_args_str
@@ -6,39 +6,26 @@ from pydantic import BaseModel, ConfigDict, Field
6
6
  from grasp_agents.typing.completion import Completion
7
7
 
8
8
  from .printer import ColoringMode, Printer
9
- from .typing.io import LLMPromptArgs, ProcName
9
+ from .typing.io import ProcName
10
10
  from .usage_tracker import UsageTracker
11
11
 
12
-
13
- class RunArgs(BaseModel):
14
- sys: LLMPromptArgs | None = None
15
- usr: LLMPromptArgs | None = None
16
-
17
- model_config = ConfigDict(extra="forbid")
18
-
19
-
20
12
  CtxT = TypeVar("CtxT")
21
13
 
22
14
 
23
15
  class RunContext(BaseModel, Generic[CtxT]):
24
16
  state: CtxT | None = None
25
17
 
26
- run_args: dict[ProcName, RunArgs] = Field(default_factory=dict)
27
-
28
- is_streaming: bool = False
29
- result: Any | None = None
30
-
31
18
  completions: dict[ProcName, list[Completion]] = Field(
32
19
  default_factory=lambda: defaultdict(list)
33
20
  )
34
21
  usage_tracker: UsageTracker = Field(default_factory=UsageTracker)
35
22
 
36
23
  printer: Printer | None = None
37
- print_messages: bool = False
24
+ log_messages: bool = False
38
25
  color_messages_by: ColoringMode = "role"
39
26
 
40
27
  def model_post_init(self, context: Any) -> None: # noqa: ARG002
41
- if self.print_messages:
28
+ if self.log_messages:
42
29
  self.printer = Printer(color_by=self.color_messages_by)
43
30
 
44
31
  model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
grasp_agents/runner.py CHANGED
@@ -1,42 +1,131 @@
1
1
  from collections.abc import AsyncIterator, Sequence
2
+ from functools import partial
2
3
  from typing import Any, Generic
3
4
 
4
- from .comm_processor import CommProcessor
5
+ from .errors import RunnerError
6
+ from .packet import Packet, StartPacket
7
+ from .packet_pool import END_PROC_NAME, PacketPool
8
+ from .processor import Processor
5
9
  from .run_context import CtxT, RunContext
6
- from .typing.events import Event
10
+ from .typing.events import Event, ProcPacketOutputEvent, RunResultEvent
11
+ from .typing.io import OutT
7
12
 
8
13
 
9
- class Runner(Generic[CtxT]):
14
+ class Runner(Generic[OutT, CtxT]):
10
15
  def __init__(
11
16
  self,
12
- start_proc: CommProcessor[Any, Any, Any, CtxT],
13
- procs: Sequence[CommProcessor[Any, Any, Any, CtxT]],
17
+ entry_proc: Processor[Any, Any, Any, CtxT],
18
+ procs: Sequence[Processor[Any, Any, Any, CtxT]],
14
19
  ctx: RunContext[CtxT] | None = None,
15
20
  ) -> None:
16
- if start_proc not in procs:
17
- raise ValueError(
18
- f"Start processor {start_proc.name} must be in the list of processors: "
21
+ if entry_proc not in procs:
22
+ raise RunnerError(
23
+ f"Entry processor {entry_proc.name} must be in the list of processors: "
19
24
  f"{', '.join(proc.name for proc in procs)}"
20
25
  )
21
- self._start_proc = start_proc
26
+ if sum(1 for proc in procs if END_PROC_NAME in (proc.recipients or [])) != 1:
27
+ raise RunnerError(
28
+ "There must be exactly one processor with recipient 'END'."
29
+ )
30
+
31
+ self._entry_proc = entry_proc
22
32
  self._procs = procs
23
33
  self._ctx = ctx or RunContext[CtxT]()
34
+ self._packet_pool: PacketPool[CtxT] = PacketPool()
24
35
 
25
36
  @property
26
37
  def ctx(self) -> RunContext[CtxT]:
27
38
  return self._ctx
28
39
 
29
- async def run(self, **run_args: Any) -> Any:
30
- self._ctx.is_streaming = False
31
- for proc in self._procs:
32
- proc.start_listening(ctx=self._ctx, **run_args)
33
- await self._start_proc.run(**run_args, ctx=self._ctx)
40
+ def _unpack_packet(
41
+ self, packet: Packet[Any] | None
42
+ ) -> tuple[Packet[Any] | None, Any | None]:
43
+ if isinstance(packet, StartPacket):
44
+ return None, packet.chat_inputs
45
+ return packet, None
46
+
47
+ async def _packet_handler(
48
+ self,
49
+ proc: Processor[Any, Any, Any, CtxT],
50
+ pool: PacketPool[CtxT],
51
+ packet: Packet[Any],
52
+ ctx: RunContext[CtxT],
53
+ **run_kwargs: Any,
54
+ ) -> None:
55
+ _in_packet, _chat_inputs = self._unpack_packet(packet)
56
+ out_packet = await proc.run(
57
+ chat_inputs=_chat_inputs, in_packet=_in_packet, ctx=ctx, **run_kwargs
58
+ )
59
+ await pool.post(out_packet)
60
+
61
+ async def _packet_handler_stream(
62
+ self,
63
+ proc: Processor[Any, Any, Any, CtxT],
64
+ pool: PacketPool[CtxT],
65
+ packet: Packet[Any],
66
+ ctx: RunContext[CtxT],
67
+ **run_kwargs: Any,
68
+ ) -> None:
69
+ _in_packet, _chat_inputs = self._unpack_packet(packet)
34
70
 
35
- return self._ctx.result
71
+ out_packet: Packet[Any] | None = None
72
+ async for event in proc.run_stream(
73
+ chat_inputs=_chat_inputs, in_packet=_in_packet, ctx=ctx, **run_kwargs
74
+ ):
75
+ if isinstance(event, ProcPacketOutputEvent):
76
+ out_packet = event.data
77
+ await pool.push_event(event)
36
78
 
37
- async def run_stream(self, **run_args: Any) -> AsyncIterator[Event[Any]]:
38
- self._ctx.is_streaming = True
39
- for proc in self._procs:
40
- proc.start_listening(ctx=self._ctx, **run_args)
41
- async for event in self._start_proc.run_stream(**run_args, ctx=self._ctx):
42
- yield event
79
+ assert out_packet is not None
80
+
81
+ await pool.post(out_packet)
82
+
83
+ async def run(
84
+ self,
85
+ chat_input: Any = "start",
86
+ **run_args: Any,
87
+ ) -> Packet[OutT]:
88
+ async with PacketPool[CtxT]() as pool:
89
+ for proc in self._procs:
90
+ pool.register_packet_handler(
91
+ proc_name=proc.name,
92
+ handler=partial(self._packet_handler, proc, pool),
93
+ ctx=self._ctx,
94
+ **run_args,
95
+ )
96
+ await pool.post(
97
+ StartPacket[Any](
98
+ recipients=[self._entry_proc.name], chat_inputs=chat_input
99
+ )
100
+ )
101
+ return await pool.final_result()
102
+
103
+ async def run_stream(
104
+ self,
105
+ chat_input: Any = "start",
106
+ **run_args: Any,
107
+ ) -> AsyncIterator[Event[Any]]:
108
+ async with PacketPool[CtxT]() as pool:
109
+ for proc in self._procs:
110
+ pool.register_packet_handler(
111
+ proc_name=proc.name,
112
+ handler=partial(self._packet_handler_stream, proc, pool),
113
+ ctx=self._ctx,
114
+ **run_args,
115
+ )
116
+ await pool.post(
117
+ StartPacket[Any](
118
+ recipients=[self._entry_proc.name], chat_inputs=chat_input
119
+ )
120
+ )
121
+ async for event in pool.stream_events():
122
+ if isinstance(
123
+ event, ProcPacketOutputEvent
124
+ ) and event.data.recipients == [END_PROC_NAME]:
125
+ yield RunResultEvent(
126
+ data=event.data,
127
+ proc_name=event.proc_name,
128
+ call_id=event.call_id,
129
+ )
130
+ else:
131
+ yield event
@@ -136,16 +136,20 @@ class ProcPayloadOutputEvent(Event[Any], frozen=True):
136
136
 
137
137
 
138
138
  class ProcPacketOutputEvent(Event[Packet[Any]], frozen=True):
139
- type: Literal[EventType.PACKET_OUT] = EventType.PACKET_OUT
140
- source: Literal[EventSourceType.PROC] = EventSourceType.PROC
139
+ type: Literal[EventType.PACKET_OUT, EventType.WORKFLOW_RES, EventType.RUN_RES] = (
140
+ EventType.PACKET_OUT
141
+ )
142
+ source: Literal[
143
+ EventSourceType.PROC, EventSourceType.WORKFLOW, EventSourceType.RUN
144
+ ] = EventSourceType.PROC
141
145
 
142
146
 
143
- class WorkflowResultEvent(Event[Packet[Any]], frozen=True):
147
+ class WorkflowResultEvent(ProcPacketOutputEvent, frozen=True):
144
148
  type: Literal[EventType.WORKFLOW_RES] = EventType.WORKFLOW_RES
145
149
  source: Literal[EventSourceType.WORKFLOW] = EventSourceType.WORKFLOW
146
150
 
147
151
 
148
- class RunResultEvent(Event[Packet[Any]], frozen=True):
152
+ class RunResultEvent(ProcPacketOutputEvent, frozen=True):
149
153
  type: Literal[EventType.RUN_RES] = EventType.RUN_RES
150
154
  source: Literal[EventSourceType.RUN] = EventSourceType.RUN
151
155
 
grasp_agents/typing/io.py CHANGED
@@ -1,16 +1,9 @@
1
1
  from typing import TypeAlias, TypeVar
2
2
 
3
- from pydantic import BaseModel
4
-
5
3
  ProcName: TypeAlias = str
6
4
 
7
5
 
8
6
  InT = TypeVar("InT")
9
- OutT_co = TypeVar("OutT_co", covariant=True)
10
-
11
-
12
- class LLMPromptArgs(BaseModel):
13
- pass
14
-
7
+ OutT = TypeVar("OutT")
15
8
 
16
9
  LLMPrompt: TypeAlias = str
@@ -4,11 +4,11 @@ from logging import getLogger
4
4
  from typing import Any, Generic, Protocol, TypeVar, cast, final
5
5
 
6
6
  from ..errors import WorkflowConstructionError
7
- from ..packet_pool import Packet, PacketPool
7
+ from ..packet_pool import Packet
8
8
  from ..processor import Processor
9
9
  from ..run_context import CtxT, RunContext
10
10
  from ..typing.events import Event, ProcPacketOutputEvent, WorkflowResultEvent
11
- from ..typing.io import InT, OutT_co, ProcName
11
+ from ..typing.io import InT, OutT, ProcName
12
12
  from .workflow_processor import WorkflowProcessor
13
13
 
14
14
  logger = getLogger(__name__)
@@ -25,15 +25,12 @@ class ExitWorkflowLoopHandler(Protocol[_OutT_contra, CtxT]):
25
25
  ) -> bool: ...
26
26
 
27
27
 
28
- class LoopedWorkflow(
29
- WorkflowProcessor[InT, OutT_co, CtxT], Generic[InT, OutT_co, CtxT]
30
- ):
28
+ class LoopedWorkflow(WorkflowProcessor[InT, OutT, CtxT], Generic[InT, OutT, CtxT]):
31
29
  def __init__(
32
30
  self,
33
31
  name: ProcName,
34
32
  subprocs: Sequence[Processor[Any, Any, Any, CtxT]],
35
- exit_proc: Processor[Any, OutT_co, Any, CtxT],
36
- packet_pool: PacketPool[CtxT] | None = None,
33
+ exit_proc: Processor[Any, OutT, Any, CtxT],
37
34
  recipients: list[ProcName] | None = None,
38
35
  max_retries: int = 0,
39
36
  max_iterations: int = 10,
@@ -43,7 +40,6 @@ class LoopedWorkflow(
43
40
  name=name,
44
41
  start_proc=subprocs[0],
45
42
  end_proc=exit_proc,
46
- packet_pool=packet_pool,
47
43
  recipients=recipients,
48
44
  max_retries=max_retries,
49
45
  )
@@ -64,24 +60,22 @@ class LoopedWorkflow(
64
60
 
65
61
  self._max_iterations = max_iterations
66
62
 
67
- self._exit_workflow_loop_impl: ExitWorkflowLoopHandler[OutT_co, CtxT] | None = (
68
- None
69
- )
63
+ self._exit_workflow_loop_impl: ExitWorkflowLoopHandler[OutT, CtxT] | None = None
70
64
 
71
65
  @property
72
66
  def max_iterations(self) -> int:
73
67
  return self._max_iterations
74
68
 
75
69
  def exit_workflow_loop(
76
- self, func: ExitWorkflowLoopHandler[OutT_co, CtxT]
77
- ) -> ExitWorkflowLoopHandler[OutT_co, CtxT]:
70
+ self, func: ExitWorkflowLoopHandler[OutT, CtxT]
71
+ ) -> ExitWorkflowLoopHandler[OutT, CtxT]:
78
72
  self._exit_workflow_loop_impl = func
79
73
 
80
74
  return func
81
75
 
82
76
  def _exit_workflow_loop(
83
77
  self,
84
- out_packet: Packet[OutT_co],
78
+ out_packet: Packet[OutT],
85
79
  *,
86
80
  ctx: RunContext[CtxT] | None = None,
87
81
  **kwargs: Any,
@@ -101,12 +95,12 @@ class LoopedWorkflow(
101
95
  call_id: str | None = None,
102
96
  forgetful: bool = False,
103
97
  ctx: RunContext[CtxT] | None = None,
104
- ) -> Packet[OutT_co]:
98
+ ) -> Packet[OutT]:
105
99
  call_id = self._generate_call_id(call_id)
106
100
 
107
101
  packet = in_packet
108
102
  num_iterations = 0
109
- exit_packet: Packet[OutT_co] | None = None
103
+ exit_packet: Packet[OutT] | None = None
110
104
 
111
105
  while True:
112
106
  for subproc in self.subprocs:
@@ -121,7 +115,7 @@ class LoopedWorkflow(
121
115
 
122
116
  if subproc is self._end_proc:
123
117
  num_iterations += 1
124
- exit_packet = cast("Packet[OutT_co]", packet)
118
+ exit_packet = cast("Packet[OutT]", packet)
125
119
  if self._exit_workflow_loop(exit_packet, ctx=ctx):
126
120
  return exit_packet
127
121
  if num_iterations >= self._max_iterations:
@@ -149,7 +143,7 @@ class LoopedWorkflow(
149
143
 
150
144
  packet = in_packet
151
145
  num_iterations = 0
152
- exit_packet: Packet[OutT_co] | None = None
146
+ exit_packet: Packet[OutT] | None = None
153
147
 
154
148
  while True:
155
149
  for subproc in self.subprocs:
@@ -167,7 +161,7 @@ class LoopedWorkflow(
167
161
 
168
162
  if subproc is self._end_proc:
169
163
  num_iterations += 1
170
- exit_packet = cast("Packet[OutT_co]", packet)
164
+ exit_packet = cast("Packet[OutT]", packet)
171
165
  if self._exit_workflow_loop(exit_packet, ctx=ctx):
172
166
  yield WorkflowResultEvent(
173
167
  data=exit_packet, proc_name=self.name, call_id=call_id