grasp_agents 0.5.5__py3-none-any.whl → 0.5.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,244 @@
1
+ import asyncio
2
+ import logging
3
+ from collections.abc import AsyncIterator, Sequence
4
+ from typing import Any, ClassVar, Generic, cast
5
+
6
+
7
+ from ..errors import PacketRoutingError
8
+ from ..memory import MemT
9
+ from ..packet import Packet
10
+ from ..run_context import CtxT, RunContext
11
+ from ..typing.events import (
12
+ Event, ProcPacketOutputEvent, ProcPayloadOutputEvent
13
+ )
14
+ from ..typing.io import InT, OutT
15
+ from ..utils import stream_concurrent
16
+ from .base_processor import BaseProcessor, with_retry, with_retry_stream
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class ParallelProcessor(BaseProcessor[InT, OutT, MemT, CtxT], Generic[InT, OutT, MemT, CtxT]):
22
+ _generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
23
+ 0: "_in_type",
24
+ 1: "_out_type",
25
+ }
26
+
27
+ async def _process(
28
+ self,
29
+ chat_inputs: Any | None = None,
30
+ *,
31
+ in_args: InT | None = None,
32
+ memory: MemT,
33
+ call_id: str,
34
+ ctx: RunContext[CtxT] | None = None,
35
+ ) -> OutT:
36
+ return cast(OutT, in_args)
37
+
38
+ async def _process_stream(
39
+ self,
40
+ chat_inputs: Any | None = None,
41
+ *,
42
+ in_args: InT | None = None,
43
+ memory: MemT,
44
+ call_id: str,
45
+ ctx: RunContext[CtxT] | None = None,
46
+ ) -> AsyncIterator[Event[Any]]:
47
+ output = cast(OutT, in_args)
48
+ yield ProcPayloadOutputEvent(data=output, proc_name=self.name, call_id=call_id)
49
+
50
+ def _validate_parallel_recipients(
51
+ self, out_packets: Sequence[Packet[OutT]], call_id: str
52
+ ) -> None:
53
+ recipient_sets = [set(p.recipients or []) for p in out_packets]
54
+ same_recipients = all(rs == recipient_sets[0] for rs in recipient_sets)
55
+ if not same_recipients:
56
+ raise PacketRoutingError(
57
+ proc_name=self.name,
58
+ call_id=call_id,
59
+ message="Parallel runs must return the same recipients "
60
+ f"[proc_name={self.name}; call_id={call_id}]",
61
+ )
62
+
63
+ @with_retry
64
+ async def _run_single(
65
+ self,
66
+ chat_inputs: Any | None = None,
67
+ *,
68
+ in_args: InT | None = None,
69
+ forgetful: bool = False,
70
+ call_id: str,
71
+ ctx: RunContext[CtxT] | None = None,
72
+ ) -> Packet[OutT]:
73
+ memory = self.memory.model_copy(deep=True) if forgetful else self.memory
74
+
75
+ output = await self._process(
76
+ chat_inputs=chat_inputs,
77
+ in_args=in_args,
78
+ memory=memory,
79
+ call_id=call_id,
80
+ ctx=ctx,
81
+ )
82
+ val_output = self._validate_output(output, call_id=call_id)
83
+
84
+ recipients = self._select_recipients(output=val_output, ctx=ctx)
85
+ self._validate_recipients(recipients, call_id=call_id)
86
+
87
+ return Packet(payloads=[val_output], sender=self.name, recipients=recipients)
88
+
89
+
90
+ async def _run_parallel(
91
+ self, in_args: list[InT], call_id: str, ctx: RunContext[CtxT] | None = None
92
+ ) -> Packet[OutT]:
93
+ tasks = [
94
+ self._run_single(
95
+ in_args=inp, forgetful=True, call_id=f"{call_id}/{idx}", ctx=ctx
96
+ )
97
+ for idx, inp in enumerate(in_args)
98
+ ]
99
+ out_packets = await asyncio.gather(*tasks)
100
+ self._validate_parallel_recipients(out_packets, call_id=call_id)
101
+
102
+ return Packet(
103
+ payloads=[out_packet.payloads[0] for out_packet in out_packets],
104
+ sender=self.name,
105
+ recipients=out_packets[0].recipients,
106
+ )
107
+
108
+ async def run(
109
+ self,
110
+ chat_inputs: Any | None = None,
111
+ *,
112
+ in_packet: Packet[InT] | None = None,
113
+ in_args: InT | list[InT] | None = None,
114
+ forgetful: bool = False,
115
+ call_id: str | None = None,
116
+ ctx: RunContext[CtxT] | None = None,
117
+ ) -> Packet[OutT]:
118
+ call_id = self._generate_call_id(call_id)
119
+
120
+ val_in_args = self._validate_inputs(
121
+ call_id=call_id,
122
+ chat_inputs=chat_inputs,
123
+ in_packet=in_packet,
124
+ in_args=in_args,
125
+ )
126
+
127
+ if val_in_args and len(val_in_args) > 1:
128
+ return await self._run_parallel(in_args=val_in_args, call_id=call_id, ctx=ctx)
129
+
130
+ return await self._run_single(
131
+ chat_inputs=chat_inputs,
132
+ in_args=val_in_args[0] if val_in_args else None,
133
+ forgetful=forgetful,
134
+ call_id=call_id,
135
+ ctx=ctx,
136
+ )
137
+
138
+ @with_retry_stream
139
+ async def _run_single_stream(
140
+ self,
141
+ chat_inputs: Any | None = None,
142
+ *,
143
+ in_args: InT | None = None,
144
+ forgetful: bool = False,
145
+ call_id: str,
146
+ ctx: RunContext[CtxT] | None = None,
147
+ ) -> AsyncIterator[Event[Any]]:
148
+ memory = self.memory.model_copy(deep=True) if forgetful else self.memory
149
+
150
+ output: OutT | None = None
151
+ async for event in self._process_stream(
152
+ chat_inputs=chat_inputs,
153
+ in_args=in_args,
154
+ memory=memory,
155
+ call_id=call_id,
156
+ ctx=ctx,
157
+ ):
158
+ if isinstance(event, ProcPayloadOutputEvent):
159
+ output = event.data
160
+ yield event
161
+
162
+ assert output is not None
163
+
164
+ val_output = self._validate_output(output, call_id=call_id)
165
+
166
+ recipients = self._select_recipients(output=val_output, ctx=ctx)
167
+ self._validate_recipients(recipients, call_id=call_id)
168
+
169
+ out_packet = Packet[OutT](
170
+ payloads=[val_output], sender=self.name, recipients=recipients
171
+ )
172
+
173
+ yield ProcPacketOutputEvent(
174
+ data=out_packet, proc_name=self.name, call_id=call_id
175
+ )
176
+
177
+ async def _run_parallel_stream(
178
+ self,
179
+ in_args: list[InT],
180
+ call_id: str,
181
+ ctx: RunContext[CtxT] | None = None,
182
+ ) -> AsyncIterator[Event[Any]]:
183
+ streams = [
184
+ self._run_single_stream(
185
+ in_args=inp, forgetful=True, call_id=f"{call_id}/{idx}", ctx=ctx
186
+ )
187
+ for idx, inp in enumerate(in_args)
188
+ ]
189
+
190
+ out_packets_map: dict[int, Packet[OutT]] = {}
191
+ async for idx, event in stream_concurrent(streams):
192
+ if isinstance(event, ProcPacketOutputEvent):
193
+ out_packets_map[idx] = event.data
194
+ else:
195
+ yield event
196
+
197
+ self._validate_parallel_recipients(
198
+ out_packets=list(out_packets_map.values()), call_id=call_id
199
+ )
200
+
201
+ out_packet = Packet(
202
+ payloads=[
203
+ out_packet.payloads[0]
204
+ for _, out_packet in sorted(out_packets_map.items())
205
+ ],
206
+ sender=self.name,
207
+ recipients=out_packets_map[0].recipients,
208
+ )
209
+
210
+ yield ProcPacketOutputEvent(
211
+ data=out_packet, proc_name=self.name, call_id=call_id
212
+ )
213
+
214
+ async def run_stream(
215
+ self,
216
+ chat_inputs: Any | None = None,
217
+ *,
218
+ in_packet: Packet[InT] | None = None,
219
+ in_args: InT | list[InT] | None = None,
220
+ forgetful: bool = False,
221
+ call_id: str | None = None,
222
+ ctx: RunContext[CtxT] | None = None,
223
+ ) -> AsyncIterator[Event[Any]]:
224
+ call_id = self._generate_call_id(call_id)
225
+
226
+ val_in_args = self._validate_inputs(
227
+ call_id=call_id,
228
+ chat_inputs=chat_inputs,
229
+ in_packet=in_packet,
230
+ in_args=in_args,
231
+ )
232
+
233
+ if val_in_args and len(val_in_args) > 1:
234
+ stream = self._run_parallel_stream(in_args=val_in_args, call_id=call_id, ctx=ctx)
235
+ else:
236
+ stream = self._run_single_stream(
237
+ chat_inputs=chat_inputs,
238
+ in_args=val_in_args[0] if val_in_args else None,
239
+ forgetful=forgetful,
240
+ call_id=call_id,
241
+ ctx=ctx,
242
+ )
243
+ async for event in stream:
244
+ yield event
@@ -0,0 +1,161 @@
1
+ import logging
2
+ from collections.abc import AsyncIterator
3
+ from typing import Any, ClassVar, Generic, cast
4
+
5
+ from ..memory import MemT
6
+ from ..packet import Packet
7
+ from ..run_context import CtxT, RunContext
8
+ from ..typing.events import Event, ProcPacketOutputEvent, ProcPayloadOutputEvent
9
+ from ..typing.io import InT, OutT, ProcName
10
+ from .base_processor import BaseProcessor, with_retry, with_retry_stream
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class Processor(BaseProcessor[InT, OutT, MemT, CtxT], Generic[InT, OutT, MemT, CtxT]):
16
+ _generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
17
+ 0: "_in_type",
18
+ 1: "_out_type",
19
+ }
20
+
21
+ async def _process(
22
+ self,
23
+ chat_inputs: Any | None = None,
24
+ *,
25
+ in_args: list[InT] | None = None,
26
+ memory: MemT,
27
+ call_id: str,
28
+ ctx: RunContext[CtxT] | None = None,
29
+ ) -> list[OutT]:
30
+ return cast("list[OutT]", in_args)
31
+
32
+ async def _process_stream(
33
+ self,
34
+ chat_inputs: Any | None = None,
35
+ *,
36
+ in_args: list[InT] | None = None,
37
+ memory: MemT,
38
+ call_id: str,
39
+ ctx: RunContext[CtxT] | None = None,
40
+ ) -> AsyncIterator[Event[Any]]:
41
+ outputs = await self._process(
42
+ chat_inputs=chat_inputs,
43
+ in_args=in_args,
44
+ memory=memory,
45
+ call_id=call_id,
46
+ ctx=ctx,
47
+ )
48
+ for output in outputs:
49
+ yield ProcPayloadOutputEvent(
50
+ data=output, proc_name=self.name, call_id=call_id
51
+ )
52
+
53
+ def _preprocess(
54
+ self,
55
+ chat_inputs: Any | None = None,
56
+ *,
57
+ in_packet: Packet[InT] | None = None,
58
+ in_args: InT | list[InT] | None = None,
59
+ forgetful: bool = False,
60
+ call_id: str | None = None,
61
+ ctx: RunContext[CtxT] | None = None,
62
+ ) -> tuple[list[InT] | None, MemT, str]:
63
+ call_id = self._generate_call_id(call_id)
64
+
65
+ val_in_args = self._validate_inputs(
66
+ call_id=call_id,
67
+ chat_inputs=chat_inputs,
68
+ in_packet=in_packet,
69
+ in_args=in_args,
70
+ )
71
+
72
+ memory = self.memory.model_copy(deep=True) if forgetful else self.memory
73
+
74
+ return val_in_args, memory, call_id
75
+
76
+ def _postprocess(
77
+ self, outputs: list[OutT], call_id: str, ctx: RunContext[CtxT] | None = None
78
+ ) -> Packet[OutT]:
79
+ payloads: list[OutT] = []
80
+ routing: dict[int, list[ProcName] | None] = {}
81
+ for idx, output in enumerate(outputs):
82
+ val_output = self._validate_output(output, call_id=call_id)
83
+ recipients = self._select_recipients(output=val_output, ctx=ctx)
84
+ self._validate_recipients(recipients, call_id=call_id)
85
+
86
+ payloads.append(val_output)
87
+ routing[idx] = recipients
88
+
89
+ recipient_sets = [set(r or []) for r in routing.values()]
90
+ if all(r == recipient_sets[0] for r in recipient_sets):
91
+ recipients = routing[0]
92
+ else:
93
+ recipients = routing
94
+
95
+ return Packet(payloads=payloads, sender=self.name, recipients=recipients) # type: ignore[return-value]
96
+
97
+ @with_retry
98
+ async def run(
99
+ self,
100
+ chat_inputs: Any | None = None,
101
+ *,
102
+ in_packet: Packet[InT] | None = None,
103
+ in_args: InT | list[InT] | None = None,
104
+ forgetful: bool = False,
105
+ call_id: str | None = None,
106
+ ctx: RunContext[CtxT] | None = None,
107
+ ) -> Packet[OutT]:
108
+ val_in_args, memory, call_id = self._preprocess(
109
+ chat_inputs=chat_inputs,
110
+ in_packet=in_packet,
111
+ in_args=in_args,
112
+ forgetful=forgetful,
113
+ call_id=call_id,
114
+ ctx=ctx,
115
+ )
116
+ outputs = await self._process(
117
+ chat_inputs=chat_inputs,
118
+ in_args=val_in_args,
119
+ memory=memory,
120
+ call_id=call_id,
121
+ ctx=ctx,
122
+ )
123
+
124
+ return self._postprocess(outputs=outputs, call_id=call_id, ctx=ctx)
125
+
126
+ @with_retry_stream
127
+ async def run_stream(
128
+ self,
129
+ chat_inputs: Any | None = None,
130
+ *,
131
+ in_packet: Packet[InT] | None = None,
132
+ in_args: InT | list[InT] | None = None,
133
+ forgetful: bool = False,
134
+ call_id: str | None = None,
135
+ ctx: RunContext[CtxT] | None = None,
136
+ ) -> AsyncIterator[Event[Any]]:
137
+ val_in_args, memory, call_id = self._preprocess(
138
+ chat_inputs=chat_inputs,
139
+ in_packet=in_packet,
140
+ in_args=in_args,
141
+ forgetful=forgetful,
142
+ call_id=call_id,
143
+ ctx=ctx,
144
+ )
145
+ outputs: list[OutT] = []
146
+ async for event in self._process_stream(
147
+ chat_inputs=chat_inputs,
148
+ in_args=val_in_args,
149
+ memory=memory,
150
+ call_id=call_id,
151
+ ctx=ctx,
152
+ ):
153
+ if isinstance(event, ProcPayloadOutputEvent):
154
+ outputs.append(event.data)
155
+ yield event
156
+
157
+ out_packet = self._postprocess(outputs=outputs, call_id=call_id, ctx=ctx)
158
+
159
+ yield ProcPacketOutputEvent(
160
+ data=out_packet, proc_name=self.name, call_id=call_id
161
+ )
grasp_agents/runner.py CHANGED
@@ -1,3 +1,4 @@
1
+ import logging
1
2
  from collections.abc import AsyncIterator, Sequence
