grasp_agents 0.5.2__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,5 @@
1
1
  import json
2
- from collections.abc import Mapping, Sequence
2
+ from collections.abc import Sequence
3
3
  from typing import ClassVar, Generic, Protocol, TypeAlias, TypeVar
4
4
 
5
5
  from pydantic import BaseModel, TypeAdapter
@@ -26,9 +26,8 @@ class MakeSystemPromptHandler(Protocol[CtxT]):
26
26
  class MakeInputContentHandler(Protocol[_InT_contra, CtxT]):
27
27
  def __call__(
28
28
  self,
29
- *,
30
29
  in_args: _InT_contra | None,
31
- usr_args: LLMPromptArgs | None,
30
+ *,
32
31
  ctx: RunContext[CtxT] | None,
33
32
  ) -> Content: ...
34
33
 
@@ -42,132 +41,129 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
42
41
  def __init__(
43
42
  self,
44
43
  agent_name: str,
45
- sys_prompt_template: LLMPrompt | None,
46
- in_prompt_template: LLMPrompt | None,
44
+ sys_prompt: LLMPrompt | None,
45
+ in_prompt: LLMPrompt | None,
47
46
  sys_args_schema: type[LLMPromptArgs] | None = None,
48
- usr_args_schema: type[LLMPromptArgs] | None = None,
49
47
  ):
50
48
  self._in_type: type[InT]
51
49
  super().__init__()
52
50
 
53
51
  self._agent_name = agent_name
54
- self.sys_prompt_template = sys_prompt_template
55
- self.in_prompt_template = in_prompt_template
52
+ self.sys_prompt = sys_prompt
53
+ self.in_prompt = in_prompt
56
54
  self.sys_args_schema = sys_args_schema
57
- self.usr_args_schema = usr_args_schema
58
55
  self.make_system_prompt_impl: MakeSystemPromptHandler[CtxT] | None = None
59
56
  self.make_input_content_impl: MakeInputContentHandler[InT, CtxT] | None = None
60
57
 
61
58
  self._in_args_type_adapter: TypeAdapter[InT] = TypeAdapter(self._in_type)
62
59
 
63
- def make_system_prompt(
64
- self, sys_args: LLMPromptArgs | None = None, ctx: RunContext[CtxT] | None = None
65
- ) -> str | None:
66
- if self.sys_prompt_template is None:
67
- return None
68
-
60
+ def _validate_sys_args(
61
+ self, sys_args: LLMPromptArgs | None
62
+ ) -> LLMPromptArgs | None:
69
63
  val_sys_args = sys_args
70
64
  if sys_args is not None:
71
65
  if self.sys_args_schema is not None:
72
66
  val_sys_args = self.sys_args_schema.model_validate(sys_args)
73
67
  else:
74
68
  raise SystemPromptBuilderError(
75
- "System prompt template and arguments is set, but system arguments "
76
- "schema is not provided."
69
+ proc_name=self._agent_name,
70
+ message="System prompt template and arguments is set, "
71
+ "but system arguments schema is not provided "
72
+ f"[agent_name={self._agent_name}]",
77
73
  )
78
74
 
75
+ return val_sys_args
76
+
77
+ def make_system_prompt(
78
+ self,
79
+ sys_args: LLMPromptArgs | None = None,
80
+ ctx: RunContext[CtxT] | None = None,
81
+ ) -> str | None:
82
+ if self.sys_prompt is None:
83
+ return None
84
+
85
+ val_sys_args = self._validate_sys_args(sys_args=sys_args)
86
+
79
87
  if self.make_system_prompt_impl:
80
88
  return self.make_system_prompt_impl(sys_args=val_sys_args, ctx=ctx)
81
89
 
82
90
  sys_args_dict = val_sys_args.model_dump() if val_sys_args else {}
83
91
 
