grasp_agents 0.5.3__py3-none-any.whl → 0.5.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_agents/__init__.py +4 -6
- grasp_agents/errors.py +80 -18
- grasp_agents/llm_agent.py +106 -146
- grasp_agents/llm_agent_memory.py +1 -1
- grasp_agents/llm_policy_executor.py +17 -15
- grasp_agents/packet.py +23 -4
- grasp_agents/packet_pool.py +117 -50
- grasp_agents/printer.py +9 -5
- grasp_agents/processor.py +217 -166
- grasp_agents/prompt_builder.py +75 -138
- grasp_agents/run_context.py +3 -16
- grasp_agents/runner.py +110 -21
- grasp_agents/typing/events.py +8 -4
- grasp_agents/typing/io.py +1 -8
- grasp_agents/workflow/looped_workflow.py +13 -19
- grasp_agents/workflow/sequential_workflow.py +6 -10
- grasp_agents/workflow/workflow_processor.py +23 -16
- {grasp_agents-0.5.3.dist-info → grasp_agents-0.5.5.dist-info}/METADATA +1 -1
- {grasp_agents-0.5.3.dist-info → grasp_agents-0.5.5.dist-info}/RECORD +21 -22
- grasp_agents/comm_processor.py +0 -214
- {grasp_agents-0.5.3.dist-info → grasp_agents-0.5.5.dist-info}/WHEEL +0 -0
- {grasp_agents-0.5.3.dist-info → grasp_agents-0.5.5.dist-info}/licenses/LICENSE.md +0 -0
grasp_agents/prompt_builder.py
CHANGED
@@ -1,35 +1,26 @@
|
|
1
1
|
import json
|
2
|
-
from collections.abc import
|
3
|
-
from typing import ClassVar, Generic, Protocol, TypeAlias, TypeVar
|
2
|
+
from collections.abc import Sequence
|
3
|
+
from typing import ClassVar, Generic, Protocol, TypeAlias, TypeVar, final
|
4
4
|
|
5
5
|
from pydantic import BaseModel, TypeAdapter
|
6
6
|
|
7
|
-
from .errors import InputPromptBuilderError
|
7
|
+
from .errors import InputPromptBuilderError
|
8
8
|
from .generics_utils import AutoInstanceAttributesMixin
|
9
9
|
from .run_context import CtxT, RunContext
|
10
10
|
from .typing.content import Content, ImageData
|
11
|
-
from .typing.io import InT, LLMPrompt
|
11
|
+
from .typing.io import InT, LLMPrompt
|
12
12
|
from .typing.message import UserMessage
|
13
13
|
|
14
14
|
_InT_contra = TypeVar("_InT_contra", contravariant=True)
|
15
15
|
|
16
16
|
|
17
|
-
class
|
18
|
-
def __call__(
|
19
|
-
self,
|
20
|
-
sys_args: LLMPromptArgs | None,
|
21
|
-
*,
|
22
|
-
ctx: RunContext[CtxT] | None,
|
23
|
-
) -> str | None: ...
|
17
|
+
class SystemPromptBuilder(Protocol[CtxT]):
|
18
|
+
def __call__(self, ctx: RunContext[CtxT] | None) -> str | None: ...
|
24
19
|
|
25
20
|
|
26
|
-
class
|
21
|
+
class InputContentBuilder(Protocol[_InT_contra, CtxT]):
|
27
22
|
def __call__(
|
28
|
-
self,
|
29
|
-
*,
|
30
|
-
in_args: _InT_contra | None,
|
31
|
-
usr_args: LLMPromptArgs | None,
|
32
|
-
ctx: RunContext[CtxT] | None,
|
23
|
+
self, in_args: _InT_contra | None, *, ctx: RunContext[CtxT] | None
|
33
24
|
) -> Content: ...
|
34
25
|
|
35
26
|
|
@@ -40,134 +31,101 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
|
|
40
31
|
_generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {0: "_in_type"}
|
41
32
|
|
42
33
|
def __init__(
|
43
|
-
self,
|
44
|
-
agent_name: str,
|
45
|
-
sys_prompt_template: LLMPrompt | None,
|
46
|
-
in_prompt_template: LLMPrompt | None,
|
47
|
-
sys_args_schema: type[LLMPromptArgs] | None = None,
|
48
|
-
usr_args_schema: type[LLMPromptArgs] | None = None,
|
34
|
+
self, agent_name: str, sys_prompt: LLMPrompt | None, in_prompt: LLMPrompt | None
|
49
35
|
):
|
50
36
|
self._in_type: type[InT]
|
51
37
|
super().__init__()
|
52
38
|
|
53
39
|
self._agent_name = agent_name
|
54
|
-
self.
|
55
|
-
self.
|
56
|
-
self.
|
57
|
-
self.
|
58
|
-
self.make_system_prompt_impl: MakeSystemPromptHandler[CtxT] | None = None
|
59
|
-
self.make_input_content_impl: MakeInputContentHandler[InT, CtxT] | None = None
|
40
|
+
self.sys_prompt = sys_prompt
|
41
|
+
self.in_prompt = in_prompt
|
42
|
+
self.system_prompt_builder: SystemPromptBuilder[CtxT] | None = None
|
43
|
+
self.input_content_builder: InputContentBuilder[InT, CtxT] | None = None
|
60
44
|
|
61
45
|
self._in_args_type_adapter: TypeAdapter[InT] = TypeAdapter(self._in_type)
|
62
46
|
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
return None
|
68
|
-
|
69
|
-
val_sys_args = sys_args
|
70
|
-
if sys_args is not None:
|
71
|
-
if self.sys_args_schema is not None:
|
72
|
-
val_sys_args = self.sys_args_schema.model_validate(sys_args)
|
73
|
-
else:
|
74
|
-
raise SystemPromptBuilderError(
|
75
|
-
"System prompt template and arguments is set, but system arguments "
|
76
|
-
"schema is not provided."
|
77
|
-
)
|
47
|
+
@final
|
48
|
+
def build_system_prompt(self, ctx: RunContext[CtxT] | None = None) -> str | None:
|
49
|
+
if self.system_prompt_builder:
|
50
|
+
return self.system_prompt_builder(ctx=ctx)
|
78
51
|
|
79
|
-
|
80
|
-
return self.make_system_prompt_impl(sys_args=val_sys_args, ctx=ctx)
|
52
|
+
return self.sys_prompt
|
81
53
|
|
82
|
-
|
54
|
+
def _validate_input_args(self, in_args: InT) -> InT:
|
55
|
+
val_in_args = self._in_args_type_adapter.validate_python(in_args)
|
56
|
+
if isinstance(val_in_args, BaseModel):
|
57
|
+
has_image = self._has_image_data(val_in_args)
|
58
|
+
if has_image and self.in_prompt is None:
|
59
|
+
raise InputPromptBuilderError(
|
60
|
+
proc_name=self._agent_name,
|
61
|
+
message="BaseModel input arguments contain ImageData, "
|
62
|
+
"but input prompt template is not set "
|
63
|
+
f"[agent_name={self._agent_name}]. Cannot format input arguments.",
|
64
|
+
)
|
65
|
+
elif self.in_prompt is not None:
|
66
|
+
raise InputPromptBuilderError(
|
67
|
+
proc_name=self._agent_name,
|
68
|
+
message="Cannot use the input prompt template with "
|
69
|
+
f"non-BaseModel input arguments [agent_name={self._agent_name}]",
|
70
|
+
)
|
83
71
|
|
84
|
-
return
|
72
|
+
return val_in_args
|
85
73
|
|
86
|
-
|
74
|
+
@final
|
75
|
+
def _build_input_content(
|
87
76
|
self,
|
88
|
-
|
89
|
-
in_args: InT | None,
|
90
|
-
usr_args: LLMPromptArgs | None,
|
77
|
+
in_args: InT | None = None,
|
91
78
|
ctx: RunContext[CtxT] | None = None,
|
92
79
|
) -> Content:
|
93
|
-
val_in_args
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
if self.make_input_content_impl:
|
98
|
-
return self.make_input_content_impl(
|
99
|
-
in_args=val_in_args, usr_args=val_usr_args, ctx=ctx
|
100
|
-
)
|
80
|
+
val_in_args = in_args
|
81
|
+
if in_args is not None:
|
82
|
+
val_in_args = self._validate_input_args(in_args=in_args)
|
101
83
|
|
102
|
-
|
103
|
-
|
104
|
-
return Content.from_text(combined_args)
|
84
|
+
if self.input_content_builder:
|
85
|
+
return self.input_content_builder(in_args=val_in_args, ctx=ctx)
|
105
86
|
|
106
|
-
if
|
107
|
-
|
108
|
-
self.
|
87
|
+
if val_in_args is None:
|
88
|
+
raise InputPromptBuilderError(
|
89
|
+
proc_name=self._agent_name,
|
90
|
+
message="Input arguments are not provided, "
|
91
|
+
f"but input content is required [agent_name={self._agent_name}]",
|
109
92
|
)
|
110
93
|
|
111
|
-
|
94
|
+
if issubclass(self._in_type, BaseModel) and isinstance(val_in_args, BaseModel):
|
95
|
+
val_in_args_map = self._format_pydantic_prompt_args(val_in_args)
|
96
|
+
if self.in_prompt is not None:
|
97
|
+
return Content.from_formatted_prompt(self.in_prompt, **val_in_args_map)
|
98
|
+
return Content.from_text(json.dumps(val_in_args_map, indent=2))
|
99
|
+
|
100
|
+
fmt_in_args = self._in_args_type_adapter.dump_json(
|
101
|
+
val_in_args, indent=2, warnings="error"
|
102
|
+
).decode("utf-8")
|
103
|
+
return Content.from_text(fmt_in_args)
|
112
104
|
|
113
|
-
|
105
|
+
@final
|
106
|
+
def build_input_message(
|
114
107
|
self,
|
115
108
|
chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
|
116
109
|
in_args: InT | None = None,
|
117
|
-
usr_args: LLMPromptArgs | None = None,
|
118
110
|
ctx: RunContext[CtxT] | None = None,
|
119
111
|
) -> UserMessage | None:
|
120
|
-
if chat_inputs is None and in_args is
|
121
|
-
|
112
|
+
if chat_inputs is not None and in_args is not None:
|
113
|
+
raise InputPromptBuilderError(
|
114
|
+
proc_name=self._agent_name,
|
115
|
+
message="Cannot use both chat inputs and input arguments "
|
116
|
+
f"at the same time [agent_name={self._agent_name}]",
|
117
|
+
)
|
122
118
|
|
123
119
|
if chat_inputs:
|
124
120
|
if isinstance(chat_inputs, LLMPrompt):
|
125
121
|
return UserMessage.from_text(chat_inputs, name=self._agent_name)
|
126
122
|
return UserMessage.from_content_parts(chat_inputs, name=self._agent_name)
|
127
123
|
|
128
|
-
|
129
|
-
in_args=in_args,
|
124
|
+
return UserMessage(
|
125
|
+
content=self._build_input_content(in_args=in_args, ctx=ctx),
|
126
|
+
name=self._agent_name,
|
130
127
|
)
|
131
128
|
|
132
|
-
return UserMessage(content=in_content, name=self._agent_name)
|
133
|
-
|
134
|
-
def _validate_prompt_args(
|
135
|
-
self,
|
136
|
-
*,
|
137
|
-
in_args: InT | None,
|
138
|
-
usr_args: LLMPromptArgs | None,
|
139
|
-
) -> tuple[InT | None, LLMPromptArgs | None]:
|
140
|
-
val_usr_args = usr_args
|
141
|
-
if usr_args is not None:
|
142
|
-
if self.in_prompt_template is None:
|
143
|
-
raise InputPromptBuilderError(
|
144
|
-
"Input prompt template is not set, but user arguments are provided."
|
145
|
-
)
|
146
|
-
if self.usr_args_schema is None:
|
147
|
-
raise InputPromptBuilderError(
|
148
|
-
"User arguments schema is not provided, but user arguments are "
|
149
|
-
"given."
|
150
|
-
)
|
151
|
-
val_usr_args = self.usr_args_schema.model_validate(usr_args)
|
152
|
-
|
153
|
-
val_in_args = in_args
|
154
|
-
if in_args is not None:
|
155
|
-
val_in_args = self._in_args_type_adapter.validate_python(in_args)
|
156
|
-
if isinstance(val_in_args, BaseModel):
|
157
|
-
has_image = self._has_image_data(val_in_args)
|
158
|
-
if has_image and self.in_prompt_template is None:
|
159
|
-
raise InputPromptBuilderError(
|
160
|
-
"BaseModel input arguments contain ImageData, but input prompt "
|
161
|
-
"template is not set. Cannot format input arguments."
|
162
|
-
)
|
163
|
-
elif self.in_prompt_template is not None:
|
164
|
-
raise InputPromptBuilderError(
|
165
|
-
"Cannot use the input prompt template with "
|
166
|
-
"non-BaseModel input arguments."
|
167
|
-
)
|
168
|
-
|
169
|
-
return val_in_args, val_usr_args
|
170
|
-
|
171
129
|
@staticmethod
|
172
130
|
def _has_image_data(inp: BaseModel) -> bool:
|
173
131
|
contains_image_data = False
|
@@ -180,39 +138,18 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
|
|
180
138
|
@staticmethod
|
181
139
|
def _format_pydantic_prompt_args(inp: BaseModel) -> dict[str, PromptArgumentType]:
|
182
140
|
formatted_args: dict[str, PromptArgumentType] = {}
|
183
|
-
for
|
184
|
-
if
|
141
|
+
for field_name, field_info in type(inp).model_fields.items():
|
142
|
+
if field_info.exclude:
|
185
143
|
continue
|
186
144
|
|
187
|
-
val = getattr(inp,
|
145
|
+
val = getattr(inp, field_name)
|
188
146
|
if isinstance(val, (int, str, bool, ImageData)):
|
189
|
-
formatted_args[
|
147
|
+
formatted_args[field_name] = val
|
190
148
|
else:
|
191
|
-
formatted_args[
|
149
|
+
formatted_args[field_name] = (
|
192
150
|
TypeAdapter(type(val)) # type: ignore[return-value]
|
193
151
|
.dump_json(val, indent=2, warnings="error")
|
194
152
|
.decode("utf-8")
|
195
153
|
)
|
196
154
|
|
197
155
|
return formatted_args
|
198
|
-
|
199
|
-
def _combine_args(
|
200
|
-
self, *, in_args: InT | None, usr_args: LLMPromptArgs | None
|
201
|
-
) -> Mapping[str, PromptArgumentType] | str:
|
202
|
-
fmt_usr_args = self._format_pydantic_prompt_args(usr_args) if usr_args else {}
|
203
|
-
|
204
|
-
if in_args is None:
|
205
|
-
return fmt_usr_args
|
206
|
-
|
207
|
-
if isinstance(in_args, BaseModel):
|
208
|
-
fmt_in_args = self._format_pydantic_prompt_args(in_args)
|
209
|
-
return fmt_in_args | fmt_usr_args
|
210
|
-
|
211
|
-
combined_args_str = self._in_args_type_adapter.dump_json(
|
212
|
-
in_args, indent=2, warnings="error"
|
213
|
-
).decode("utf-8")
|
214
|
-
if usr_args is not None:
|
215
|
-
fmt_usr_args_str = usr_args.model_dump_json(indent=2, warnings="error")
|
216
|
-
combined_args_str += "\n" + fmt_usr_args_str
|
217
|
-
|
218
|
-
return combined_args_str
|
grasp_agents/run_context.py
CHANGED
@@ -6,39 +6,26 @@ from pydantic import BaseModel, ConfigDict, Field
|
|
6
6
|
from grasp_agents.typing.completion import Completion
|
7
7
|
|
8
8
|
from .printer import ColoringMode, Printer
|
9
|
-
from .typing.io import
|
9
|
+
from .typing.io import ProcName
|
10
10
|
from .usage_tracker import UsageTracker
|
11
11
|
|
12
|
-
|
13
|
-
class RunArgs(BaseModel):
|
14
|
-
sys: LLMPromptArgs | None = None
|
15
|
-
usr: LLMPromptArgs | None = None
|
16
|
-
|
17
|
-
model_config = ConfigDict(extra="forbid")
|
18
|
-
|
19
|
-
|
20
12
|
CtxT = TypeVar("CtxT")
|
21
13
|
|
22
14
|
|
23
15
|
class RunContext(BaseModel, Generic[CtxT]):
|
24
16
|
state: CtxT | None = None
|
25
17
|
|
26
|
-
run_args: dict[ProcName, RunArgs] = Field(default_factory=dict)
|
27
|
-
|
28
|
-
is_streaming: bool = False
|
29
|
-
result: Any | None = None
|
30
|
-
|
31
18
|
completions: dict[ProcName, list[Completion]] = Field(
|
32
19
|
default_factory=lambda: defaultdict(list)
|
33
20
|
)
|
34
21
|
usage_tracker: UsageTracker = Field(default_factory=UsageTracker)
|
35
22
|
|
36
23
|
printer: Printer | None = None
|
37
|
-
|
24
|
+
log_messages: bool = False
|
38
25
|
color_messages_by: ColoringMode = "role"
|
39
26
|
|
40
27
|
def model_post_init(self, context: Any) -> None: # noqa: ARG002
|
41
|
-
if self.
|
28
|
+
if self.log_messages:
|
42
29
|
self.printer = Printer(color_by=self.color_messages_by)
|
43
30
|
|
44
31
|
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
|
grasp_agents/runner.py
CHANGED
@@ -1,42 +1,131 @@
|
|
1
1
|
from collections.abc import AsyncIterator, Sequence
|
2
|
+
from functools import partial
|
2
3
|
from typing import Any, Generic
|
3
4
|
|
4
|
-
from .
|
5
|
+
from .errors import RunnerError
|
6
|
+
from .packet import Packet, StartPacket
|
7
|
+
from .packet_pool import END_PROC_NAME, PacketPool
|
8
|
+
from .processor import Processor
|
5
9
|
from .run_context import CtxT, RunContext
|
6
|
-
from .typing.events import Event
|
10
|
+
from .typing.events import Event, ProcPacketOutputEvent, RunResultEvent
|
11
|
+
from .typing.io import OutT
|
7
12
|
|
8
13
|
|
9
|
-
class Runner(Generic[CtxT]):
|
14
|
+
class Runner(Generic[OutT, CtxT]):
|
10
15
|
def __init__(
|
11
16
|
self,
|
12
|
-
|
13
|
-
procs: Sequence[
|
17
|
+
entry_proc: Processor[Any, Any, Any, CtxT],
|
18
|
+
procs: Sequence[Processor[Any, Any, Any, CtxT]],
|
14
19
|
ctx: RunContext[CtxT] | None = None,
|
15
20
|
) -> None:
|
16
|
-
if
|
17
|
-
raise
|
18
|
-
f"
|
21
|
+
if entry_proc not in procs:
|
22
|
+
raise RunnerError(
|
23
|
+
f"Entry processor {entry_proc.name} must be in the list of processors: "
|
19
24
|
f"{', '.join(proc.name for proc in procs)}"
|
20
25
|
)
|
21
|
-
|
26
|
+
if sum(1 for proc in procs if END_PROC_NAME in (proc.recipients or [])) != 1:
|
27
|
+
raise RunnerError(
|
28
|
+
"There must be exactly one processor with recipient 'END'."
|
29
|
+
)
|
30
|
+
|
31
|
+
self._entry_proc = entry_proc
|
22
32
|
self._procs = procs
|
23
33
|
self._ctx = ctx or RunContext[CtxT]()
|
34
|
+
self._packet_pool: PacketPool[CtxT] = PacketPool()
|
24
35
|
|
25
36
|
@property
|
26
37
|
def ctx(self) -> RunContext[CtxT]:
|
27
38
|
return self._ctx
|
28
39
|
|
29
|
-
|
30
|
-
self
|
31
|
-
|
32
|
-
|
33
|
-
|
40
|
+
def _unpack_packet(
|
41
|
+
self, packet: Packet[Any] | None
|
42
|
+
) -> tuple[Packet[Any] | None, Any | None]:
|
43
|
+
if isinstance(packet, StartPacket):
|
44
|
+
return None, packet.chat_inputs
|
45
|
+
return packet, None
|
46
|
+
|
47
|
+
async def _packet_handler(
|
48
|
+
self,
|
49
|
+
proc: Processor[Any, Any, Any, CtxT],
|
50
|
+
pool: PacketPool[CtxT],
|
51
|
+
packet: Packet[Any],
|
52
|
+
ctx: RunContext[CtxT],
|
53
|
+
**run_kwargs: Any,
|
54
|
+
) -> None:
|
55
|
+
_in_packet, _chat_inputs = self._unpack_packet(packet)
|
56
|
+
out_packet = await proc.run(
|
57
|
+
chat_inputs=_chat_inputs, in_packet=_in_packet, ctx=ctx, **run_kwargs
|
58
|
+
)
|
59
|
+
await pool.post(out_packet)
|
60
|
+
|
61
|
+
async def _packet_handler_stream(
|
62
|
+
self,
|
63
|
+
proc: Processor[Any, Any, Any, CtxT],
|
64
|
+
pool: PacketPool[CtxT],
|
65
|
+
packet: Packet[Any],
|
66
|
+
ctx: RunContext[CtxT],
|
67
|
+
**run_kwargs: Any,
|
68
|
+
) -> None:
|
69
|
+
_in_packet, _chat_inputs = self._unpack_packet(packet)
|
34
70
|
|
35
|
-
|
71
|
+
out_packet: Packet[Any] | None = None
|
72
|
+
async for event in proc.run_stream(
|
73
|
+
chat_inputs=_chat_inputs, in_packet=_in_packet, ctx=ctx, **run_kwargs
|
74
|
+
):
|
75
|
+
if isinstance(event, ProcPacketOutputEvent):
|
76
|
+
out_packet = event.data
|
77
|
+
await pool.push_event(event)
|
36
78
|
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
79
|
+
assert out_packet is not None
|
80
|
+
|
81
|
+
await pool.post(out_packet)
|
82
|
+
|
83
|
+
async def run(
|
84
|
+
self,
|
85
|
+
chat_input: Any = "start",
|
86
|
+
**run_args: Any,
|
87
|
+
) -> Packet[OutT]:
|
88
|
+
async with PacketPool[CtxT]() as pool:
|
89
|
+
for proc in self._procs:
|
90
|
+
pool.register_packet_handler(
|
91
|
+
proc_name=proc.name,
|
92
|
+
handler=partial(self._packet_handler, proc, pool),
|
93
|
+
ctx=self._ctx,
|
94
|
+
**run_args,
|
95
|
+
)
|
96
|
+
await pool.post(
|
97
|
+
StartPacket[Any](
|
98
|
+
recipients=[self._entry_proc.name], chat_inputs=chat_input
|
99
|
+
)
|
100
|
+
)
|
101
|
+
return await pool.final_result()
|
102
|
+
|
103
|
+
async def run_stream(
|
104
|
+
self,
|
105
|
+
chat_input: Any = "start",
|
106
|
+
**run_args: Any,
|
107
|
+
) -> AsyncIterator[Event[Any]]:
|
108
|
+
async with PacketPool[CtxT]() as pool:
|
109
|
+
for proc in self._procs:
|
110
|
+
pool.register_packet_handler(
|
111
|
+
proc_name=proc.name,
|
112
|
+
handler=partial(self._packet_handler_stream, proc, pool),
|
113
|
+
ctx=self._ctx,
|
114
|
+
**run_args,
|
115
|
+
)
|
116
|
+
await pool.post(
|
117
|
+
StartPacket[Any](
|
118
|
+
recipients=[self._entry_proc.name], chat_inputs=chat_input
|
119
|
+
)
|
120
|
+
)
|
121
|
+
async for event in pool.stream_events():
|
122
|
+
if isinstance(
|
123
|
+
event, ProcPacketOutputEvent
|
124
|
+
) and event.data.recipients == [END_PROC_NAME]:
|
125
|
+
yield RunResultEvent(
|
126
|
+
data=event.data,
|
127
|
+
proc_name=event.proc_name,
|
128
|
+
call_id=event.call_id,
|
129
|
+
)
|
130
|
+
else:
|
131
|
+
yield event
|
grasp_agents/typing/events.py
CHANGED
@@ -136,16 +136,20 @@ class ProcPayloadOutputEvent(Event[Any], frozen=True):
|
|
136
136
|
|
137
137
|
|
138
138
|
class ProcPacketOutputEvent(Event[Packet[Any]], frozen=True):
|
139
|
-
type: Literal[EventType.PACKET_OUT] =
|
140
|
-
|
139
|
+
type: Literal[EventType.PACKET_OUT, EventType.WORKFLOW_RES, EventType.RUN_RES] = (
|
140
|
+
EventType.PACKET_OUT
|
141
|
+
)
|
142
|
+
source: Literal[
|
143
|
+
EventSourceType.PROC, EventSourceType.WORKFLOW, EventSourceType.RUN
|
144
|
+
] = EventSourceType.PROC
|
141
145
|
|
142
146
|
|
143
|
-
class WorkflowResultEvent(
|
147
|
+
class WorkflowResultEvent(ProcPacketOutputEvent, frozen=True):
|
144
148
|
type: Literal[EventType.WORKFLOW_RES] = EventType.WORKFLOW_RES
|
145
149
|
source: Literal[EventSourceType.WORKFLOW] = EventSourceType.WORKFLOW
|
146
150
|
|
147
151
|
|
148
|
-
class RunResultEvent(
|
152
|
+
class RunResultEvent(ProcPacketOutputEvent, frozen=True):
|
149
153
|
type: Literal[EventType.RUN_RES] = EventType.RUN_RES
|
150
154
|
source: Literal[EventSourceType.RUN] = EventSourceType.RUN
|
151
155
|
|
grasp_agents/typing/io.py
CHANGED
@@ -1,16 +1,9 @@
|
|
1
1
|
from typing import TypeAlias, TypeVar
|
2
2
|
|
3
|
-
from pydantic import BaseModel
|
4
|
-
|
5
3
|
ProcName: TypeAlias = str
|
6
4
|
|
7
5
|
|
8
6
|
InT = TypeVar("InT")
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
class LLMPromptArgs(BaseModel):
|
13
|
-
pass
|
14
|
-
|
7
|
+
OutT = TypeVar("OutT")
|
15
8
|
|
16
9
|
LLMPrompt: TypeAlias = str
|
@@ -4,11 +4,11 @@ from logging import getLogger
|
|
4
4
|
from typing import Any, Generic, Protocol, TypeVar, cast, final
|
5
5
|
|
6
6
|
from ..errors import WorkflowConstructionError
|
7
|
-
from ..packet_pool import Packet
|
7
|
+
from ..packet_pool import Packet
|
8
8
|
from ..processor import Processor
|
9
9
|
from ..run_context import CtxT, RunContext
|
10
10
|
from ..typing.events import Event, ProcPacketOutputEvent, WorkflowResultEvent
|
11
|
-
from ..typing.io import InT,
|
11
|
+
from ..typing.io import InT, OutT, ProcName
|
12
12
|
from .workflow_processor import WorkflowProcessor
|
13
13
|
|
14
14
|
logger = getLogger(__name__)
|
@@ -25,15 +25,12 @@ class ExitWorkflowLoopHandler(Protocol[_OutT_contra, CtxT]):
|
|
25
25
|
) -> bool: ...
|
26
26
|
|
27
27
|
|
28
|
-
class LoopedWorkflow(
|
29
|
-
WorkflowProcessor[InT, OutT_co, CtxT], Generic[InT, OutT_co, CtxT]
|
30
|
-
):
|
28
|
+
class LoopedWorkflow(WorkflowProcessor[InT, OutT, CtxT], Generic[InT, OutT, CtxT]):
|
31
29
|
def __init__(
|
32
30
|
self,
|
33
31
|
name: ProcName,
|
34
32
|
subprocs: Sequence[Processor[Any, Any, Any, CtxT]],
|
35
|
-
exit_proc: Processor[Any,
|
36
|
-
packet_pool: PacketPool[CtxT] | None = None,
|
33
|
+
exit_proc: Processor[Any, OutT, Any, CtxT],
|
37
34
|
recipients: list[ProcName] | None = None,
|
38
35
|
max_retries: int = 0,
|
39
36
|
max_iterations: int = 10,
|
@@ -43,7 +40,6 @@ class LoopedWorkflow(
|
|
43
40
|
name=name,
|
44
41
|
start_proc=subprocs[0],
|
45
42
|
end_proc=exit_proc,
|
46
|
-
packet_pool=packet_pool,
|
47
43
|
recipients=recipients,
|
48
44
|
max_retries=max_retries,
|
49
45
|
)
|
@@ -64,24 +60,22 @@ class LoopedWorkflow(
|
|
64
60
|
|
65
61
|
self._max_iterations = max_iterations
|
66
62
|
|
67
|
-
self._exit_workflow_loop_impl: ExitWorkflowLoopHandler[
|
68
|
-
None
|
69
|
-
)
|
63
|
+
self._exit_workflow_loop_impl: ExitWorkflowLoopHandler[OutT, CtxT] | None = None
|
70
64
|
|
71
65
|
@property
|
72
66
|
def max_iterations(self) -> int:
|
73
67
|
return self._max_iterations
|
74
68
|
|
75
69
|
def exit_workflow_loop(
|
76
|
-
self, func: ExitWorkflowLoopHandler[
|
77
|
-
) -> ExitWorkflowLoopHandler[
|
70
|
+
self, func: ExitWorkflowLoopHandler[OutT, CtxT]
|
71
|
+
) -> ExitWorkflowLoopHandler[OutT, CtxT]:
|
78
72
|
self._exit_workflow_loop_impl = func
|
79
73
|
|
80
74
|
return func
|
81
75
|
|
82
76
|
def _exit_workflow_loop(
|
83
77
|
self,
|
84
|
-
out_packet: Packet[
|
78
|
+
out_packet: Packet[OutT],
|
85
79
|
*,
|
86
80
|
ctx: RunContext[CtxT] | None = None,
|
87
81
|
**kwargs: Any,
|
@@ -101,12 +95,12 @@ class LoopedWorkflow(
|
|
101
95
|
call_id: str | None = None,
|
102
96
|
forgetful: bool = False,
|
103
97
|
ctx: RunContext[CtxT] | None = None,
|
104
|
-
) -> Packet[
|
98
|
+
) -> Packet[OutT]:
|
105
99
|
call_id = self._generate_call_id(call_id)
|
106
100
|
|
107
101
|
packet = in_packet
|
108
102
|
num_iterations = 0
|
109
|
-
exit_packet: Packet[
|
103
|
+
exit_packet: Packet[OutT] | None = None
|
110
104
|
|
111
105
|
while True:
|
112
106
|
for subproc in self.subprocs:
|
@@ -121,7 +115,7 @@ class LoopedWorkflow(
|
|
121
115
|
|
122
116
|
if subproc is self._end_proc:
|
123
117
|
num_iterations += 1
|
124
|
-
exit_packet = cast("Packet[
|
118
|
+
exit_packet = cast("Packet[OutT]", packet)
|
125
119
|
if self._exit_workflow_loop(exit_packet, ctx=ctx):
|
126
120
|
return exit_packet
|
127
121
|
if num_iterations >= self._max_iterations:
|
@@ -149,7 +143,7 @@ class LoopedWorkflow(
|
|
149
143
|
|
150
144
|
packet = in_packet
|
151
145
|
num_iterations = 0
|
152
|
-
exit_packet: Packet[
|
146
|
+
exit_packet: Packet[OutT] | None = None
|
153
147
|
|
154
148
|
while True:
|
155
149
|
for subproc in self.subprocs:
|
@@ -167,7 +161,7 @@ class LoopedWorkflow(
|
|
167
161
|
|
168
162
|
if subproc is self._end_proc:
|
169
163
|
num_iterations += 1
|
170
|
-
exit_packet = cast("Packet[
|
164
|
+
exit_packet = cast("Packet[OutT]", packet)
|
171
165
|
if self._exit_workflow_loop(exit_packet, ctx=ctx):
|
172
166
|
yield WorkflowResultEvent(
|
173
167
|
data=exit_packet, proc_name=self.name, call_id=call_id
|