2
3
  from functools import partial
3
4
  from typing import Any, Generic
@@ -5,11 +6,13 @@ from typing import Any, Generic
5
6
  from .errors import RunnerError
6
7
  from .packet import Packet, StartPacket
7
8
  from .packet_pool import END_PROC_NAME, PacketPool
8
- from .processor import Processor
9
+ from .processors.processor import Processor
9
10
  from .run_context import CtxT, RunContext
10
11
  from .typing.events import Event, ProcPacketOutputEvent, RunResultEvent
11
12
  from .typing.io import OutT
12
13
 
14
+ logger = logging.getLogger(__name__)
15
+
13
16
 
14
17
  class Runner(Generic[OutT, CtxT]):
15
18
  def __init__(
@@ -53,9 +56,18 @@ class Runner(Generic[OutT, CtxT]):
53
56
  **run_kwargs: Any,
54
57
  ) -> None:
55
58
  _in_packet, _chat_inputs = self._unpack_packet(packet)
59
+
60
+ logger.info(f"\n[Running processor {proc.name}]\n")
61
+
56
62
  out_packet = await proc.run(
57
63
  chat_inputs=_chat_inputs, in_packet=_in_packet, ctx=ctx, **run_kwargs
58
64
  )
65
+
66
+ logger.info(
67
+ f"\n[Finished running processor {proc.name}]\n"
68
+ f"Posting output packet to recipients {out_packet.recipients}\n"
69
+ )
70
+
59
71
  await pool.post(out_packet)