84
- return self.sys_prompt_template.format(**sys_args_dict)
92
+ return self.sys_prompt.format(**sys_args_dict)
93
+
94
+ def _validate_input_args(self, in_args: InT) -> InT:
95
+ val_in_args = self._in_args_type_adapter.validate_python(in_args)
96
+ if isinstance(val_in_args, BaseModel):
97
+ has_image = self._has_image_data(val_in_args)
98
+ if has_image and self.in_prompt is None:
99
+ raise InputPromptBuilderError(
100
+ proc_name=self._agent_name,
101
+ message="BaseModel input arguments contain ImageData, "
102
+ "but input prompt template is not set "
103
+ f"[agent_name={self._agent_name}]. Cannot format input arguments.",
104
+ )
105
+ elif self.in_prompt is not None:
106
+ raise InputPromptBuilderError(
107
+ proc_name=self._agent_name,
108
+ message="Cannot use the input prompt template with "
109
+ f"non-BaseModel input arguments [agent_name={self._agent_name}]",
110
+ )
111
+
112
+ return val_in_args
85
113
 
86
114
  def make_input_content(
87
115
  self,
88
- *,
89
- in_args: InT | None,
90
- usr_args: LLMPromptArgs | None,
116
+ in_args: InT | None = None,
91
117
  ctx: RunContext[CtxT] | None = None,
92
118
  ) -> Content:
93
- val_in_args, val_usr_args = self._validate_prompt_args(
94
- in_args=in_args, usr_args=usr_args
95
- )
119
+ val_in_args = in_args
120
+ if in_args is not None:
121
+ val_in_args = self._validate_input_args(in_args=in_args)
96
122
 
97
123
  if self.make_input_content_impl:
98
- return self.make_input_content_impl(
99
- in_args=val_in_args, usr_args=val_usr_args, ctx=ctx
100
- )
101
-
102
- combined_args = self._combine_args(in_args=val_in_args, usr_args=val_usr_args)
103
- if isinstance(combined_args, str):
104
- return Content.from_text(combined_args)
124
+ return self.make_input_content_impl(in_args=val_in_args, ctx=ctx)
105
125
 
