grasp_agents 0.5.2__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_agents/__init__.py +3 -4
- grasp_agents/errors.py +80 -18
- grasp_agents/llm_agent.py +36 -58
- grasp_agents/llm_policy_executor.py +2 -2
- grasp_agents/packet.py +23 -4
- grasp_agents/packet_pool.py +117 -50
- grasp_agents/printer.py +8 -4
- grasp_agents/processor.py +208 -163
- grasp_agents/prompt_builder.py +80 -105
- grasp_agents/run_context.py +1 -9
- grasp_agents/runner.py +110 -21
- grasp_agents/typing/events.py +8 -4
- grasp_agents/typing/io.py +1 -1
- grasp_agents/workflow/looped_workflow.py +47 -52
- grasp_agents/workflow/sequential_workflow.py +6 -10
- grasp_agents/workflow/workflow_processor.py +7 -15
- {grasp_agents-0.5.2.dist-info → grasp_agents-0.5.4.dist-info}/METADATA +1 -1
- {grasp_agents-0.5.2.dist-info → grasp_agents-0.5.4.dist-info}/RECORD +20 -21
- grasp_agents/comm_processor.py +0 -214
- {grasp_agents-0.5.2.dist-info → grasp_agents-0.5.4.dist-info}/WHEEL +0 -0
- {grasp_agents-0.5.2.dist-info → grasp_agents-0.5.4.dist-info}/licenses/LICENSE.md +0 -0
grasp_agents/prompt_builder.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1
1
|
import json
|
2
|
-
from collections.abc import
|
2
|
+
from collections.abc import Sequence
|
3
3
|
from typing import ClassVar, Generic, Protocol, TypeAlias, TypeVar
|
4
4
|
|
5
5
|
from pydantic import BaseModel, TypeAdapter
|
@@ -26,9 +26,8 @@ class MakeSystemPromptHandler(Protocol[CtxT]):
|
|
26
26
|
class MakeInputContentHandler(Protocol[_InT_contra, CtxT]):
|
27
27
|
def __call__(
|
28
28
|
self,
|
29
|
-
*,
|
30
29
|
in_args: _InT_contra | None,
|
31
|
-
|
30
|
+
*,
|
32
31
|
ctx: RunContext[CtxT] | None,
|
33
32
|
) -> Content: ...
|
34
33
|
|
@@ -42,132 +41,129 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
|
|
42
41
|
def __init__(
|
43
42
|
self,
|
44
43
|
agent_name: str,
|
45
|
-
|
46
|
-
|
44
|
+
sys_prompt: LLMPrompt | None,
|
45
|
+
in_prompt: LLMPrompt | None,
|
47
46
|
sys_args_schema: type[LLMPromptArgs] | None = None,
|
48
|
-
usr_args_schema: type[LLMPromptArgs] | None = None,
|
49
47
|
):
|
50
48
|
self._in_type: type[InT]
|
51
49
|
super().__init__()
|
52
50
|
|
53
51
|
self._agent_name = agent_name
|
54
|
-
self.
|
55
|
-
self.
|
52
|
+
self.sys_prompt = sys_prompt
|
53
|
+
self.in_prompt = in_prompt
|
56
54
|
self.sys_args_schema = sys_args_schema
|
57
|
-
self.usr_args_schema = usr_args_schema
|
58
55
|
self.make_system_prompt_impl: MakeSystemPromptHandler[CtxT] | None = None
|
59
56
|
self.make_input_content_impl: MakeInputContentHandler[InT, CtxT] | None = None
|
60
57
|
|
61
58
|
self._in_args_type_adapter: TypeAdapter[InT] = TypeAdapter(self._in_type)
|
62
59
|
|
63
|
-
def
|
64
|
-
self, sys_args: LLMPromptArgs | None
|
65
|
-
) ->
|
66
|
-
if self.sys_prompt_template is None:
|
67
|
-
return None
|
68
|
-
|
60
|
+
def _validate_sys_args(
|
61
|
+
self, sys_args: LLMPromptArgs | None
|
62
|
+
) -> LLMPromptArgs | None:
|
69
63
|
val_sys_args = sys_args
|
70
64
|
if sys_args is not None:
|
71
65
|
if self.sys_args_schema is not None:
|
72
66
|
val_sys_args = self.sys_args_schema.model_validate(sys_args)
|
73
67
|
else:
|
74
68
|
raise SystemPromptBuilderError(
|
75
|
-
|
76
|
-
"
|
69
|
+
proc_name=self._agent_name,
|
70
|
+
message="System prompt template and arguments is set, "
|
71
|
+
"but system arguments schema is not provided "
|
72
|
+
f"[agent_name={self._agent_name}]",
|
77
73
|
)
|
78
74
|
|
75
|
+
return val_sys_args
|
76
|
+
|
77
|
+
def make_system_prompt(
|
78
|
+
self,
|
79
|
+
sys_args: LLMPromptArgs | None = None,
|
80
|
+
ctx: RunContext[CtxT] | None = None,
|
81
|
+
) -> str | None:
|
82
|
+
if self.sys_prompt is None:
|
83
|
+
return None
|
84
|
+
|
85
|
+
val_sys_args = self._validate_sys_args(sys_args=sys_args)
|
86
|
+
|
79
87
|
if self.make_system_prompt_impl:
|
80
88
|
return self.make_system_prompt_impl(sys_args=val_sys_args, ctx=ctx)
|
81
89
|
|
82
90
|
sys_args_dict = val_sys_args.model_dump() if val_sys_args else {}
|
83
91
|
|
84
|
-
return self.
|
92
|
+
return self.sys_prompt.format(**sys_args_dict)
|
93
|
+
|
94
|
+
def _validate_input_args(self, in_args: InT) -> InT:
|
95
|
+
val_in_args = self._in_args_type_adapter.validate_python(in_args)
|
96
|
+
if isinstance(val_in_args, BaseModel):
|
97
|
+
has_image = self._has_image_data(val_in_args)
|
98
|
+
if has_image and self.in_prompt is None:
|
99
|
+
raise InputPromptBuilderError(
|
100
|
+
proc_name=self._agent_name,
|
101
|
+
message="BaseModel input arguments contain ImageData, "
|
102
|
+
"but input prompt template is not set "
|
103
|
+
f"[agent_name={self._agent_name}]. Cannot format input arguments.",
|
104
|
+
)
|
105
|
+
elif self.in_prompt is not None:
|
106
|
+
raise InputPromptBuilderError(
|
107
|
+
proc_name=self._agent_name,
|
108
|
+
message="Cannot use the input prompt template with "
|
109
|
+
f"non-BaseModel input arguments [agent_name={self._agent_name}]",
|
110
|
+
)
|
111
|
+
|
112
|
+
return val_in_args
|
85
113
|
|
86
114
|
def make_input_content(
|
87
115
|
self,
|
88
|
-
|
89
|
-
in_args: InT | None,
|
90
|
-
usr_args: LLMPromptArgs | None,
|
116
|
+
in_args: InT | None = None,
|
91
117
|
ctx: RunContext[CtxT] | None = None,
|
92
118
|
) -> Content:
|
93
|
-
val_in_args
|
94
|
-
|
95
|
-
|
119
|
+
val_in_args = in_args
|
120
|
+
if in_args is not None:
|
121
|
+
val_in_args = self._validate_input_args(in_args=in_args)
|
96
122
|
|
97
123
|
if self.make_input_content_impl:
|
98
|
-
return self.make_input_content_impl(
|
99
|
-
in_args=val_in_args, usr_args=val_usr_args, ctx=ctx
|
100
|
-
)
|
101
|
-
|
102
|
-
combined_args = self._combine_args(in_args=val_in_args, usr_args=val_usr_args)
|
103
|
-
if isinstance(combined_args, str):
|
104
|
-
return Content.from_text(combined_args)
|
124
|
+
return self.make_input_content_impl(in_args=val_in_args, ctx=ctx)
|
105
125
|
|
106
|
-
if
|
107
|
-
|
108
|
-
self.
|
126
|
+
if val_in_args is None:
|
127
|
+
raise InputPromptBuilderError(
|
128
|
+
proc_name=self._agent_name,
|
129
|
+
message="Input arguments are not provided, "
|
130
|
+
f"but input content is required [agent_name={self._agent_name}]",
|
109
131
|
)
|
110
132
|
|
111
|
-
|
133
|
+
if issubclass(self._in_type, BaseModel) and isinstance(val_in_args, BaseModel):
|
134
|
+
val_in_args_map = self._format_pydantic_prompt_args(val_in_args)
|
135
|
+
if self.in_prompt is not None:
|
136
|
+
return Content.from_formatted_prompt(self.in_prompt, **val_in_args_map)
|
137
|
+
return Content.from_text(json.dumps(val_in_args_map, indent=2))
|
138
|
+
|
139
|
+
fmt_in_args = self._in_args_type_adapter.dump_json(
|
140
|
+
val_in_args, indent=2, warnings="error"
|
141
|
+
).decode("utf-8")
|
142
|
+
return Content.from_text(fmt_in_args)
|
112
143
|
|
113
144
|
def make_input_message(
|
114
145
|
self,
|
115
146
|
chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
|
116
147
|
in_args: InT | None = None,
|
117
|
-
usr_args: LLMPromptArgs | None = None,
|
118
148
|
ctx: RunContext[CtxT] | None = None,
|
119
149
|
) -> UserMessage | None:
|
120
|
-
if chat_inputs is None and in_args is
|
121
|
-
|
150
|
+
if chat_inputs is not None and in_args is not None:
|
151
|
+
raise InputPromptBuilderError(
|
152
|
+
proc_name=self._agent_name,
|
153
|
+
message="Cannot use both chat inputs and input arguments "
|
154
|
+
f"at the same time [agent_name={self._agent_name}]",
|
155
|
+
)
|
122
156
|
|
123
157
|
if chat_inputs:
|
124
158
|
if isinstance(chat_inputs, LLMPrompt):
|
125
159
|
return UserMessage.from_text(chat_inputs, name=self._agent_name)
|
126
160
|
return UserMessage.from_content_parts(chat_inputs, name=self._agent_name)
|
127
161
|
|
128
|
-
|
129
|
-
in_args=in_args,
|
162
|
+
return UserMessage(
|
163
|
+
content=self.make_input_content(in_args=in_args, ctx=ctx),
|
164
|
+
name=self._agent_name,
|
130
165
|
)
|
131
166
|
|
132
|
-
return UserMessage(content=in_content, name=self._agent_name)
|
133
|
-
|
134
|
-
def _validate_prompt_args(
|
135
|
-
self,
|
136
|
-
*,
|
137
|
-
in_args: InT | None,
|
138
|
-
usr_args: LLMPromptArgs | None,
|
139
|
-
) -> tuple[InT | None, LLMPromptArgs | None]:
|
140
|
-
val_usr_args = usr_args
|
141
|
-
if usr_args is not None:
|
142
|
-
if self.in_prompt_template is None:
|
143
|
-
raise InputPromptBuilderError(
|
144
|
-
"Input prompt template is not set, but user arguments are provided."
|
145
|
-
)
|
146
|
-
if self.usr_args_schema is None:
|
147
|
-
raise InputPromptBuilderError(
|
148
|
-
"User arguments schema is not provided, but user arguments are "
|
149
|
-
"given."
|
150
|
-
)
|
151
|
-
val_usr_args = self.usr_args_schema.model_validate(usr_args)
|
152
|
-
|
153
|
-
val_in_args = in_args
|
154
|
-
if in_args is not None:
|
155
|
-
val_in_args = self._in_args_type_adapter.validate_python(in_args)
|
156
|
-
if isinstance(val_in_args, BaseModel):
|
157
|
-
has_image = self._has_image_data(val_in_args)
|
158
|
-
if has_image and self.in_prompt_template is None:
|
159
|
-
raise InputPromptBuilderError(
|
160
|
-
"BaseModel input arguments contain ImageData, but input prompt "
|
161
|
-
"template is not set. Cannot format input arguments."
|
162
|
-
)
|
163
|
-
elif self.in_prompt_template is not None:
|
164
|
-
raise InputPromptBuilderError(
|
165
|
-
"Cannot use the input prompt template with "
|
166
|
-
"non-BaseModel input arguments."
|
167
|
-
)
|
168
|
-
|
169
|
-
return val_in_args, val_usr_args
|
170
|
-
|
171
167
|
@staticmethod
|
172
168
|
def _has_image_data(inp: BaseModel) -> bool:
|
173
169
|
contains_image_data = False
|
@@ -180,39 +176,18 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
|
|
180
176
|
@staticmethod
|
181
177
|
def _format_pydantic_prompt_args(inp: BaseModel) -> dict[str, PromptArgumentType]:
|
182
178
|
formatted_args: dict[str, PromptArgumentType] = {}
|
183
|
-
for
|
184
|
-
if
|
179
|
+
for field_name, field_info in type(inp).model_fields.items():
|
180
|
+
if field_info.exclude:
|
185
181
|
continue
|
186
182
|
|
187
|
-
val = getattr(inp,
|
183
|
+
val = getattr(inp, field_name)
|
188
184
|
if isinstance(val, (int, str, bool, ImageData)):
|
189
|
-
formatted_args[
|
185
|
+
formatted_args[field_name] = val
|
190
186
|
else:
|
191
|
-
formatted_args[
|
187
|
+
formatted_args[field_name] = (
|
192
188
|
TypeAdapter(type(val)) # type: ignore[return-value]
|
193
189
|
.dump_json(val, indent=2, warnings="error")
|
194
190
|
.decode("utf-8")
|
195
191
|
)
|
196
192
|
|
197
193
|
return formatted_args
|
198
|
-
|
199
|
-
def _combine_args(
|
200
|
-
self, *, in_args: InT | None, usr_args: LLMPromptArgs | None
|
201
|
-
) -> Mapping[str, PromptArgumentType] | str:
|
202
|
-
fmt_usr_args = self._format_pydantic_prompt_args(usr_args) if usr_args else {}
|
203
|
-
|
204
|
-
if in_args is None:
|
205
|
-
return fmt_usr_args
|
206
|
-
|
207
|
-
if isinstance(in_args, BaseModel):
|
208
|
-
fmt_in_args = self._format_pydantic_prompt_args(in_args)
|
209
|
-
return fmt_in_args | fmt_usr_args
|
210
|
-
|
211
|
-
combined_args_str = self._in_args_type_adapter.dump_json(
|
212
|
-
in_args, indent=2, warnings="error"
|
213
|
-
).decode("utf-8")
|
214
|
-
if usr_args is not None:
|
215
|
-
fmt_usr_args_str = usr_args.model_dump_json(indent=2, warnings="error")
|
216
|
-
combined_args_str += "\n" + fmt_usr_args_str
|
217
|
-
|
218
|
-
return combined_args_str
|
grasp_agents/run_context.py
CHANGED
@@ -9,21 +9,13 @@ from .printer import ColoringMode, Printer
|
|
9
9
|
from .typing.io import LLMPromptArgs, ProcName
|
10
10
|
from .usage_tracker import UsageTracker
|
11
11
|
|
12
|
-
|
13
|
-
class RunArgs(BaseModel):
|
14
|
-
sys: LLMPromptArgs | None = None
|
15
|
-
usr: LLMPromptArgs | None = None
|
16
|
-
|
17
|
-
model_config = ConfigDict(extra="forbid")
|
18
|
-
|
19
|
-
|
20
12
|
CtxT = TypeVar("CtxT")
|
21
13
|
|
22
14
|
|
23
15
|
class RunContext(BaseModel, Generic[CtxT]):
|
24
16
|
state: CtxT | None = None
|
25
17
|
|
26
|
-
|
18
|
+
sys_args: dict[ProcName, LLMPromptArgs] = Field(default_factory=dict)
|
27
19
|
|
28
20
|
is_streaming: bool = False
|
29
21
|
result: Any | None = None
|
grasp_agents/runner.py
CHANGED
@@ -1,42 +1,131 @@
|
|
1
1
|
from collections.abc import AsyncIterator, Sequence
|
2
|
+
from functools import partial
|
2
3
|
from typing import Any, Generic
|
3
4
|
|
4
|
-
from .
|
5
|
+
from .errors import RunnerError
|
6
|
+
from .packet import Packet, StartPacket
|
7
|
+
from .packet_pool import END_PROC_NAME, PacketPool
|
8
|
+
from .processor import Processor
|
5
9
|
from .run_context import CtxT, RunContext
|
6
|
-
from .typing.events import Event
|
10
|
+
from .typing.events import Event, ProcPacketOutputEvent, RunResultEvent
|
11
|
+
from .typing.io import OutT
|
7
12
|
|
8
13
|
|
9
|
-
class Runner(Generic[CtxT]):
|
14
|
+
class Runner(Generic[OutT, CtxT]):
|
10
15
|
def __init__(
|
11
16
|
self,
|
12
|
-
|
13
|
-
procs: Sequence[
|
17
|
+
entry_proc: Processor[Any, Any, Any, CtxT],
|
18
|
+
procs: Sequence[Processor[Any, Any, Any, CtxT]],
|
14
19
|
ctx: RunContext[CtxT] | None = None,
|
15
20
|
) -> None:
|
16
|
-
if
|
17
|
-
raise
|
18
|
-
f"
|
21
|
+
if entry_proc not in procs:
|
22
|
+
raise RunnerError(
|
23
|
+
f"Entry processor {entry_proc.name} must be in the list of processors: "
|
19
24
|
f"{', '.join(proc.name for proc in procs)}"
|
20
25
|
)
|
21
|
-
|
26
|
+
if sum(1 for proc in procs if END_PROC_NAME in (proc.recipients or [])) != 1:
|
27
|
+
raise RunnerError(
|
28
|
+
"There must be exactly one processor with recipient 'END'."
|
29
|
+
)
|
30
|
+
|
31
|
+
self._entry_proc = entry_proc
|
22
32
|
self._procs = procs
|
23
33
|
self._ctx = ctx or RunContext[CtxT]()
|
34
|
+
self._packet_pool: PacketPool[CtxT] = PacketPool()
|
24
35
|
|
25
36
|
@property
|
26
37
|
def ctx(self) -> RunContext[CtxT]:
|
27
38
|
return self._ctx
|
28
39
|
|
29
|
-
|
30
|
-
self
|
31
|
-
|
32
|
-
|
33
|
-
|
40
|
+
def _unpack_packet(
|
41
|
+
self, packet: Packet[Any] | None
|
42
|
+
) -> tuple[Packet[Any] | None, Any | None]:
|
43
|
+
if isinstance(packet, StartPacket):
|
44
|
+
return None, packet.chat_inputs
|
45
|
+
return packet, None
|
46
|
+
|
47
|
+
async def _packet_handler(
|
48
|
+
self,
|
49
|
+
proc: Processor[Any, Any, Any, CtxT],
|
50
|
+
pool: PacketPool[CtxT],
|
51
|
+
packet: Packet[Any],
|
52
|
+
ctx: RunContext[CtxT],
|
53
|
+
**run_kwargs: Any,
|
54
|
+
) -> None:
|
55
|
+
_in_packet, _chat_inputs = self._unpack_packet(packet)
|
56
|
+
out_packet = await proc.run(
|
57
|
+
chat_inputs=_chat_inputs, in_packet=_in_packet, ctx=ctx, **run_kwargs
|
58
|
+
)
|
59
|
+
await pool.post(out_packet)
|
60
|
+
|
61
|
+
async def _packet_handler_stream(
|
62
|
+
self,
|
63
|
+
proc: Processor[Any, Any, Any, CtxT],
|
64
|
+
pool: PacketPool[CtxT],
|
65
|
+
packet: Packet[Any],
|
66
|
+
ctx: RunContext[CtxT],
|
67
|
+
**run_kwargs: Any,
|
68
|
+
) -> None:
|
69
|
+
_in_packet, _chat_inputs = self._unpack_packet(packet)
|
34
70
|
|
35
|
-
|
71
|
+
out_packet: Packet[Any] | None = None
|
72
|
+
async for event in proc.run_stream(
|
73
|
+
chat_inputs=_chat_inputs, in_packet=_in_packet, ctx=ctx, **run_kwargs
|
74
|
+
):
|
75
|
+
if isinstance(event, ProcPacketOutputEvent):
|
76
|
+
out_packet = event.data
|
77
|
+
await pool.push_event(event)
|
36
78
|
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
79
|
+
assert out_packet is not None
|
80
|
+
|
81
|
+
await pool.post(out_packet)
|
82
|
+
|
83
|
+
async def run(
|
84
|
+
self,
|
85
|
+
chat_input: Any = "start",
|
86
|
+
**run_args: Any,
|
87
|
+
) -> Packet[OutT]:
|
88
|
+
async with PacketPool[CtxT]() as pool:
|
89
|
+
for proc in self._procs:
|
90
|
+
pool.register_packet_handler(
|
91
|
+
proc_name=proc.name,
|
92
|
+
handler=partial(self._packet_handler, proc, pool),
|
93
|
+
ctx=self._ctx,
|
94
|
+
**run_args,
|
95
|
+
)
|
96
|
+
await pool.post(
|
97
|
+
StartPacket[Any](
|
98
|
+
recipients=[self._entry_proc.name], chat_inputs=chat_input
|
99
|
+
)
|
100
|
+
)
|
101
|
+
return await pool.final_result()
|
102
|
+
|
103
|
+
async def run_stream(
|
104
|
+
self,
|
105
|
+
chat_input: Any = "start",
|
106
|
+
**run_args: Any,
|
107
|
+
) -> AsyncIterator[Event[Any]]:
|
108
|
+
async with PacketPool[CtxT]() as pool:
|
109
|
+
for proc in self._procs:
|
110
|
+
pool.register_packet_handler(
|
111
|
+
proc_name=proc.name,
|
112
|
+
handler=partial(self._packet_handler_stream, proc, pool),
|
113
|
+
ctx=self._ctx,
|
114
|
+
**run_args,
|
115
|
+
)
|
116
|
+
await pool.post(
|
117
|
+
StartPacket[Any](
|
118
|
+
recipients=[self._entry_proc.name], chat_inputs=chat_input
|
119
|
+
)
|
120
|
+
)
|
121
|
+
async for event in pool.stream_events():
|
122
|
+
if isinstance(
|
123
|
+
event, ProcPacketOutputEvent
|
124
|
+
) and event.data.recipients == [END_PROC_NAME]:
|
125
|
+
yield RunResultEvent(
|
126
|
+
data=event.data,
|
127
|
+
proc_name=event.proc_name,
|
128
|
+
call_id=event.call_id,
|
129
|
+
)
|
130
|
+
else:
|
131
|
+
yield event
|
grasp_agents/typing/events.py
CHANGED
@@ -136,16 +136,20 @@ class ProcPayloadOutputEvent(Event[Any], frozen=True):
|
|
136
136
|
|
137
137
|
|
138
138
|
class ProcPacketOutputEvent(Event[Packet[Any]], frozen=True):
|
139
|
-
type: Literal[EventType.PACKET_OUT] =
|
140
|
-
|
139
|
+
type: Literal[EventType.PACKET_OUT, EventType.WORKFLOW_RES, EventType.RUN_RES] = (
|
140
|
+
EventType.PACKET_OUT
|
141
|
+
)
|
142
|
+
source: Literal[
|
143
|
+
EventSourceType.PROC, EventSourceType.WORKFLOW, EventSourceType.RUN
|
144
|
+
] = EventSourceType.PROC
|
141
145
|
|
142
146
|
|
143
|
-
class WorkflowResultEvent(
|
147
|
+
class WorkflowResultEvent(ProcPacketOutputEvent, frozen=True):
|
144
148
|
type: Literal[EventType.WORKFLOW_RES] = EventType.WORKFLOW_RES
|
145
149
|
source: Literal[EventSourceType.WORKFLOW] = EventSourceType.WORKFLOW
|
146
150
|
|
147
151
|
|
148
|
-
class RunResultEvent(
|
152
|
+
class RunResultEvent(ProcPacketOutputEvent, frozen=True):
|
149
153
|
type: Literal[EventType.RUN_RES] = EventType.RUN_RES
|
150
154
|
source: Literal[EventSourceType.RUN] = EventSourceType.RUN
|
151
155
|
|
grasp_agents/typing/io.py
CHANGED
@@ -4,11 +4,11 @@ from logging import getLogger
|
|
4
4
|
from typing import Any, Generic, Protocol, TypeVar, cast, final
|
5
5
|
|
6
6
|
from ..errors import WorkflowConstructionError
|
7
|
-
from ..packet_pool import Packet
|
7
|
+
from ..packet_pool import Packet
|
8
8
|
from ..processor import Processor
|
9
9
|
from ..run_context import CtxT, RunContext
|
10
10
|
from ..typing.events import Event, ProcPacketOutputEvent, WorkflowResultEvent
|
11
|
-
from ..typing.io import InT,
|
11
|
+
from ..typing.io import InT, OutT, ProcName
|
12
12
|
from .workflow_processor import WorkflowProcessor
|
13
13
|
|
14
14
|
logger = getLogger(__name__)
|
@@ -25,15 +25,12 @@ class ExitWorkflowLoopHandler(Protocol[_OutT_contra, CtxT]):
|
|
25
25
|
) -> bool: ...
|
26
26
|
|
27
27
|
|
28
|
-
class LoopedWorkflow(
|
29
|
-
WorkflowProcessor[InT, OutT_co, CtxT], Generic[InT, OutT_co, CtxT]
|
30
|
-
):
|
28
|
+
class LoopedWorkflow(WorkflowProcessor[InT, OutT, CtxT], Generic[InT, OutT, CtxT]):
|
31
29
|
def __init__(
|
32
30
|
self,
|
33
31
|
name: ProcName,
|
34
32
|
subprocs: Sequence[Processor[Any, Any, Any, CtxT]],
|
35
|
-
exit_proc: Processor[Any,
|
36
|
-
packet_pool: PacketPool[CtxT] | None = None,
|
33
|
+
exit_proc: Processor[Any, OutT, Any, CtxT],
|
37
34
|
recipients: list[ProcName] | None = None,
|
38
35
|
max_retries: int = 0,
|
39
36
|
max_iterations: int = 10,
|
@@ -43,7 +40,6 @@ class LoopedWorkflow(
|
|
43
40
|
name=name,
|
44
41
|
start_proc=subprocs[0],
|
45
42
|
end_proc=exit_proc,
|
46
|
-
packet_pool=packet_pool,
|
47
43
|
recipients=recipients,
|
48
44
|
max_retries=max_retries,
|
49
45
|
)
|
@@ -64,24 +60,22 @@ class LoopedWorkflow(
|
|
64
60
|
|
65
61
|
self._max_iterations = max_iterations
|
66
62
|
|
67
|
-
self._exit_workflow_loop_impl: ExitWorkflowLoopHandler[
|
68
|
-
None
|
69
|
-
)
|
63
|
+
self._exit_workflow_loop_impl: ExitWorkflowLoopHandler[OutT, CtxT] | None = None
|
70
64
|
|
71
65
|
@property
|
72
66
|
def max_iterations(self) -> int:
|
73
67
|
return self._max_iterations
|
74
68
|
|
75
69
|
def exit_workflow_loop(
|
76
|
-
self, func: ExitWorkflowLoopHandler[
|
77
|
-
) -> ExitWorkflowLoopHandler[
|
70
|
+
self, func: ExitWorkflowLoopHandler[OutT, CtxT]
|
71
|
+
) -> ExitWorkflowLoopHandler[OutT, CtxT]:
|
78
72
|
self._exit_workflow_loop_impl = func
|
79
73
|
|
80
74
|
return func
|
81
75
|
|
82
76
|
def _exit_workflow_loop(
|
83
77
|
self,
|
84
|
-
out_packet: Packet[
|
78
|
+
out_packet: Packet[OutT],
|
85
79
|
*,
|
86
80
|
ctx: RunContext[CtxT] | None = None,
|
87
81
|
**kwargs: Any,
|
@@ -101,12 +95,12 @@ class LoopedWorkflow(
|
|
101
95
|
call_id: str | None = None,
|
102
96
|
forgetful: bool = False,
|
103
97
|
ctx: RunContext[CtxT] | None = None,
|
104
|
-
) -> Packet[
|
98
|
+
) -> Packet[OutT]:
|
105
99
|
call_id = self._generate_call_id(call_id)
|
106
100
|
|
107
101
|
packet = in_packet
|
108
102
|
num_iterations = 0
|
109
|
-
exit_packet: Packet[
|
103
|
+
exit_packet: Packet[OutT] | None = None
|
110
104
|
|
111
105
|
while True:
|
112
106
|
for subproc in self.subprocs:
|
@@ -121,7 +115,7 @@ class LoopedWorkflow(
|
|
121
115
|
|
122
116
|
if subproc is self._end_proc:
|
123
117
|
num_iterations += 1
|
124
|
-
exit_packet = cast("Packet[
|
118
|
+
exit_packet = cast("Packet[OutT]", packet)
|
125
119
|
if self._exit_workflow_loop(exit_packet, ctx=ctx):
|
126
120
|
return exit_packet
|
127
121
|
if num_iterations >= self._max_iterations:
|
@@ -149,38 +143,39 @@ class LoopedWorkflow(
|
|
149
143
|
|
150
144
|
packet = in_packet
|
151
145
|
num_iterations = 0
|
152
|
-
exit_packet: Packet[
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
146
|
+
exit_packet: Packet[OutT] | None = None
|
147
|
+
|
148
|
+
while True:
|
149
|
+
for subproc in self.subprocs:
|
150
|
+
async for event in subproc.run_stream(
|
151
|
+
chat_inputs=chat_inputs,
|
152
|
+
in_packet=packet,
|
153
|
+
in_args=in_args,
|
154
|
+
forgetful=forgetful,
|
155
|
+
call_id=f"{call_id}/{subproc.name}",
|
156
|
+
ctx=ctx,
|
157
|
+
):
|
158
|
+
if isinstance(event, ProcPacketOutputEvent):
|
159
|
+
packet = event.data
|
160
|
+
yield event
|
161
|
+
|
162
|
+
if subproc is self._end_proc:
|
163
|
+
num_iterations += 1
|
164
|
+
exit_packet = cast("Packet[OutT]", packet)
|
165
|
+
if self._exit_workflow_loop(exit_packet, ctx=ctx):
|
166
|
+
yield WorkflowResultEvent(
|
167
|
+
data=exit_packet, proc_name=self.name, call_id=call_id
|
168
|
+
)
|
169
|
+
return
|
170
|
+
if num_iterations >= self._max_iterations:
|
171
|
+
logger.info(
|
172
|
+
f"Max iterations reached ({self._max_iterations}). "
|
173
|
+
"Exiting loop."
|
174
|
+
)
|
175
|
+
yield WorkflowResultEvent(
|
176
|
+
data=exit_packet, proc_name=self.name, call_id=call_id
|
177
|
+
)
|
178
|
+
return
|
179
|
+
|
180
|
+
chat_inputs = None
|
181
|
+
in_args = None
|