grasp_agents 0.5.5__py3-none-any.whl → 0.5.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_agents/__init__.py +5 -1
- grasp_agents/llm.py +5 -1
- grasp_agents/llm_agent.py +18 -7
- grasp_agents/packet_pool.py +6 -1
- grasp_agents/printer.py +7 -4
- grasp_agents/{processor.py → processors/base_processor.py} +89 -287
- grasp_agents/processors/parallel_processor.py +244 -0
- grasp_agents/processors/processor.py +161 -0
- grasp_agents/runner.py +20 -1
- grasp_agents/typing/events.py +4 -0
- grasp_agents/workflow/looped_workflow.py +35 -27
- grasp_agents/workflow/sequential_workflow.py +14 -3
- grasp_agents/workflow/workflow_processor.py +21 -15
- {grasp_agents-0.5.5.dist-info → grasp_agents-0.5.6.dist-info}/METADATA +1 -1
- {grasp_agents-0.5.5.dist-info → grasp_agents-0.5.6.dist-info}/RECORD +17 -15
- {grasp_agents-0.5.5.dist-info → grasp_agents-0.5.6.dist-info}/WHEEL +0 -0
- {grasp_agents-0.5.5.dist-info → grasp_agents-0.5.6.dist-info}/licenses/LICENSE.md +0 -0
@@ -0,0 +1,244 @@
|
|
1
|
+
import asyncio
|
2
|
+
import logging
|
3
|
+
from collections.abc import AsyncIterator, Sequence
|
4
|
+
from typing import Any, ClassVar, Generic, cast
|
5
|
+
|
6
|
+
|
7
|
+
from ..errors import PacketRoutingError
|
8
|
+
from ..memory import MemT
|
9
|
+
from ..packet import Packet
|
10
|
+
from ..run_context import CtxT, RunContext
|
11
|
+
from ..typing.events import (
|
12
|
+
Event, ProcPacketOutputEvent, ProcPayloadOutputEvent
|
13
|
+
)
|
14
|
+
from ..typing.io import InT, OutT
|
15
|
+
from ..utils import stream_concurrent
|
16
|
+
from .base_processor import BaseProcessor, with_retry, with_retry_stream
|
17
|
+
|
18
|
+
logger = logging.getLogger(__name__)
|
19
|
+
|
20
|
+
|
21
|
+
class ParallelProcessor(BaseProcessor[InT, OutT, MemT, CtxT], Generic[InT, OutT, MemT, CtxT]):
|
22
|
+
_generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
|
23
|
+
0: "_in_type",
|
24
|
+
1: "_out_type",
|
25
|
+
}
|
26
|
+
|
27
|
+
async def _process(
|
28
|
+
self,
|
29
|
+
chat_inputs: Any | None = None,
|
30
|
+
*,
|
31
|
+
in_args: InT | None = None,
|
32
|
+
memory: MemT,
|
33
|
+
call_id: str,
|
34
|
+
ctx: RunContext[CtxT] | None = None,
|
35
|
+
) -> OutT:
|
36
|
+
return cast(OutT, in_args)
|
37
|
+
|
38
|
+
async def _process_stream(
|
39
|
+
self,
|
40
|
+
chat_inputs: Any | None = None,
|
41
|
+
*,
|
42
|
+
in_args: InT | None = None,
|
43
|
+
memory: MemT,
|
44
|
+
call_id: str,
|
45
|
+
ctx: RunContext[CtxT] | None = None,
|
46
|
+
) -> AsyncIterator[Event[Any]]:
|
47
|
+
output = cast(OutT, in_args)
|
48
|
+
yield ProcPayloadOutputEvent(data=output, proc_name=self.name, call_id=call_id)
|
49
|
+
|
50
|
+
def _validate_parallel_recipients(
|
51
|
+
self, out_packets: Sequence[Packet[OutT]], call_id: str
|
52
|
+
) -> None:
|
53
|
+
recipient_sets = [set(p.recipients or []) for p in out_packets]
|
54
|
+
same_recipients = all(rs == recipient_sets[0] for rs in recipient_sets)
|
55
|
+
if not same_recipients:
|
56
|
+
raise PacketRoutingError(
|
57
|
+
proc_name=self.name,
|
58
|
+
call_id=call_id,
|
59
|
+
message="Parallel runs must return the same recipients "
|
60
|
+
f"[proc_name={self.name}; call_id={call_id}]",
|
61
|
+
)
|
62
|
+
|
63
|
+
@with_retry
|
64
|
+
async def _run_single(
|
65
|
+
self,
|
66
|
+
chat_inputs: Any | None = None,
|
67
|
+
*,
|
68
|
+
in_args: InT | None = None,
|
69
|
+
forgetful: bool = False,
|
70
|
+
call_id: str,
|
71
|
+
ctx: RunContext[CtxT] | None = None,
|
72
|
+
) -> Packet[OutT]:
|
73
|
+
memory = self.memory.model_copy(deep=True) if forgetful else self.memory
|
74
|
+
|
75
|
+
output = await self._process(
|
76
|
+
chat_inputs=chat_inputs,
|
77
|
+
in_args=in_args,
|
78
|
+
memory=memory,
|
79
|
+
call_id=call_id,
|
80
|
+
ctx=ctx,
|
81
|
+
)
|
82
|
+
val_output = self._validate_output(output, call_id=call_id)
|
83
|
+
|
84
|
+
recipients = self._select_recipients(output=val_output, ctx=ctx)
|
85
|
+
self._validate_recipients(recipients, call_id=call_id)
|
86
|
+
|
87
|
+
return Packet(payloads=[val_output], sender=self.name, recipients=recipients)
|
88
|
+
|
89
|
+
|
90
|
+
async def _run_parallel(
|
91
|
+
self, in_args: list[InT], call_id: str, ctx: RunContext[CtxT] | None = None
|
92
|
+
) -> Packet[OutT]:
|
93
|
+
tasks = [
|
94
|
+
self._run_single(
|
95
|
+
in_args=inp, forgetful=True, call_id=f"{call_id}/{idx}", ctx=ctx
|
96
|
+
)
|
97
|
+
for idx, inp in enumerate(in_args)
|
98
|
+
]
|
99
|
+
out_packets = await asyncio.gather(*tasks)
|
100
|
+
self._validate_parallel_recipients(out_packets, call_id=call_id)
|
101
|
+
|
102
|
+
return Packet(
|
103
|
+
payloads=[out_packet.payloads[0] for out_packet in out_packets],
|
104
|
+
sender=self.name,
|
105
|
+
recipients=out_packets[0].recipients,
|
106
|
+
)
|
107
|
+
|
108
|
+
async def run(
|
109
|
+
self,
|
110
|
+
chat_inputs: Any | None = None,
|
111
|
+
*,
|
112
|
+
in_packet: Packet[InT] | None = None,
|
113
|
+
in_args: InT | list[InT] | None = None,
|
114
|
+
forgetful: bool = False,
|
115
|
+
call_id: str | None = None,
|
116
|
+
ctx: RunContext[CtxT] | None = None,
|
117
|
+
) -> Packet[OutT]:
|
118
|
+
call_id = self._generate_call_id(call_id)
|
119
|
+
|
120
|
+
val_in_args = self._validate_inputs(
|
121
|
+
call_id=call_id,
|
122
|
+
chat_inputs=chat_inputs,
|
123
|
+
in_packet=in_packet,
|
124
|
+
in_args=in_args,
|
125
|
+
)
|
126
|
+
|
127
|
+
if val_in_args and len(val_in_args) > 1:
|
128
|
+
return await self._run_parallel(in_args=val_in_args, call_id=call_id, ctx=ctx)
|
129
|
+
|
130
|
+
return await self._run_single(
|
131
|
+
chat_inputs=chat_inputs,
|
132
|
+
in_args=val_in_args[0] if val_in_args else None,
|
133
|
+
forgetful=forgetful,
|
134
|
+
call_id=call_id,
|
135
|
+
ctx=ctx,
|
136
|
+
)
|
137
|
+
|
138
|
+
@with_retry_stream
|
139
|
+
async def _run_single_stream(
|
140
|
+
self,
|
141
|
+
chat_inputs: Any | None = None,
|
142
|
+
*,
|
143
|
+
in_args: InT | None = None,
|
144
|
+
forgetful: bool = False,
|
145
|
+
call_id: str,
|
146
|
+
ctx: RunContext[CtxT] | None = None,
|
147
|
+
) -> AsyncIterator[Event[Any]]:
|
148
|
+
memory = self.memory.model_copy(deep=True) if forgetful else self.memory
|
149
|
+
|
150
|
+
output: OutT | None = None
|
151
|
+
async for event in self._process_stream(
|
152
|
+
chat_inputs=chat_inputs,
|
153
|
+
in_args=in_args,
|
154
|
+
memory=memory,
|
155
|
+
call_id=call_id,
|
156
|
+
ctx=ctx,
|
157
|
+
):
|
158
|
+
if isinstance(event, ProcPayloadOutputEvent):
|
159
|
+
output = event.data
|
160
|
+
yield event
|
161
|
+
|
162
|
+
assert output is not None
|
163
|
+
|
164
|
+
val_output = self._validate_output(output, call_id=call_id)
|
165
|
+
|
166
|
+
recipients = self._select_recipients(output=val_output, ctx=ctx)
|
167
|
+
self._validate_recipients(recipients, call_id=call_id)
|
168
|
+
|
169
|
+
out_packet = Packet[OutT](
|
170
|
+
payloads=[val_output], sender=self.name, recipients=recipients
|
171
|
+
)
|
172
|
+
|
173
|
+
yield ProcPacketOutputEvent(
|
174
|
+
data=out_packet, proc_name=self.name, call_id=call_id
|
175
|
+
)
|
176
|
+
|
177
|
+
async def _run_parallel_stream(
|
178
|
+
self,
|
179
|
+
in_args: list[InT],
|
180
|
+
call_id: str,
|
181
|
+
ctx: RunContext[CtxT] | None = None,
|
182
|
+
) -> AsyncIterator[Event[Any]]:
|
183
|
+
streams = [
|
184
|
+
self._run_single_stream(
|
185
|
+
in_args=inp, forgetful=True, call_id=f"{call_id}/{idx}", ctx=ctx
|
186
|
+
)
|
187
|
+
for idx, inp in enumerate(in_args)
|
188
|
+
]
|
189
|
+
|
190
|
+
out_packets_map: dict[int, Packet[OutT]] = {}
|
191
|
+
async for idx, event in stream_concurrent(streams):
|
192
|
+
if isinstance(event, ProcPacketOutputEvent):
|
193
|
+
out_packets_map[idx] = event.data
|
194
|
+
else:
|
195
|
+
yield event
|
196
|
+
|
197
|
+
self._validate_parallel_recipients(
|
198
|
+
out_packets=list(out_packets_map.values()), call_id=call_id
|
199
|
+
)
|
200
|
+
|
201
|
+
out_packet = Packet(
|
202
|
+
payloads=[
|
203
|
+
out_packet.payloads[0]
|
204
|
+
for _, out_packet in sorted(out_packets_map.items())
|
205
|
+
],
|
206
|
+
sender=self.name,
|
207
|
+
recipients=out_packets_map[0].recipients,
|
208
|
+
)
|
209
|
+
|
210
|
+
yield ProcPacketOutputEvent(
|
211
|
+
data=out_packet, proc_name=self.name, call_id=call_id
|
212
|
+
)
|
213
|
+
|
214
|
+
async def run_stream(
|
215
|
+
self,
|
216
|
+
chat_inputs: Any | None = None,
|
217
|
+
*,
|
218
|
+
in_packet: Packet[InT] | None = None,
|
219
|
+
in_args: InT | list[InT] | None = None,
|
220
|
+
forgetful: bool = False,
|
221
|
+
call_id: str | None = None,
|
222
|
+
ctx: RunContext[CtxT] | None = None,
|
223
|
+
) -> AsyncIterator[Event[Any]]:
|
224
|
+
call_id = self._generate_call_id(call_id)
|
225
|
+
|
226
|
+
val_in_args = self._validate_inputs(
|
227
|
+
call_id=call_id,
|
228
|
+
chat_inputs=chat_inputs,
|
229
|
+
in_packet=in_packet,
|
230
|
+
in_args=in_args,
|
231
|
+
)
|
232
|
+
|
233
|
+
if val_in_args and len(val_in_args) > 1:
|
234
|
+
stream = self._run_parallel_stream(in_args=val_in_args, call_id=call_id, ctx=ctx)
|
235
|
+
else:
|
236
|
+
stream = self._run_single_stream(
|
237
|
+
chat_inputs=chat_inputs,
|
238
|
+
in_args=val_in_args[0] if val_in_args else None,
|
239
|
+
forgetful=forgetful,
|
240
|
+
call_id=call_id,
|
241
|
+
ctx=ctx,
|
242
|
+
)
|
243
|
+
async for event in stream:
|
244
|
+
yield event
|
@@ -0,0 +1,161 @@
|
|
1
|
+
import logging
|
2
|
+
from collections.abc import AsyncIterator
|
3
|
+
from typing import Any, ClassVar, Generic, cast
|
4
|
+
|
5
|
+
from ..memory import MemT
|
6
|
+
from ..packet import Packet
|
7
|
+
from ..run_context import CtxT, RunContext
|
8
|
+
from ..typing.events import Event, ProcPacketOutputEvent, ProcPayloadOutputEvent
|
9
|
+
from ..typing.io import InT, OutT, ProcName
|
10
|
+
from .base_processor import BaseProcessor, with_retry, with_retry_stream
|
11
|
+
|
12
|
+
logger = logging.getLogger(__name__)
|
13
|
+
|
14
|
+
|
15
|
+
class Processor(BaseProcessor[InT, OutT, MemT, CtxT], Generic[InT, OutT, MemT, CtxT]):
|
16
|
+
_generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
|
17
|
+
0: "_in_type",
|
18
|
+
1: "_out_type",
|
19
|
+
}
|
20
|
+
|
21
|
+
async def _process(
|
22
|
+
self,
|
23
|
+
chat_inputs: Any | None = None,
|
24
|
+
*,
|
25
|
+
in_args: list[InT] | None = None,
|
26
|
+
memory: MemT,
|
27
|
+
call_id: str,
|
28
|
+
ctx: RunContext[CtxT] | None = None,
|
29
|
+
) -> list[OutT]:
|
30
|
+
return cast("list[OutT]", in_args)
|
31
|
+
|
32
|
+
async def _process_stream(
|
33
|
+
self,
|
34
|
+
chat_inputs: Any | None = None,
|
35
|
+
*,
|
36
|
+
in_args: list[InT] | None = None,
|
37
|
+
memory: MemT,
|
38
|
+
call_id: str,
|
39
|
+
ctx: RunContext[CtxT] | None = None,
|
40
|
+
) -> AsyncIterator[Event[Any]]:
|
41
|
+
outputs = await self._process(
|
42
|
+
chat_inputs=chat_inputs,
|
43
|
+
in_args=in_args,
|
44
|
+
memory=memory,
|
45
|
+
call_id=call_id,
|
46
|
+
ctx=ctx,
|
47
|
+
)
|
48
|
+
for output in outputs:
|
49
|
+
yield ProcPayloadOutputEvent(
|
50
|
+
data=output, proc_name=self.name, call_id=call_id
|
51
|
+
)
|
52
|
+
|
53
|
+
def _preprocess(
|
54
|
+
self,
|
55
|
+
chat_inputs: Any | None = None,
|
56
|
+
*,
|
57
|
+
in_packet: Packet[InT] | None = None,
|
58
|
+
in_args: InT | list[InT] | None = None,
|
59
|
+
forgetful: bool = False,
|
60
|
+
call_id: str | None = None,
|
61
|
+
ctx: RunContext[CtxT] | None = None,
|
62
|
+
) -> tuple[list[InT] | None, MemT, str]:
|
63
|
+
call_id = self._generate_call_id(call_id)
|
64
|
+
|
65
|
+
val_in_args = self._validate_inputs(
|
66
|
+
call_id=call_id,
|
67
|
+
chat_inputs=chat_inputs,
|
68
|
+
in_packet=in_packet,
|
69
|
+
in_args=in_args,
|
70
|
+
)
|
71
|
+
|
72
|
+
memory = self.memory.model_copy(deep=True) if forgetful else self.memory
|
73
|
+
|
74
|
+
return val_in_args, memory, call_id
|
75
|
+
|
76
|
+
def _postprocess(
|
77
|
+
self, outputs: list[OutT], call_id: str, ctx: RunContext[CtxT] | None = None
|
78
|
+
) -> Packet[OutT]:
|
79
|
+
payloads: list[OutT] = []
|
80
|
+
routing: dict[int, list[ProcName] | None] = {}
|
81
|
+
for idx, output in enumerate(outputs):
|
82
|
+
val_output = self._validate_output(output, call_id=call_id)
|
83
|
+
recipients = self._select_recipients(output=val_output, ctx=ctx)
|
84
|
+
self._validate_recipients(recipients, call_id=call_id)
|
85
|
+
|
86
|
+
payloads.append(val_output)
|
87
|
+
routing[idx] = recipients
|
88
|
+
|
89
|
+
recipient_sets = [set(r or []) for r in routing.values()]
|
90
|
+
if all(r == recipient_sets[0] for r in recipient_sets):
|
91
|
+
recipients = routing[0]
|
92
|
+
else:
|
93
|
+
recipients = routing
|
94
|
+
|
95
|
+
return Packet(payloads=payloads, sender=self.name, recipients=recipients) # type: ignore[return-value]
|
96
|
+
|
97
|
+
@with_retry
|
98
|
+
async def run(
|
99
|
+
self,
|
100
|
+
chat_inputs: Any | None = None,
|
101
|
+
*,
|
102
|
+
in_packet: Packet[InT] | None = None,
|
103
|
+
in_args: InT | list[InT] | None = None,
|
104
|
+
forgetful: bool = False,
|
105
|
+
call_id: str | None = None,
|
106
|
+
ctx: RunContext[CtxT] | None = None,
|
107
|
+
) -> Packet[OutT]:
|
108
|
+
val_in_args, memory, call_id = self._preprocess(
|
109
|
+
chat_inputs=chat_inputs,
|
110
|
+
in_packet=in_packet,
|
111
|
+
in_args=in_args,
|
112
|
+
forgetful=forgetful,
|
113
|
+
call_id=call_id,
|
114
|
+
ctx=ctx,
|
115
|
+
)
|
116
|
+
outputs = await self._process(
|
117
|
+
chat_inputs=chat_inputs,
|
118
|
+
in_args=val_in_args,
|
119
|
+
memory=memory,
|
120
|
+
call_id=call_id,
|
121
|
+
ctx=ctx,
|
122
|
+
)
|
123
|
+
|
124
|
+
return self._postprocess(outputs=outputs, call_id=call_id, ctx=ctx)
|
125
|
+
|
126
|
+
@with_retry_stream
|
127
|
+
async def run_stream(
|
128
|
+
self,
|
129
|
+
chat_inputs: Any | None = None,
|
130
|
+
*,
|
131
|
+
in_packet: Packet[InT] | None = None,
|
132
|
+
in_args: InT | list[InT] | None = None,
|
133
|
+
forgetful: bool = False,
|
134
|
+
call_id: str | None = None,
|
135
|
+
ctx: RunContext[CtxT] | None = None,
|
136
|
+
) -> AsyncIterator[Event[Any]]:
|
137
|
+
val_in_args, memory, call_id = self._preprocess(
|
138
|
+
chat_inputs=chat_inputs,
|
139
|
+
in_packet=in_packet,
|
140
|
+
in_args=in_args,
|
141
|
+
forgetful=forgetful,
|
142
|
+
call_id=call_id,
|
143
|
+
ctx=ctx,
|
144
|
+
)
|
145
|
+
outputs: list[OutT] = []
|
146
|
+
async for event in self._process_stream(
|
147
|
+
chat_inputs=chat_inputs,
|
148
|
+
in_args=val_in_args,
|
149
|
+
memory=memory,
|
150
|
+
call_id=call_id,
|
151
|
+
ctx=ctx,
|
152
|
+
):
|
153
|
+
if isinstance(event, ProcPayloadOutputEvent):
|
154
|
+
outputs.append(event.data)
|
155
|
+
yield event
|
156
|
+
|
157
|
+
out_packet = self._postprocess(outputs=outputs, call_id=call_id, ctx=ctx)
|
158
|
+
|
159
|
+
yield ProcPacketOutputEvent(
|
160
|
+
data=out_packet, proc_name=self.name, call_id=call_id
|
161
|
+
)
|
grasp_agents/runner.py
CHANGED
@@ -1,3 +1,4 @@
|
|
1
|
+
import logging
|
1
2
|
from collections.abc import AsyncIterator, Sequence
|
2
3
|
from functools import partial
|
3
4
|
from typing import Any, Generic
|
@@ -5,11 +6,13 @@ from typing import Any, Generic
|
|
5
6
|
from .errors import RunnerError
|
6
7
|
from .packet import Packet, StartPacket
|
7
8
|
from .packet_pool import END_PROC_NAME, PacketPool
|
8
|
-
from .processor import Processor
|
9
|
+
from .processors.processor import Processor
|
9
10
|
from .run_context import CtxT, RunContext
|
10
11
|
from .typing.events import Event, ProcPacketOutputEvent, RunResultEvent
|
11
12
|
from .typing.io import OutT
|
12
13
|
|
14
|
+
logger = logging.getLogger(__name__)
|
15
|
+
|
13
16
|
|
14
17
|
class Runner(Generic[OutT, CtxT]):
|
15
18
|
def __init__(
|
@@ -53,9 +56,18 @@ class Runner(Generic[OutT, CtxT]):
|
|
53
56
|
**run_kwargs: Any,
|
54
57
|
) -> None:
|
55
58
|
_in_packet, _chat_inputs = self._unpack_packet(packet)
|
59
|
+
|
60
|
+
logger.info(f"\n[Running processor {proc.name}]\n")
|
61
|
+
|
56
62
|
out_packet = await proc.run(
|
57
63
|
chat_inputs=_chat_inputs, in_packet=_in_packet, ctx=ctx, **run_kwargs
|
58
64
|
)
|
65
|
+
|
66
|
+
logger.info(
|
67
|
+
f"\n[Finished running processor {proc.name}]\n"
|
68
|
+
f"Posting output packet to recipients {out_packet.recipients}\n"
|
69
|
+
)
|
70
|
+
|
59
71
|
await pool.post(out_packet)
|
60
72
|
|
61
73
|
async def _packet_handler_stream(
|
@@ -68,6 +80,8 @@ class Runner(Generic[OutT, CtxT]):
|
|
68
80
|
) -> None:
|
69
81
|
_in_packet, _chat_inputs = self._unpack_packet(packet)
|
70
82
|
|
83
|
+
logger.info(f"\n[Running processor {proc.name}]\n")
|
84
|
+
|
71
85
|
out_packet: Packet[Any] | None = None
|
72
86
|
async for event in proc.run_stream(
|
73
87
|
chat_inputs=_chat_inputs, in_packet=_in_packet, ctx=ctx, **run_kwargs
|
@@ -78,6 +92,11 @@ class Runner(Generic[OutT, CtxT]):
|
|
78
92
|
|
79
93
|
assert out_packet is not None
|
80
94
|
|
95
|
+
logger.info(
|
96
|
+
f"\n[Finished running processor {proc.name}]\n"
|
97
|
+
f"Posting output packet to recipients {out_packet.recipients}\n"
|
98
|
+
)
|
99
|
+
|
81
100
|
await pool.post(out_packet)
|
82
101
|
|
83
102
|
async def run(
|
grasp_agents/typing/events.py
CHANGED
@@ -57,6 +57,10 @@ class Event(BaseModel, Generic[_T], frozen=True):
|
|
57
57
|
call_id: str | None = None
|
58
58
|
data: _T
|
59
59
|
|
60
|
+
class DummyEvent(Event[Any], frozen=True):
|
61
|
+
type: Literal[EventType.PAYLOAD_OUT] = EventType.PAYLOAD_OUT
|
62
|
+
source: Literal[EventSourceType.PROC] = EventSourceType.PROC
|
63
|
+
data: Any = None
|
60
64
|
|
61
65
|
class CompletionEvent(Event[Completion], frozen=True):
|
62
66
|
type: Literal[EventType.COMP] = EventType.COMP
|
@@ -5,7 +5,7 @@ from typing import Any, Generic, Protocol, TypeVar, cast, final
|
|
5
5
|
|
6
6
|
from ..errors import WorkflowConstructionError
|
7
7
|
from ..packet_pool import Packet
|
8
|
-
from ..
|
8
|
+
from ..processors.base_processor import BaseProcessor
|
9
9
|
from ..run_context import CtxT, RunContext
|
10
10
|
from ..typing.events import Event, ProcPacketOutputEvent, WorkflowResultEvent
|
11
11
|
from ..typing.io import InT, OutT, ProcName
|
@@ -16,7 +16,7 @@ logger = getLogger(__name__)
|
|
16
16
|
_OutT_contra = TypeVar("_OutT_contra", contravariant=True)
|
17
17
|
|
18
18
|
|
19
|
-
class
|
19
|
+
class WorkflowLoopTerminator(Protocol[_OutT_contra, CtxT]):
|
20
20
|
def __call__(
|
21
21
|
self,
|
22
22
|
out_packet: Packet[_OutT_contra],
|
@@ -29,8 +29,8 @@ class LoopedWorkflow(WorkflowProcessor[InT, OutT, CtxT], Generic[InT, OutT, CtxT
|
|
29
29
|
def __init__(
|
30
30
|
self,
|
31
31
|
name: ProcName,
|
32
|
-
subprocs: Sequence[
|
33
|
-
exit_proc:
|
32
|
+
subprocs: Sequence[BaseProcessor[Any, Any, Any, CtxT]],
|
33
|
+
exit_proc: BaseProcessor[Any, OutT, Any, CtxT],
|
34
34
|
recipients: list[ProcName] | None = None,
|
35
35
|
max_retries: int = 0,
|
36
36
|
max_iterations: int = 10,
|
@@ -60,28 +60,30 @@ class LoopedWorkflow(WorkflowProcessor[InT, OutT, CtxT], Generic[InT, OutT, CtxT
|
|
60
60
|
|
61
61
|
self._max_iterations = max_iterations
|
62
62
|
|
63
|
-
self.
|
63
|
+
self.workflow_loop_terminator: WorkflowLoopTerminator[OutT, CtxT] | None
|
64
|
+
if not hasattr(type(self), "workflow_loop_terminator"):
|
65
|
+
self.workflow_loop_terminator = None
|
64
66
|
|
65
67
|
@property
|
66
68
|
def max_iterations(self) -> int:
|
67
69
|
return self._max_iterations
|
68
70
|
|
69
|
-
def
|
70
|
-
self, func:
|
71
|
-
) ->
|
72
|
-
self.
|
71
|
+
def add_workflow_loop_terminator(
|
72
|
+
self, func: WorkflowLoopTerminator[OutT, CtxT]
|
73
|
+
) -> WorkflowLoopTerminator[OutT, CtxT]:
|
74
|
+
self.workflow_loop_terminator = func
|
73
75
|
|
74
76
|
return func
|
75
77
|
|
76
|
-
def
|
78
|
+
def _terminate_workflow_loop(
|
77
79
|
self,
|
78
80
|
out_packet: Packet[OutT],
|
79
81
|
*,
|
80
82
|
ctx: RunContext[CtxT] | None = None,
|
81
83
|
**kwargs: Any,
|
82
84
|
) -> bool:
|
83
|
-
if self.
|
84
|
-
return self.
|
85
|
+
if self.workflow_loop_terminator:
|
86
|
+
return self.workflow_loop_terminator(out_packet, ctx=ctx, **kwargs)
|
85
87
|
|
86
88
|
return False
|
87
89
|
|
@@ -96,14 +98,15 @@ class LoopedWorkflow(WorkflowProcessor[InT, OutT, CtxT], Generic[InT, OutT, CtxT
|
|
96
98
|
forgetful: bool = False,
|
97
99
|
ctx: RunContext[CtxT] | None = None,
|
98
100
|
) -> Packet[OutT]:
|
99
|
-
call_id = self._generate_call_id(call_id)
|
100
|
-
|
101
101
|
packet = in_packet
|
102
|
-
num_iterations = 0
|
103
102
|
exit_packet: Packet[OutT] | None = None
|
104
103
|
|
105
|
-
|
104
|
+
for num_iterations in range(1, self._max_iterations + 1):
|
105
|
+
call_id = self._generate_call_id(call_id)
|
106
|
+
|
106
107
|
for subproc in self.subprocs:
|
108
|
+
logger.info(f"\n[Running subprocessor {subproc.name}]\n")
|
109
|
+
|
107
110
|
packet = await subproc.run(
|
108
111
|
chat_inputs=chat_inputs,
|
109
112
|
in_packet=packet,
|
@@ -113,12 +116,13 @@ class LoopedWorkflow(WorkflowProcessor[InT, OutT, CtxT], Generic[InT, OutT, CtxT
|
|
113
116
|
ctx=ctx,
|
114
117
|
)
|
115
118
|
|
119
|
+
logger.info(f"\n[Finished running subprocessor {subproc.name}]\n")
|
120
|
+
|
116
121
|
if subproc is self._end_proc:
|
117
|
-
num_iterations += 1
|
118
122
|
exit_packet = cast("Packet[OutT]", packet)
|
119
|
-
if self.
|
123
|
+
if self._terminate_workflow_loop(exit_packet, ctx=ctx):
|
120
124
|
return exit_packet
|
121
|
-
if num_iterations
|
125
|
+
if num_iterations == self._max_iterations:
|
122
126
|
logger.info(
|
123
127
|
f"Max iterations reached ({self._max_iterations}). "
|
124
128
|
"Exiting loop."
|
@@ -128,8 +132,10 @@ class LoopedWorkflow(WorkflowProcessor[InT, OutT, CtxT], Generic[InT, OutT, CtxT
|
|
128
132
|
chat_inputs = None
|
129
133
|
in_args = None
|
130
134
|
|
135
|
+
raise RuntimeError("Looped workflow did not exit after max iterations.")
|
136
|
+
|
131
137
|
@final
|
132
|
-
async def run_stream(
|
138
|
+
async def run_stream(
|
133
139
|
self,
|
134
140
|
chat_inputs: Any | None = None,
|
135
141
|
*,
|
@@ -139,14 +145,15 @@ class LoopedWorkflow(WorkflowProcessor[InT, OutT, CtxT], Generic[InT, OutT, CtxT
|
|
139
145
|
forgetful: bool = False,
|
140
146
|
ctx: RunContext[CtxT] | None = None,
|
141
147
|
) -> AsyncIterator[Event[Any]]:
|
142
|
-
call_id = self._generate_call_id(call_id)
|
143
|
-
|
144
148
|
packet = in_packet
|
145
|
-
num_iterations = 0
|
146
149
|
exit_packet: Packet[OutT] | None = None
|
147
150
|
|
148
|
-
|
151
|
+
for num_iterations in range(1, self._max_iterations + 1):
|
152
|
+
call_id = self._generate_call_id(call_id)
|
153
|
+
|
149
154
|
for subproc in self.subprocs:
|
155
|
+
logger.info(f"\n[Running subprocessor {subproc.name}]\n")
|
156
|
+
|
150
157
|
async for event in subproc.run_stream(
|
151
158
|
chat_inputs=chat_inputs,
|
152
159
|
in_packet=packet,
|
@@ -159,15 +166,16 @@ class LoopedWorkflow(WorkflowProcessor[InT, OutT, CtxT], Generic[InT, OutT, CtxT
|
|
159
166
|
packet = event.data
|
160
167
|
yield event
|
161
168
|
|
169
|
+
logger.info(f"\n[Finished running subprocessor {subproc.name}]\n")
|
170
|
+
|
162
171
|
if subproc is self._end_proc:
|
163
|
-
num_iterations += 1
|
164
172
|
exit_packet = cast("Packet[OutT]", packet)
|
165
|
-
if self.
|
173
|
+
if self._terminate_workflow_loop(exit_packet, ctx=ctx):
|
166
174
|
yield WorkflowResultEvent(
|
167
175
|
data=exit_packet, proc_name=self.name, call_id=call_id
|
168
176
|
)
|
169
177
|
return
|
170
|
-
if num_iterations
|
178
|
+
if num_iterations == self._max_iterations:
|
171
179
|
logger.info(
|
172
180
|
f"Max iterations reached ({self._max_iterations}). "
|
173
181
|
"Exiting loop."
|
@@ -1,21 +1,24 @@
|
|
1
|
+
import logging
|
1
2
|
from collections.abc import AsyncIterator, Sequence
|
2
3
|
from itertools import pairwise
|
3
4
|
from typing import Any, Generic, cast, final
|
4
5
|
|
5
6
|
from ..errors import WorkflowConstructionError
|
6
7
|
from ..packet_pool import Packet
|
7
|
-
from ..
|
8
|
+
from ..processors.base_processor import BaseProcessor
|
8
9
|
from ..run_context import CtxT, RunContext
|
9
10
|
from ..typing.events import Event, ProcPacketOutputEvent, WorkflowResultEvent
|
10
11
|
from ..typing.io import InT, OutT, ProcName
|
11
12
|
from .workflow_processor import WorkflowProcessor
|
12
13
|
|
14
|
+
logger = logging.getLogger(__name__)
|
15
|
+
|
13
16
|
|
14
17
|
class SequentialWorkflow(WorkflowProcessor[InT, OutT, CtxT], Generic[InT, OutT, CtxT]):
|
15
18
|
def __init__(
|
16
19
|
self,
|
17
20
|
name: ProcName,
|
18
|
-
subprocs: Sequence[
|
21
|
+
subprocs: Sequence[BaseProcessor[Any, Any, Any, CtxT]],
|
19
22
|
recipients: list[ProcName] | None = None,
|
20
23
|
max_retries: int = 0,
|
21
24
|
) -> None:
|
@@ -51,6 +54,8 @@ class SequentialWorkflow(WorkflowProcessor[InT, OutT, CtxT], Generic[InT, OutT,
|
|
51
54
|
|
52
55
|
packet = in_packet
|
53
56
|
for subproc in self.subprocs:
|
57
|
+
logger.info(f"\n[Running subprocessor {subproc.name}]\n")
|
58
|
+
|
54
59
|
packet = await subproc.run(
|
55
60
|
chat_inputs=chat_inputs,
|
56
61
|
in_packet=packet,
|
@@ -62,10 +67,12 @@ class SequentialWorkflow(WorkflowProcessor[InT, OutT, CtxT], Generic[InT, OutT,
|
|
62
67
|
chat_inputs = None
|
63
68
|
in_args = None
|
64
69
|
|
70
|
+
logger.info(f"\n[Finished running subprocessor {subproc.name}]\n")
|
71
|
+
|
65
72
|
return cast("Packet[OutT]", packet)
|
66
73
|
|
67
74
|
@final
|
68
|
-
async def run_stream(
|
75
|
+
async def run_stream(
|
69
76
|
self,
|
70
77
|
chat_inputs: Any | None = None,
|
71
78
|
*,
|
@@ -79,6 +86,8 @@ class SequentialWorkflow(WorkflowProcessor[InT, OutT, CtxT], Generic[InT, OutT,
|
|
79
86
|
|
80
87
|
packet = in_packet
|
81
88
|
for subproc in self.subprocs:
|
89
|
+
logger.info(f"\n[Running subprocessor {subproc.name}]\n")
|
90
|
+
|
82
91
|
async for event in subproc.run_stream(
|
83
92
|
chat_inputs=chat_inputs,
|
84
93
|
in_packet=packet,
|
@@ -94,6 +103,8 @@ class SequentialWorkflow(WorkflowProcessor[InT, OutT, CtxT], Generic[InT, OutT,
|
|
94
103
|
chat_inputs = None
|
95
104
|
in_args = None
|
96
105
|
|
106
|
+
logger.info(f"\n[Finished running subprocessor {subproc.name}]\n")
|
107
|
+
|
97
108
|
yield WorkflowResultEvent(
|
98
109
|
data=cast("Packet[OutT]", packet), proc_name=self.name, call_id=call_id
|
99
110
|
)
|