106
- if self.in_prompt_template is not None:
107
- return Content.from_formatted_prompt(
108
- self.in_prompt_template, **combined_args
126
+ if val_in_args is None:
127
+ raise InputPromptBuilderError(
128
+ proc_name=self._agent_name,
129
+ message="Input arguments are not provided, "
130
+ f"but input content is required [agent_name={self._agent_name}]",
109
131
  )
110
132
 
111
- return Content.from_text(json.dumps(combined_args, indent=2))
133
+ if issubclass(self._in_type, BaseModel) and isinstance(val_in_args, BaseModel):
134
+ val_in_args_map = self._format_pydantic_prompt_args(val_in_args)
135
+ if self.in_prompt is not None:
136
+ return Content.from_formatted_prompt(self.in_prompt, **val_in_args_map)
137
+ return Content.from_text(json.dumps(val_in_args_map, indent=2))
138
+
139
+ fmt_in_args = self._in_args_type_adapter.dump_json(
140
+ val_in_args, indent=2, warnings="error"
141
+ ).decode("utf-8")
142
+ return Content.from_text(fmt_in_args)
112
143
 
113
144
  def make_input_message(
114
145
  self,
115
146
  chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
116
147
  in_args: InT | None = None,
117
- usr_args: LLMPromptArgs | None = None,
118
148
  ctx: RunContext[CtxT] | None = None,
119
149
  ) -> UserMessage | None:
120
- if chat_inputs is None and in_args is None and usr_args is None:
121
- return None
150
+ if chat_inputs is not None and in_args is not None:
151
+ raise InputPromptBuilderError(
152
+ proc_name=self._agent_name,
153
+ message="Cannot use both chat inputs and input arguments "
154
+ f"at the same time [agent_name={self._agent_name}]",
155
+ )
122
156
 
123
157
  if chat_inputs:
124
158
  if isinstance(chat_inputs, LLMPrompt):
125
159
  return UserMessage.from_text(chat_inputs, name=self._agent_name)
126
160
  return UserMessage.from_content_parts(chat_inputs, name=self._agent_name)
127
161
 
128
- in_content = self.make_input_content(
129
- in_args=in_args, usr_args=usr_args, ctx=ctx
162
+ return UserMessage(
163
+ content=self.make_input_content(in_args=in_args, ctx=ctx),
164
+ name=self._agent_name,
130
165
  )
131
166
 
132
- return UserMessage(content=in_content, name=self._agent_name)
133
-
134
- def _validate_prompt_args(
135
- self,
136
- *,
137
- in_args: InT | None,
138
- usr_args: LLMPromptArgs | None,
139
- ) -> tuple[InT | None, LLMPromptArgs | None]:
140
- val_usr_args = usr_args
141
- if usr_args is not None:
142
- if self.in_prompt_template is None:
143
- raise InputPromptBuilderError(
144
- "Input prompt template is not set, but user arguments are provided."
145
- )
146
- if self.usr_args_schema is None:
147
- raise InputPromptBuilderError(
148
- "User arguments schema is not provided, but user arguments are "
149
- "given."
150
- )
151
- val_usr_args = self.usr_args_schema.model_validate(usr_args)
152
-
153
- val_in_args = in_args
154
- if in_args is not None:
155
- val_in_args = self._in_args_type_adapter.validate_python(in_args)
156
- if isinstance(val_in_args, BaseModel):
157
- has_image = self._has_image_data(val_in_args)
158
- if has_image and self.in_prompt_template is None:
159
- raise InputPromptBuilderError(
160
- "BaseModel input arguments contain ImageData, but input prompt "
161
- "template is not set. Cannot format input arguments."
162
- )
163
- elif self.in_prompt_template is not None:
164
- raise InputPromptBuilderError(
165
- "Cannot use the input prompt template with "
166
- "non-BaseModel input arguments."
167
- )
168
-
169
- return val_in_args, val_usr_args
170
-
171
167
  @staticmethod
172
168
  def _has_image_data(inp: BaseModel) -> bool:
173
169
  contains_image_data = False
@@ -180,39 +176,18 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
180
176
  @staticmethod
181
177
  def _format_pydantic_prompt_args(inp: BaseModel) -> dict[str, PromptArgumentType]:
182
178
  formatted_args: dict[str, PromptArgumentType] = {}
183
- for field in type(inp).model_fields:
184
- if field == "selected_recipients":
179
+ for field_name, field_info in type(inp).model_fields.items():
180
+ if field_info.exclude:
185
181
  continue
186
182
 
187
- val = getattr(inp, field)
183
+ val = getattr(inp, field_name)
188
184
  if isinstance(val, (int, str, bool, ImageData)):
189
- formatted_args[field] = val
185
+ formatted_args[field_name] = val
190
186
  else:
191
- formatted_args[field] = (
187
+ formatted_args[field_name] = (
192
188
  TypeAdapter(type(val)) # type: ignore[return-value]
193
189
  .dump_json(val, indent=2, warnings="error")
194
190
  .decode("utf-8")
195
191
  )
196
192
 
197
193
  return formatted_args
198
-
199
- def _combine_args(
200
- self, *, in_args: InT | None, usr_args: LLMPromptArgs | None
201
- ) -> Mapping[str, PromptArgumentType] | str:
202
- fmt_usr_args = self._format_pydantic_prompt_args(usr_args) if usr_args else {}
203
-
204
- if in_args is None:
205
- return fmt_usr_args
206
-
207
- if isinstance(in_args, BaseModel):
208
- fmt_in_args = self._format_pydantic_prompt_args(in_args)
209
- return fmt_in_args | fmt_usr_args
210
-
211
- combined_args_str = self._in_args_type_adapter.dump_json(
212
- in_args, indent=2, warnings="error"
213
- ).decode("utf-8")
214
- if usr_args is not None:
215
- fmt_usr_args_str = usr_args.model_dump_json(indent=2, warnings="error")
216
- combined_args_str += "\n" + fmt_usr_args_str
217
-
218
- return combined_args_str
@@ -9,21 +9,13 @@ from .printer import ColoringMode, Printer
9
9
  from .typing.io import LLMPromptArgs, ProcName
10
10
  from .usage_tracker import UsageTracker
11
11
 
12
-
13
- class RunArgs(BaseModel):
14
- sys: LLMPromptArgs | None = None
15
- usr: LLMPromptArgs | None = None
16
-
17
- model_config = ConfigDict(extra="forbid")
18
-
19
-
20
12
  CtxT = TypeVar("CtxT")
21
13
 
22
14
 
23
15
  class RunContext(BaseModel, Generic[CtxT]):
24
16
  state: CtxT | None = None
25
17
 
26
- run_args: dict[ProcName, RunArgs] = Field(default_factory=dict)
18
+ sys_args: dict[ProcName, LLMPromptArgs] = Field(default_factory=dict)
27
19
 
28
20
  is_streaming: bool = False
29
21
  result: Any | None = None
grasp_agents/runner.py CHANGED
@@ -1,42 +1,131 @@
1
1
  from collections.abc import AsyncIterator, Sequence
2
+ from functools import partial
2
3
  from typing import Any, Generic
3
4
 
4
- from .comm_processor import CommProcessor
5
+ from .errors import RunnerError
6
+ from .packet import Packet, StartPacket
7
+ from .packet_pool import END_PROC_NAME, PacketPool
8
+ from .processor import Processor
5
9
  from .run_context import CtxT, RunContext
6
- from .typing.events import Event
10
+ from .typing.events import Event, ProcPacketOutputEvent, RunResultEvent
11
+ from .typing.io import OutT
7
12
 
8
13
 
9
- class Runner(Generic[CtxT]):
14
+ class Runner(Generic[OutT, CtxT]):
10
15
  def __init__(
11
16
  self,
12
- start_proc: CommProcessor[Any, Any, Any, CtxT],
13
- procs: Sequence[CommProcessor[Any, Any, Any, CtxT]],
17
+ entry_proc: Processor[Any, Any, Any, CtxT],
18
+ procs: Sequence[Processor[Any, Any, Any, CtxT]],
14
19
  ctx: RunContext[CtxT] | None = None,
15
20
  ) -> None:
16
- if start_proc not in procs:
17
- raise ValueError(
18
- f"Start processor {start_proc.name} must be in the list of processors: "
21
+ if entry_proc not in procs:
22
+ raise RunnerError(
23
+ f"Entry processor {entry_proc.name} must be in the list of processors: "
19
24
  f"{', '.join(proc.name for proc in procs)}"
20
25
  )
21
- self._start_proc = start_proc
26
+ if sum(1 for proc in procs if END_PROC_NAME in (proc.recipients or [])) != 1:
27
+ raise RunnerError(
28
+ "There must be exactly one processor with recipient 'END'."
29
+ )
30
+
31
+ self._entry_proc = entry_proc
22
32
  self._procs = procs
23
33
  self._ctx = ctx or RunContext[CtxT]()
34
+ self._packet_pool: PacketPool[CtxT] = PacketPool()
24
35
 
25
36
  @property
26
37
  def ctx(self) -> RunContext[CtxT]:
27
38
  return self._ctx
28
39
 
29
- async def run(self, **run_args: Any) -> Any:
30
- self._ctx.is_streaming = False
31
- for proc in self._procs:
32
- proc.start_listening(ctx=self._ctx, **run_args)
33
- await self._start_proc.run(**run_args, ctx=self._ctx)
40
+ def _unpack_packet(
41
+ self, packet: Packet[Any] | None
42
+ ) -> tuple[Packet[Any] | None, Any | None]:
43
+ if isinstance(packet, StartPacket):
44
+ return None, packet.chat_inputs
45
+ return packet, None
46
+
47
+ async def _packet_handler(
48
+ self,
49
+ proc: Processor[Any, Any, Any, CtxT],
50
+ pool: PacketPool[CtxT],
51
+ packet: Packet[Any],
52
+ ctx: RunContext[CtxT],
53
+ **run_kwargs: Any,
54
+ ) -> None:
55
+ _in_packet, _chat_inputs = self._unpack_packet(packet)
56
+ out_packet = await proc.run(
57
+ chat_inputs=_chat_inputs, in_packet=_in_packet, ctx=ctx, **run_kwargs
58
+ )
59
+ await pool.post(out_packet)
60
+
61
+ async def _packet_handler_stream(
62
+ self,
63
+ proc: Processor[Any, Any, Any, CtxT],
64
+ pool: PacketPool[CtxT],
65
+ packet: Packet[Any],
66
+ ctx: RunContext[CtxT],
67
+ **run_kwargs: Any,
68
+ ) -> None:
69
+ _in_packet, _chat_inputs = self._unpack_packet(packet)
34
70
 
35
- return self._ctx.result
71
+ out_packet: Packet[Any] | None = None
72
+ async for event in proc.run_stream(
73
+ chat_inputs=_chat_inputs, in_packet=_in_packet, ctx=ctx, **run_kwargs
74
+ ):
75
+ if isinstance(event, ProcPacketOutputEvent):
76
+ out_packet = event.data
77
+ await pool.push_event(event)
36
78
 
37
- async def run_stream(self, **run_args: Any) -> AsyncIterator[Event[Any]]:
38
- self._ctx.is_streaming = True
39
- for proc in self._procs:
40
- proc.start_listening(ctx=self._ctx, **run_args)
41
- async for event in self._start_proc.run_stream(**run_args, ctx=self._ctx):
42
- yield event
79
+ assert out_packet is not None
80
+
81
+ await pool.post(out_packet)
82
+
83
+ async def run(
84
+ self,
85
+ chat_input: Any = "start",
86
+ **run_args: Any,
87
+ ) -> Packet[OutT]:
88
+ async with PacketPool[CtxT]() as pool:
89
+ for proc in self._procs:
90
+ pool.register_packet_handler(
91
+ proc_name=proc.name,
92
+ handler=partial(self._packet_handler, proc, pool),
93
+ ctx=self._ctx,
94
+ **run_args,
95
+ )
96
+ await pool.post(
97
+ StartPacket[Any](
98
+ recipients=[self._entry_proc.name], chat_inputs=chat_input
99
+ )
100
+ )
101
+ return await pool.final_result()
102
+
103
+ async def run_stream(
104
+ self,
105
+ chat_input: Any = "start",
106
+ **run_args: Any,
107
+ ) -> AsyncIterator[Event[Any]]:
108
+ async with PacketPool[CtxT]() as pool:
109
+ for proc in self._procs:
110
+ pool.register_packet_handler(
111
+ proc_name=proc.name,
112
+ handler=partial(self._packet_handler_stream, proc, pool),
113
+ ctx=self._ctx,
114
+ **run_args,
115
+ )
116
+ await pool.post(
117
+ StartPacket[Any](
118
+ recipients=[self._entry_proc.name], chat_inputs=chat_input
119
+ )
120
+ )
121
+ async for event in pool.stream_events():
122
+ if isinstance(
123
+ event, ProcPacketOutputEvent
124
+ ) and event.data.recipients == [END_PROC_NAME]:
125
+ yield RunResultEvent(
126
+ data=event.data,
127
+ proc_name=event.proc_name,
128
+ call_id=event.call_id,
129
+ )
130
+ else:
131
+ yield event
@@ -136,16 +136,20 @@ class ProcPayloadOutputEvent(Event[Any], frozen=True):
136
136
 
137
137
 
138
138
  class ProcPacketOutputEvent(Event[Packet[Any]], frozen=True):
139
- type: Literal[EventType.PACKET_OUT] = EventType.PACKET_OUT
140
- source: Literal[EventSourceType.PROC] = EventSourceType.PROC
139
+ type: Literal[EventType.PACKET_OUT, EventType.WORKFLOW_RES, EventType.RUN_RES] = (
140
+ EventType.PACKET_OUT
141
+ )
142
+ source: Literal[
143
+ EventSourceType.PROC, EventSourceType.WORKFLOW, EventSourceType.RUN
144
+ ] = EventSourceType.PROC
141
145
 
142
146
 
143
- class WorkflowResultEvent(Event[Packet[Any]], frozen=True):
147
+ class WorkflowResultEvent(ProcPacketOutputEvent, frozen=True):
144
148
  type: Literal[EventType.WORKFLOW_RES] = EventType.WORKFLOW_RES
145
149
  source: Literal[EventSourceType.WORKFLOW] = EventSourceType.WORKFLOW
146
150
 
147
151
 
148
- class RunResultEvent(Event[Packet[Any]], frozen=True):
152
+ class RunResultEvent(ProcPacketOutputEvent, frozen=True):
149
153
  type: Literal[EventType.RUN_RES] = EventType.RUN_RES
150
154
  source: Literal[EventSourceType.RUN] = EventSourceType.RUN
151
155
 
grasp_agents/typing/io.py CHANGED
@@ -6,7 +6,7 @@ ProcName: TypeAlias = str
6
6
 
7
7
 
8
8
  InT = TypeVar("InT")
9
- OutT_co = TypeVar("OutT_co", covariant=True)
9
+ OutT = TypeVar("OutT")
10
10
 
11
11
 
12
12
  class LLMPromptArgs(BaseModel):
@@ -4,11 +4,11 @@ from logging import getLogger
4
4
  from typing import Any, Generic, Protocol, TypeVar, cast, final
5
5
 
6
6
  from ..errors import WorkflowConstructionError
7
- from ..packet_pool import Packet, PacketPool
7
+ from ..packet_pool import Packet
8
8
  from ..processor import Processor
9
9
  from ..run_context import CtxT, RunContext
10
10
  from ..typing.events import Event, ProcPacketOutputEvent, WorkflowResultEvent
11
- from ..typing.io import InT, OutT_co, ProcName
11
+ from ..typing.io import InT, OutT, ProcName
12
12
  from .workflow_processor import WorkflowProcessor
13
13
 
14
14
  logger = getLogger(__name__)
@@ -25,15 +25,12 @@ class ExitWorkflowLoopHandler(Protocol[_OutT_contra, CtxT]):
25
25
  ) -> bool: ...
26
26
 
27
27
 
28
- class LoopedWorkflow(
29
- WorkflowProcessor[InT, OutT_co, CtxT], Generic[InT, OutT_co, CtxT]
30
- ):
28
+ class LoopedWorkflow(WorkflowProcessor[InT, OutT, CtxT], Generic[InT, OutT, CtxT]):
31
29
  def __init__(
32
30
  self,
33
31
  name: ProcName,
34
32
  subprocs: Sequence[Processor[Any, Any, Any, CtxT]],
35
- exit_proc: Processor[Any, OutT_co, Any, CtxT],
36
- packet_pool: PacketPool[CtxT] | None = None,
33
+ exit_proc: Processor[Any, OutT, Any, CtxT],
37
34
  recipients: list[ProcName] | None = None,
38
35
  max_retries: int = 0,
39
36
  max_iterations: int = 10,
@@ -43,7 +40,6 @@ class LoopedWorkflow(
43
40
  name=name,
44
41
  start_proc=subprocs[0],
45
42
  end_proc=exit_proc,
46
- packet_pool=packet_pool,
47
43
  recipients=recipients,
48
44
  max_retries=max_retries,
49
45
  )
@@ -64,24 +60,22 @@ class LoopedWorkflow(
64
60
 
65
61
  self._max_iterations = max_iterations
66
62
 
67
- self._exit_workflow_loop_impl: ExitWorkflowLoopHandler[OutT_co, CtxT] | None = (
68
- None
69
- )
63
+ self._exit_workflow_loop_impl: ExitWorkflowLoopHandler[OutT, CtxT] | None = None
70
64
 
71
65
  @property
72
66
  def max_iterations(self) -> int:
73
67
  return self._max_iterations
74
68
 
75
69
  def exit_workflow_loop(
76
- self, func: ExitWorkflowLoopHandler[OutT_co, CtxT]
77
- ) -> ExitWorkflowLoopHandler[OutT_co, CtxT]:
70
+ self, func: ExitWorkflowLoopHandler[OutT, CtxT]
71
+ ) -> ExitWorkflowLoopHandler[OutT, CtxT]:
78
72
  self._exit_workflow_loop_impl = func
79
73
 
80
74
  return func
81
75
 
82
76
  def _exit_workflow_loop(
83
77
  self,
84
- out_packet: Packet[OutT_co],
78
+ out_packet: Packet[OutT],
85
79
  *,
86
80
  ctx: RunContext[CtxT] | None = None,
87
81
  **kwargs: Any,
@@ -101,12 +95,12 @@ class LoopedWorkflow(
101
95
  call_id: str | None = None,
102
96
  forgetful: bool = False,
103
97
  ctx: RunContext[CtxT] | None = None,
104
- ) -> Packet[OutT_co]:
98
+ ) -> Packet[OutT]:
105
99
  call_id = self._generate_call_id(call_id)
106
100
 
107
101
  packet = in_packet
108
102
  num_iterations = 0
109
- exit_packet: Packet[OutT_co] | None = None
103
+ exit_packet: Packet[OutT] | None = None
110
104
 
111
105
  while True:
112
106
  for subproc in self.subprocs:
@@ -121,7 +115,7 @@ class LoopedWorkflow(
121
115
 
122
116
  if subproc is self._end_proc:
123
117
  num_iterations += 1
124
- exit_packet = cast("Packet[OutT_co]", packet)
118
+ exit_packet = cast("Packet[OutT]", packet)
125
119
  if self._exit_workflow_loop(exit_packet, ctx=ctx):
126
120
  return exit_packet
127
121
  if num_iterations >= self._max_iterations:
@@ -149,38 +143,39 @@ class LoopedWorkflow(
149
143
 
150
144
  packet = in_packet
151
145
  num_iterations = 0
152
- exit_packet: Packet[OutT_co] | None = None
153
-
154
- for subproc in self.subprocs:
155
- async for event in subproc.run_stream(
156
- chat_inputs=chat_inputs,
157
- in_packet=packet,
158
- in_args=in_args,
159
- forgetful=forgetful,
160
- call_id=f"{call_id}/{subproc.name}",
161
- ctx=ctx,
162
- ):
163
- if isinstance(event, ProcPacketOutputEvent):
164
- packet = event.data
165
- yield event
166
-
167
- if subproc is self._end_proc:
168
- num_iterations += 1
169
- exit_packet = cast("Packet[OutT_co]", packet)
170
- if self._exit_workflow_loop(exit_packet, ctx=ctx):
171
- yield WorkflowResultEvent(
172
- data=exit_packet, proc_name=self.name, call_id=call_id
173
- )
174
- return
175
- if num_iterations >= self._max_iterations:
176
- logger.info(
177
- f"Max iterations reached ({self._max_iterations}). "
178
- "Exiting loop."
179
- )
180
- yield WorkflowResultEvent(
181
- data=exit_packet, proc_name=self.name, call_id=call_id
182
- )
183
- return
184
-
185
- chat_inputs = None
186
- in_args = None
146
+ exit_packet: Packet[OutT] | None = None
147
+
148
+ while True:
149
+ for subproc in self.subprocs:
150
+ async for event in subproc.run_stream(
151
+ chat_inputs=chat_inputs,
152
+ in_packet=packet,
153
+ in_args=in_args,
154
+ forgetful=forgetful,
155
+ call_id=f"{call_id}/{subproc.name}",
156
+ ctx=ctx,
157
+ ):
158
+ if isinstance(event, ProcPacketOutputEvent):
159
+ packet = event.data
160
+ yield event
161
+
162
+ if subproc is self._end_proc:
163
+ num_iterations += 1
164
+ exit_packet = cast("Packet[OutT]", packet)
165
+ if self._exit_workflow_loop(exit_packet, ctx=ctx):
166
+ yield WorkflowResultEvent(
167
+ data=exit_packet, proc_name=self.name, call_id=call_id
168
+ )
169
+ return
170
+ if num_iterations >= self._max_iterations:
171
+ logger.info(
172
+ f"Max iterations reached ({self._max_iterations}). "
173
+ "Exiting loop."
174
+ )
175
+ yield WorkflowResultEvent(
176
+ data=exit_packet, proc_name=self.name, call_id=call_id
177
+ )
178
+ return
179
+
180
+ chat_inputs = None
181
+ in_args = None