60
72
 
61
73
  async def _packet_handler_stream(
@@ -68,6 +80,8 @@ class Runner(Generic[OutT, CtxT]):
68
80
  ) -> None:
69
81
  _in_packet, _chat_inputs = self._unpack_packet(packet)
70
82
 
83
+ logger.info(f"\n[Running processor {proc.name}]\n")
84
+
71
85
  out_packet: Packet[Any] | None = None
72
86
  async for event in proc.run_stream(
73
87
  chat_inputs=_chat_inputs, in_packet=_in_packet, ctx=ctx, **run_kwargs
@@ -78,6 +92,11 @@ class Runner(Generic[OutT, CtxT]):
78
92
 
79
93
  assert out_packet is not None
80
94
 
95
+ logger.info(
96
+ f"\n[Finished running processor {proc.name}]\n"
97
+ f"Posting output packet to recipients {out_packet.recipients}\n"
98
+ )
99
+
81
100
  await pool.post(out_packet)
82
101
 
83
102
  async def run(
@@ -57,6 +57,10 @@ class Event(BaseModel, Generic[_T], frozen=True):
57
57
  call_id: str | None = None
58
58
  data: _T
59
59
 
60
+ class DummyEvent(Event[Any], frozen=True):
61
+ type: Literal[EventType.PAYLOAD_OUT] = EventType.PAYLOAD_OUT
62
+ source: Literal[EventSourceType.PROC] = EventSourceType.PROC
63
+ data: Any = None
60
64
 
61
65
  class CompletionEvent(Event[Completion], frozen=True):
62
66
  type: Literal[EventType.COMP] = EventType.COMP
@@ -5,7 +5,7 @@ from typing import Any, Generic, Protocol, TypeVar, cast, final
5
5
 
6
6
  from ..errors import WorkflowConstructionError
7
7
  from ..packet_pool import Packet
8
- from ..processor import Processor
8
+ from ..processors.base_processor import BaseProcessor
9
9
  from ..run_context import CtxT, RunContext
10
10
  from ..typing.events import Event, ProcPacketOutputEvent, WorkflowResultEvent
11
11
  from ..typing.io import InT, OutT, ProcName
@@ -16,7 +16,7 @@ logger = getLogger(__name__)
16
16
  _OutT_contra = TypeVar("_OutT_contra", contravariant=True)
17
17
 
18
18
 
19
- class ExitWorkflowLoopHandler(Protocol[_OutT_contra, CtxT]):
19
+ class WorkflowLoopTerminator(Protocol[_OutT_contra, CtxT]):
20
20
  def __call__(
21
21
  self,
22
22
  out_packet: Packet[_OutT_contra],
@@ -29,8 +29,8 @@ class LoopedWorkflow(WorkflowProcessor[InT, OutT, CtxT], Generic[InT, OutT, CtxT
29
29
  def __init__(
30
30
  self,
31
31
  name: ProcName,
32
- subprocs: Sequence[Processor[Any, Any, Any, CtxT]],
33
- exit_proc: Processor[Any, OutT, Any, CtxT],
32
+ subprocs: Sequence[BaseProcessor[Any, Any, Any, CtxT]],
33
+ exit_proc: BaseProcessor[Any, OutT, Any, CtxT],
34
34
  recipients: list[ProcName] | None = None,
35
35
  max_retries: int = 0,
36
36
  max_iterations: int = 10,
@@ -60,28 +60,30 @@ class LoopedWorkflow(WorkflowProcessor[InT, OutT, CtxT], Generic[InT, OutT, CtxT
60
60
 
61
61
  self._max_iterations = max_iterations
62
62
 
63
- self._exit_workflow_loop_impl: ExitWorkflowLoopHandler[OutT, CtxT] | None = None
63
+ self.workflow_loop_terminator: WorkflowLoopTerminator[OutT, CtxT] | None
64
+ if not hasattr(type(self), "workflow_loop_terminator"):
65
+ self.workflow_loop_terminator = None
64
66
 
65
67
  @property
66
68
  def max_iterations(self) -> int:
67
69
  return self._max_iterations
68
70
 
69
- def exit_workflow_loop(
70
- self, func: ExitWorkflowLoopHandler[OutT, CtxT]
71
- ) -> ExitWorkflowLoopHandler[OutT, CtxT]:
72
- self._exit_workflow_loop_impl = func
71
+ def add_workflow_loop_terminator(
72
+ self, func: WorkflowLoopTerminator[OutT, CtxT]
73
+ ) -> WorkflowLoopTerminator[OutT, CtxT]:
74
+ self.workflow_loop_terminator = func
73
75
 
74
76
  return func
75
77
 
76
- def _exit_workflow_loop(
78
+ def _terminate_workflow_loop(
77
79
  self,
78
80
  out_packet: Packet[OutT],
79
81
  *,
80
82
  ctx: RunContext[CtxT] | None = None,
81
83
  **kwargs: Any,
82
84
  ) -> bool:
83
- if self._exit_workflow_loop_impl:
84
- return self._exit_workflow_loop_impl(out_packet, ctx=ctx, **kwargs)
85
+ if self.workflow_loop_terminator:
86
+ return self.workflow_loop_terminator(out_packet, ctx=ctx, **kwargs)
85
87
 
86
88
  return False
87
89
 
@@ -96,14 +98,15 @@ class LoopedWorkflow(WorkflowProcessor[InT, OutT, CtxT], Generic[InT, OutT, CtxT
96
98
  forgetful: bool = False,
97
99
  ctx: RunContext[CtxT] | None = None,
98
100
  ) -> Packet[OutT]:
99
- call_id = self._generate_call_id(call_id)
100
-
101
101
  packet = in_packet
102
- num_iterations = 0
103
102
  exit_packet: Packet[OutT] | None = None
104
103
 
105
- while True:
104
+ for num_iterations in range(1, self._max_iterations + 1):
105
+ call_id = self._generate_call_id(call_id)
106
+
106
107
  for subproc in self.subprocs:
108
+ logger.info(f"\n[Running subprocessor {subproc.name}]\n")
109
+
107
110
  packet = await subproc.run(
108
111
  chat_inputs=chat_inputs,
109
112
  in_packet=packet,
@@ -113,12 +116,13 @@ class LoopedWorkflow(WorkflowProcessor[InT, OutT, CtxT], Generic[InT, OutT, CtxT
113
116
  ctx=ctx,
114
117
  )
115
118
 
119
+ logger.info(f"\n[Finished running subprocessor {subproc.name}]\n")
120
+
116
121
  if subproc is self._end_proc:
117
- num_iterations += 1
118
122
  exit_packet = cast("Packet[OutT]", packet)
119
- if self._exit_workflow_loop(exit_packet, ctx=ctx):
123
+ if self._terminate_workflow_loop(exit_packet, ctx=ctx):
120
124
  return exit_packet
121
- if num_iterations >= self._max_iterations:
125
+ if num_iterations == self._max_iterations:
122
126
  logger.info(
123
127
  f"Max iterations reached ({self._max_iterations}). "
124
128
  "Exiting loop."
@@ -128,8 +132,10 @@ class LoopedWorkflow(WorkflowProcessor[InT, OutT, CtxT], Generic[InT, OutT, CtxT
128
132
  chat_inputs = None
129
133
  in_args = None
130
134
 
135
+ raise RuntimeError("Looped workflow did not exit after max iterations.")
136
+
131
137
  @final
132
- async def run_stream( # type: ignore[override]
138
+ async def run_stream(
133
139
  self,
134
140
  chat_inputs: Any | None = None,
135
141
  *,
@@ -139,14 +145,15 @@ class LoopedWorkflow(WorkflowProcessor[InT, OutT, CtxT], Generic[InT, OutT, CtxT
139
145
  forgetful: bool = False,
140
146
  ctx: RunContext[CtxT] | None = None,
141
147
  ) -> AsyncIterator[Event[Any]]:
142
- call_id = self._generate_call_id(call_id)
143
-
144
148
  packet = in_packet
145
- num_iterations = 0
146
149
  exit_packet: Packet[OutT] | None = None
147
150
 
148
- while True:
151
+ for num_iterations in range(1, self._max_iterations + 1):
152
+ call_id = self._generate_call_id(call_id)
153
+
149
154
  for subproc in self.subprocs:
155
+ logger.info(f"\n[Running subprocessor {subproc.name}]\n")
156
+
150
157
  async for event in subproc.run_stream(
151
158
  chat_inputs=chat_inputs,
152
159
  in_packet=packet,
@@ -159,15 +166,16 @@ class LoopedWorkflow(WorkflowProcessor[InT, OutT, CtxT], Generic[InT, OutT, CtxT
159
166
  packet = event.data
160
167
  yield event
161
168
 
169
+ logger.info(f"\n[Finished running subprocessor {subproc.name}]\n")
170
+
162
171
  if subproc is self._end_proc:
163
- num_iterations += 1
164
172
  exit_packet = cast("Packet[OutT]", packet)
165
- if self._exit_workflow_loop(exit_packet, ctx=ctx):
173
+ if self._terminate_workflow_loop(exit_packet, ctx=ctx):
166
174
  yield WorkflowResultEvent(
167
175
  data=exit_packet, proc_name=self.name, call_id=call_id
168
176
  )
169
177
  return
170
- if num_iterations >= self._max_iterations:
178
+ if num_iterations == self._max_iterations:
171
179
  logger.info(
172
180
  f"Max iterations reached ({self._max_iterations}). "
173
181
  "Exiting loop."
@@ -1,21 +1,24 @@
1
+ import logging
1
2
  from collections.abc import AsyncIterator, Sequence
2
3
  from itertools import pairwise
3
4
  from typing import Any, Generic, cast, final
4
5
 
5
6
  from ..errors import WorkflowConstructionError
6
7
  from ..packet_pool import Packet
7
- from ..processor import Processor
8
+ from ..processors.base_processor import BaseProcessor
8
9
  from ..run_context import CtxT, RunContext
9
10
  from ..typing.events import Event, ProcPacketOutputEvent, WorkflowResultEvent
10
11
  from ..typing.io import InT, OutT, ProcName
11
12
  from .workflow_processor import WorkflowProcessor
12
13
 
14
+ logger = logging.getLogger(__name__)
15
+
13
16
 
14
17
  class SequentialWorkflow(WorkflowProcessor[InT, OutT, CtxT], Generic[InT, OutT, CtxT]):
15
18
  def __init__(
16
19
  self,
17
20
  name: ProcName,
18
- subprocs: Sequence[Processor[Any, Any, Any, CtxT]],
21
+ subprocs: Sequence[BaseProcessor[Any, Any, Any, CtxT]],
19
22
  recipients: list[ProcName] | None = None,
20
23
  max_retries: int = 0,
21
24
  ) -> None:
@@ -51,6 +54,8 @@ class SequentialWorkflow(WorkflowProcessor[InT, OutT, CtxT], Generic[InT, OutT,
51
54
 
52
55
  packet = in_packet
53
56
  for subproc in self.subprocs:
57
+ logger.info(f"\n[Running subprocessor {subproc.name}]\n")
58
+
54
59
  packet = await subproc.run(
55
60
  chat_inputs=chat_inputs,
56
61
  in_packet=packet,
@@ -62,10 +67,12 @@ class SequentialWorkflow(WorkflowProcessor[InT, OutT, CtxT], Generic[InT, OutT,
62
67
  chat_inputs = None
63
68
  in_args = None
64
69
 
70
+ logger.info(f"\n[Finished running subprocessor {subproc.name}]\n")
71
+
65
72
  return cast("Packet[OutT]", packet)
66
73
 
67
74
  @final
68
- async def run_stream( # type: ignore[override]
75
+ async def run_stream(
69
76
  self,
70
77
  chat_inputs: Any | None = None,
71
78
  *,
@@ -79,6 +86,8 @@ class SequentialWorkflow(WorkflowProcessor[InT, OutT, CtxT], Generic[InT, OutT,
79
86
 
80
87
  packet = in_packet
81
88
  for subproc in self.subprocs:
89
+ logger.info(f"\n[Running subprocessor {subproc.name}]\n")
90
+
82
91
  async for event in subproc.run_stream(
83
92
  chat_inputs=chat_inputs,
84
93
  in_packet=packet,
@@ -94,6 +103,8 @@ class SequentialWorkflow(WorkflowProcessor[InT, OutT, CtxT], Generic[InT, OutT,
94
103
  chat_inputs = None
95
104
  in_args = None
96
105
 
106
+ logger.info(f"\n[Finished running subprocessor {subproc.name}]\n")
107
+
97
108
  yield WorkflowResultEvent(
98
109
  data=cast("Packet[OutT]", packet), proc_name=self.name, call_id=call_id
99
110
  )