grasp_agents 0.3.10__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_agents/cloud_llm.py +70 -77
- grasp_agents/comm_processor.py +21 -11
- grasp_agents/errors.py +34 -0
- grasp_agents/http_client.py +7 -5
- grasp_agents/llm.py +3 -9
- grasp_agents/llm_agent.py +92 -103
- grasp_agents/llm_agent_memory.py +36 -27
- grasp_agents/llm_policy_executor.py +66 -63
- grasp_agents/memory.py +3 -1
- grasp_agents/openai/completion_chunk_converters.py +4 -3
- grasp_agents/openai/openai_llm.py +14 -20
- grasp_agents/openai/tool_converters.py +0 -1
- grasp_agents/packet_pool.py +1 -1
- grasp_agents/printer.py +6 -6
- grasp_agents/processor.py +182 -48
- grasp_agents/prompt_builder.py +41 -55
- grasp_agents/run_context.py +1 -5
- grasp_agents/typing/completion_chunk.py +10 -5
- grasp_agents/typing/content.py +2 -2
- grasp_agents/typing/io.py +4 -4
- grasp_agents/typing/message.py +3 -6
- grasp_agents/typing/tool.py +5 -23
- grasp_agents/usage_tracker.py +2 -4
- grasp_agents/utils.py +37 -15
- grasp_agents/workflow/looped_workflow.py +14 -9
- grasp_agents/workflow/sequential_workflow.py +11 -6
- grasp_agents/workflow/workflow_processor.py +30 -13
- {grasp_agents-0.3.10.dist-info → grasp_agents-0.4.0.dist-info}/METADATA +2 -1
- grasp_agents-0.4.0.dist-info/RECORD +50 -0
- grasp_agents/message_history.py +0 -140
- grasp_agents/workflow/parallel_processor.py +0 -95
- grasp_agents-0.3.10.dist-info/RECORD +0 -51
- {grasp_agents-0.3.10.dist-info → grasp_agents-0.4.0.dist-info}/WHEEL +0 -0
- {grasp_agents-0.3.10.dist-info → grasp_agents-0.4.0.dist-info}/licenses/LICENSE.md +0 -0
grasp_agents/processor.py
CHANGED
@@ -1,40 +1,64 @@
|
|
1
|
+
import asyncio
|
2
|
+
import logging
|
1
3
|
from abc import ABC
|
2
4
|
from collections.abc import AsyncIterator, Sequence
|
3
5
|
from typing import Any, ClassVar, Generic, cast, final
|
6
|
+
from uuid import uuid4
|
4
7
|
|
5
8
|
from pydantic import BaseModel, TypeAdapter
|
9
|
+
from tenacity import RetryCallState, retry, stop_after_attempt, wait_random_exponential
|
6
10
|
|
11
|
+
from .errors import InputValidationError
|
7
12
|
from .generics_utils import AutoInstanceAttributesMixin
|
13
|
+
from .memory import MemT
|
8
14
|
from .packet import Packet
|
9
15
|
from .run_context import CtxT, RunContext
|
10
16
|
from .typing.events import Event, PacketEvent, ProcOutputEvent
|
11
|
-
from .typing.io import
|
17
|
+
from .typing.io import InT, OutT_co, ProcName
|
12
18
|
from .typing.tool import BaseTool
|
13
19
|
|
20
|
+
logger = logging.getLogger(__name__)
|
14
21
|
|
15
|
-
|
16
|
-
|
17
|
-
)
|
22
|
+
|
23
|
+
def retry_error_callback(retry_state: RetryCallState) -> None:
|
24
|
+
exception = retry_state.outcome.exception() if retry_state.outcome else None
|
25
|
+
if exception:
|
26
|
+
if retry_state.attempt_number == 1:
|
27
|
+
logger.warning(f"\nParallel run failed:\n{exception}")
|
28
|
+
if retry_state.attempt_number > 1:
|
29
|
+
logger.warning(f"\nParallel run failed after retrying:\n{exception}")
|
30
|
+
|
31
|
+
|
32
|
+
def retry_before_sleep_callback(retry_state: RetryCallState) -> None:
|
33
|
+
exception = retry_state.outcome.exception() if retry_state.outcome else None
|
34
|
+
logger.info(
|
35
|
+
f"\nRetrying parallel run (attempt {retry_state.attempt_number}):\n{exception}"
|
36
|
+
)
|
37
|
+
|
38
|
+
|
39
|
+
class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, CtxT]):
|
18
40
|
_generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
|
19
41
|
0: "_in_type",
|
20
42
|
1: "_out_type",
|
21
43
|
}
|
22
44
|
|
23
|
-
def __init__(
|
24
|
-
self
|
45
|
+
def __init__(
|
46
|
+
self, name: ProcName, num_par_run_retries: int = 0, **kwargs: Any
|
47
|
+
) -> None:
|
48
|
+
self._in_type: type[InT]
|
25
49
|
self._out_type: type[OutT_co]
|
26
50
|
|
27
51
|
super().__init__()
|
28
52
|
|
29
|
-
self._in_type_adapter: TypeAdapter[
|
53
|
+
self._in_type_adapter: TypeAdapter[InT] = TypeAdapter(self._in_type)
|
30
54
|
self._out_type_adapter: TypeAdapter[OutT_co] = TypeAdapter(self._out_type)
|
31
55
|
|
32
56
|
self._name: ProcName = name
|
33
|
-
self._memory:
|
57
|
+
self._memory: MemT
|
58
|
+
self._num_par_run_retries: int = num_par_run_retries
|
34
59
|
|
35
60
|
@property
|
36
|
-
def in_type(self) -> type[
|
37
|
-
# Exposing the type of a contravariant variable only, should be type safe
|
61
|
+
def in_type(self) -> type[InT]:
|
38
62
|
return self._in_type
|
39
63
|
|
40
64
|
@property
|
@@ -46,48 +70,73 @@ class Processor(
|
|
46
70
|
return self._name
|
47
71
|
|
48
72
|
@property
|
49
|
-
def memory(self) ->
|
73
|
+
def memory(self) -> MemT:
|
50
74
|
return self._memory
|
51
75
|
|
52
|
-
|
76
|
+
@property
|
77
|
+
def num_par_run_retries(self) -> int:
|
78
|
+
return self._num_par_run_retries
|
79
|
+
|
80
|
+
def _validate_and_resolve_single_input(
|
53
81
|
self,
|
54
82
|
chat_inputs: Any | None = None,
|
55
|
-
in_packet: Packet[
|
56
|
-
in_args:
|
57
|
-
) ->
|
83
|
+
in_packet: Packet[InT] | None = None,
|
84
|
+
in_args: InT | None = None,
|
85
|
+
) -> InT | None:
|
58
86
|
multiple_inputs_err_message = (
|
59
87
|
"Only one of chat_inputs, in_args, or in_message must be provided."
|
60
88
|
)
|
61
89
|
if chat_inputs is not None and in_args is not None:
|
62
|
-
raise
|
90
|
+
raise InputValidationError(multiple_inputs_err_message)
|
63
91
|
if chat_inputs is not None and in_packet is not None:
|
64
|
-
raise
|
92
|
+
raise InputValidationError(multiple_inputs_err_message)
|
65
93
|
if in_args is not None and in_packet is not None:
|
66
|
-
raise
|
94
|
+
raise InputValidationError(multiple_inputs_err_message)
|
67
95
|
|
68
|
-
resolved_in_args: Sequence[InT_contra] | None = None
|
69
96
|
if in_packet is not None:
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
resolved_in_args = cast("Sequence[InT_contra]", in_args)
|
97
|
+
if len(in_packet.payloads) != 1:
|
98
|
+
raise InputValidationError(
|
99
|
+
"Single input runs require exactly one payload in in_packet."
|
100
|
+
)
|
101
|
+
return in_packet.payloads[0]
|
102
|
+
return in_args
|
77
103
|
|
78
|
-
|
104
|
+
def _validate_and_resolve_parallel_inputs(
|
105
|
+
self,
|
106
|
+
chat_inputs: Any | None,
|
107
|
+
in_packet: Packet[InT] | None,
|
108
|
+
in_args: Sequence[InT] | None,
|
109
|
+
) -> Sequence[InT]:
|
110
|
+
if chat_inputs is not None:
|
111
|
+
raise InputValidationError(
|
112
|
+
"chat_inputs are not supported in parallel runs. "
|
113
|
+
"Use in_packet or in_args."
|
114
|
+
)
|
115
|
+
if in_packet is not None:
|
116
|
+
if not in_packet.payloads:
|
117
|
+
raise InputValidationError(
|
118
|
+
"Parallel runs require at least one input payload in in_packet."
|
119
|
+
)
|
120
|
+
return in_packet.payloads
|
121
|
+
if in_args is not None:
|
122
|
+
return in_args
|
123
|
+
raise InputValidationError(
|
124
|
+
"Parallel runs require either in_packet or in_args to be provided."
|
125
|
+
)
|
79
126
|
|
80
127
|
async def _process(
|
81
128
|
self,
|
82
129
|
chat_inputs: Any | None = None,
|
83
130
|
*,
|
84
|
-
in_args:
|
85
|
-
|
131
|
+
in_args: InT | None = None,
|
132
|
+
memory: MemT,
|
133
|
+
run_id: str,
|
86
134
|
ctx: RunContext[CtxT] | None = None,
|
87
135
|
) -> Sequence[OutT_co]:
|
88
|
-
|
89
|
-
|
90
|
-
|
136
|
+
if in_args is None:
|
137
|
+
raise InputValidationError(
|
138
|
+
"Default implementation of _process requires in_args"
|
139
|
+
)
|
91
140
|
|
92
141
|
return cast("Sequence[OutT_co]", in_args)
|
93
142
|
|
@@ -95,13 +144,15 @@ class Processor(
|
|
95
144
|
self,
|
96
145
|
chat_inputs: Any | None = None,
|
97
146
|
*,
|
98
|
-
in_args:
|
99
|
-
|
147
|
+
in_args: InT | None = None,
|
148
|
+
memory: MemT,
|
149
|
+
run_id: str,
|
100
150
|
ctx: RunContext[CtxT] | None = None,
|
101
151
|
) -> AsyncIterator[Event[Any]]:
|
102
|
-
|
103
|
-
|
104
|
-
|
152
|
+
if in_args is None:
|
153
|
+
raise InputValidationError(
|
154
|
+
"Default implementation of _process requires in_args"
|
155
|
+
)
|
105
156
|
outputs = cast("Sequence[OutT_co]", in_args)
|
106
157
|
for out in outputs:
|
107
158
|
yield ProcOutputEvent(data=out, name=self.name)
|
@@ -111,46 +162,129 @@ class Processor(
|
|
111
162
|
self._out_type_adapter.validate_python(payload) for payload in out_payloads
|
112
163
|
]
|
113
164
|
|
114
|
-
|
165
|
+
def _generate_run_id(self, run_id: str | None) -> str:
|
166
|
+
if run_id is None:
|
167
|
+
return str(uuid4())[:6] + "_" + self.name
|
168
|
+
return run_id
|
169
|
+
|
170
|
+
async def _run_single(
|
115
171
|
self,
|
116
172
|
chat_inputs: Any | None = None,
|
117
173
|
*,
|
118
|
-
in_packet: Packet[
|
119
|
-
in_args:
|
174
|
+
in_packet: Packet[InT] | None = None,
|
175
|
+
in_args: InT | None = None,
|
120
176
|
forgetful: bool = False,
|
177
|
+
run_id: str | None = None,
|
121
178
|
ctx: RunContext[CtxT] | None = None,
|
122
179
|
) -> Packet[OutT_co]:
|
123
|
-
resolved_in_args = self.
|
180
|
+
resolved_in_args = self._validate_and_resolve_single_input(
|
124
181
|
chat_inputs=chat_inputs, in_packet=in_packet, in_args=in_args
|
125
182
|
)
|
183
|
+
_memory = self.memory.model_copy(deep=True) if forgetful else self.memory
|
126
184
|
outputs = await self._process(
|
127
185
|
chat_inputs=chat_inputs,
|
128
186
|
in_args=resolved_in_args,
|
129
|
-
|
187
|
+
memory=_memory,
|
188
|
+
run_id=self._generate_run_id(run_id),
|
130
189
|
ctx=ctx,
|
131
190
|
)
|
132
191
|
val_outputs = self._validate_outputs(outputs)
|
133
192
|
|
134
193
|
return Packet(payloads=val_outputs, sender=self.name)
|
135
194
|
|
195
|
+
def _generate_par_run_id(self, run_id: str | None, idx: int) -> str:
|
196
|
+
return f"{self._generate_run_id(run_id)}/{idx}"
|
197
|
+
|
198
|
+
async def _run_par(
|
199
|
+
self,
|
200
|
+
chat_inputs: Any | None = None,
|
201
|
+
*,
|
202
|
+
in_packet: Packet[InT] | None = None,
|
203
|
+
in_args: Sequence[InT] | None = None,
|
204
|
+
run_id: str | None = None,
|
205
|
+
forgetful: bool = False,
|
206
|
+
ctx: RunContext[CtxT] | None = None,
|
207
|
+
) -> Packet[OutT_co]:
|
208
|
+
par_inputs = self._validate_and_resolve_parallel_inputs(
|
209
|
+
chat_inputs=chat_inputs, in_packet=in_packet, in_args=in_args
|
210
|
+
)
|
211
|
+
|
212
|
+
wrapped_func = retry(
|
213
|
+
wait=wait_random_exponential(min=1, max=8),
|
214
|
+
stop=stop_after_attempt(self._num_par_run_retries + 1),
|
215
|
+
before_sleep=retry_before_sleep_callback,
|
216
|
+
retry_error_callback=retry_error_callback,
|
217
|
+
)(self._run_single)
|
218
|
+
|
219
|
+
tasks = [
|
220
|
+
wrapped_func(
|
221
|
+
in_args=inp,
|
222
|
+
forgetful=True,
|
223
|
+
run_id=self._generate_par_run_id(run_id, idx),
|
224
|
+
ctx=ctx,
|
225
|
+
)
|
226
|
+
for idx, inp in enumerate(par_inputs)
|
227
|
+
]
|
228
|
+
out_packets = await asyncio.gather(*tasks)
|
229
|
+
|
230
|
+
return Packet( # type: ignore[return]
|
231
|
+
payloads=[
|
232
|
+
(out_packet.payloads[0] if out_packet else None)
|
233
|
+
for out_packet in out_packets
|
234
|
+
],
|
235
|
+
sender=self.name,
|
236
|
+
)
|
237
|
+
|
238
|
+
async def run(
|
239
|
+
self,
|
240
|
+
chat_inputs: Any | None = None,
|
241
|
+
*,
|
242
|
+
in_packet: Packet[InT] | None = None,
|
243
|
+
in_args: InT | Sequence[InT] | None = None,
|
244
|
+
forgetful: bool = False,
|
245
|
+
run_id: str | None = None,
|
246
|
+
ctx: RunContext[CtxT] | None = None,
|
247
|
+
) -> Packet[OutT_co]:
|
248
|
+
if isinstance(in_args, Sequence):
|
249
|
+
return await self._run_par(
|
250
|
+
chat_inputs=chat_inputs,
|
251
|
+
in_packet=in_packet,
|
252
|
+
in_args=cast("Sequence[InT]", in_args),
|
253
|
+
run_id=run_id,
|
254
|
+
forgetful=forgetful,
|
255
|
+
ctx=ctx,
|
256
|
+
)
|
257
|
+
return await self._run_single(
|
258
|
+
chat_inputs=chat_inputs,
|
259
|
+
in_packet=in_packet,
|
260
|
+
in_args=in_args,
|
261
|
+
forgetful=forgetful,
|
262
|
+
run_id=run_id,
|
263
|
+
ctx=ctx,
|
264
|
+
)
|
265
|
+
|
136
266
|
async def run_stream(
|
137
267
|
self,
|
138
268
|
chat_inputs: Any | None = None,
|
139
269
|
*,
|
140
|
-
in_packet: Packet[
|
141
|
-
in_args:
|
270
|
+
in_packet: Packet[InT] | None = None,
|
271
|
+
in_args: InT | None = None,
|
142
272
|
forgetful: bool = False,
|
273
|
+
run_id: str | None = None,
|
143
274
|
ctx: RunContext[CtxT] | None = None,
|
144
275
|
) -> AsyncIterator[Event[Any]]:
|
145
|
-
resolved_in_args = self.
|
276
|
+
resolved_in_args = self._validate_and_resolve_single_input(
|
146
277
|
chat_inputs=chat_inputs, in_packet=in_packet, in_args=in_args
|
147
278
|
)
|
148
279
|
|
280
|
+
_memory = self.memory.model_copy(deep=True) if forgetful else self.memory
|
281
|
+
|
149
282
|
outputs: Sequence[OutT_co] = []
|
150
283
|
async for output_event in self._process_stream(
|
151
284
|
chat_inputs=chat_inputs,
|
152
285
|
in_args=resolved_in_args,
|
153
|
-
|
286
|
+
memory=_memory,
|
287
|
+
run_id=self._generate_run_id(run_id),
|
154
288
|
ctx=ctx,
|
155
289
|
):
|
156
290
|
if isinstance(output_event, ProcOutputEvent):
|
@@ -166,7 +300,7 @@ class Processor(
|
|
166
300
|
@final
|
167
301
|
def as_tool(
|
168
302
|
self, tool_name: str, tool_description: str
|
169
|
-
) -> BaseTool[
|
303
|
+
) -> BaseTool[InT, OutT_co, Any]: # type: ignore[override]
|
170
304
|
# TODO: stream tools
|
171
305
|
processor_instance = self
|
172
306
|
in_type = processor_instance.in_type
|
@@ -182,7 +316,7 @@ class Processor(
|
|
182
316
|
description: str = tool_description
|
183
317
|
|
184
318
|
async def run(
|
185
|
-
self, inp:
|
319
|
+
self, inp: InT, ctx: RunContext[CtxT] | None = None
|
186
320
|
) -> OutT_co:
|
187
321
|
result = await processor_instance.run(
|
188
322
|
in_args=in_type.model_validate(inp), forgetful=True, ctx=ctx
|
grasp_agents/prompt_builder.py
CHANGED
@@ -1,15 +1,18 @@
|
|
1
1
|
import json
|
2
2
|
from collections.abc import Mapping, Sequence
|
3
|
-
from typing import ClassVar, Generic, Protocol, TypeAlias
|
3
|
+
from typing import ClassVar, Generic, Protocol, TypeAlias, TypeVar
|
4
4
|
|
5
5
|
from pydantic import BaseModel, TypeAdapter
|
6
6
|
|
7
|
+
from .errors import InputPromptBuilderError, SystemPromptBuilderError
|
7
8
|
from .generics_utils import AutoInstanceAttributesMixin
|
8
9
|
from .run_context import CtxT, RunContext
|
9
10
|
from .typing.content import Content, ImageData
|
10
|
-
from .typing.io import
|
11
|
+
from .typing.io import InT, LLMPrompt, LLMPromptArgs
|
11
12
|
from .typing.message import UserMessage
|
12
13
|
|
14
|
+
_InT_contra = TypeVar("_InT_contra", contravariant=True)
|
15
|
+
|
13
16
|
|
14
17
|
class MakeSystemPromptHandler(Protocol[CtxT]):
|
15
18
|
def __call__(
|
@@ -20,13 +23,12 @@ class MakeSystemPromptHandler(Protocol[CtxT]):
|
|
20
23
|
) -> str: ...
|
21
24
|
|
22
25
|
|
23
|
-
class MakeInputContentHandler(Protocol[
|
26
|
+
class MakeInputContentHandler(Protocol[_InT_contra, CtxT]):
|
24
27
|
def __call__(
|
25
28
|
self,
|
26
29
|
*,
|
27
|
-
in_args:
|
30
|
+
in_args: _InT_contra | None,
|
28
31
|
usr_args: LLMPromptArgs | None,
|
29
|
-
batch_idx: int,
|
30
32
|
ctx: RunContext[CtxT] | None,
|
31
33
|
) -> Content: ...
|
32
34
|
|
@@ -34,7 +36,7 @@ class MakeInputContentHandler(Protocol[InT_contra, CtxT]):
|
|
34
36
|
PromptArgumentType: TypeAlias = str | bool | int | ImageData
|
35
37
|
|
36
38
|
|
37
|
-
class PromptBuilder(AutoInstanceAttributesMixin, Generic[
|
39
|
+
class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
|
38
40
|
_generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {0: "_in_type"}
|
39
41
|
|
40
42
|
def __init__(
|
@@ -45,7 +47,7 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT_contra, CtxT]):
|
|
45
47
|
sys_args_schema: type[LLMPromptArgs] | None = None,
|
46
48
|
usr_args_schema: type[LLMPromptArgs] | None = None,
|
47
49
|
):
|
48
|
-
self._in_type: type[
|
50
|
+
self._in_type: type[InT]
|
49
51
|
super().__init__()
|
50
52
|
|
51
53
|
self._agent_name = agent_name
|
@@ -54,11 +56,9 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT_contra, CtxT]):
|
|
54
56
|
self.sys_args_schema = sys_args_schema
|
55
57
|
self.usr_args_schema = usr_args_schema
|
56
58
|
self.make_system_prompt_impl: MakeSystemPromptHandler[CtxT] | None = None
|
57
|
-
self.make_input_content_impl:
|
58
|
-
MakeInputContentHandler[InT_contra, CtxT] | None
|
59
|
-
) = None
|
59
|
+
self.make_input_content_impl: MakeInputContentHandler[InT, CtxT] | None = None
|
60
60
|
|
61
|
-
self._in_args_type_adapter: TypeAdapter[
|
61
|
+
self._in_args_type_adapter: TypeAdapter[InT] = TypeAdapter(self._in_type)
|
62
62
|
|
63
63
|
def make_system_prompt(
|
64
64
|
self, sys_args: LLMPromptArgs | None = None, ctx: RunContext[CtxT] | None = None
|
@@ -71,9 +71,9 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT_contra, CtxT]):
|
|
71
71
|
if self.sys_args_schema is not None:
|
72
72
|
val_sys_args = self.sys_args_schema.model_validate(sys_args)
|
73
73
|
else:
|
74
|
-
raise
|
75
|
-
"System prompt template is set, but system arguments
|
76
|
-
"provided."
|
74
|
+
raise SystemPromptBuilderError(
|
75
|
+
"System prompt template and arguments is set, but system arguments "
|
76
|
+
"schema is not provided."
|
77
77
|
)
|
78
78
|
|
79
79
|
if self.make_system_prompt_impl:
|
@@ -86,9 +86,8 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT_contra, CtxT]):
|
|
86
86
|
def make_input_content(
|
87
87
|
self,
|
88
88
|
*,
|
89
|
-
in_args:
|
89
|
+
in_args: InT | None,
|
90
90
|
usr_args: LLMPromptArgs | None,
|
91
|
-
batch_idx: int = 0,
|
92
91
|
ctx: RunContext[CtxT] | None = None,
|
93
92
|
) -> Content:
|
94
93
|
val_in_args, val_usr_args = self._validate_prompt_args(
|
@@ -97,7 +96,7 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT_contra, CtxT]):
|
|
97
96
|
|
98
97
|
if self.make_input_content_impl:
|
99
98
|
return self.make_input_content_impl(
|
100
|
-
in_args=val_in_args, usr_args=val_usr_args,
|
99
|
+
in_args=val_in_args, usr_args=val_usr_args, ctx=ctx
|
101
100
|
)
|
102
101
|
|
103
102
|
combined_args = self._combine_args(in_args=val_in_args, usr_args=val_usr_args)
|
@@ -106,56 +105,46 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT_contra, CtxT]):
|
|
106
105
|
|
107
106
|
if self.in_prompt_template is not None:
|
108
107
|
return Content.from_formatted_prompt(
|
109
|
-
self.in_prompt_template,
|
108
|
+
self.in_prompt_template, **combined_args
|
110
109
|
)
|
111
110
|
|
112
111
|
return Content.from_text(json.dumps(combined_args, indent=2))
|
113
112
|
|
114
|
-
def
|
113
|
+
def make_user_message(
|
115
114
|
self,
|
116
115
|
chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
|
117
|
-
|
116
|
+
in_args: InT | None = None,
|
118
117
|
usr_args: LLMPromptArgs | None = None,
|
119
118
|
ctx: RunContext[CtxT] | None = None,
|
120
|
-
) ->
|
119
|
+
) -> UserMessage | None:
|
120
|
+
if chat_inputs is None and in_args is None and usr_args is None:
|
121
|
+
return None
|
122
|
+
|
121
123
|
if chat_inputs:
|
122
124
|
if isinstance(chat_inputs, LLMPrompt):
|
123
|
-
return
|
124
|
-
return
|
125
|
+
return UserMessage.from_text(chat_inputs, name=self._agent_name)
|
126
|
+
return UserMessage.from_content_parts(chat_inputs, name=self._agent_name)
|
125
127
|
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
)
|
130
|
-
for i, in_args in enumerate(in_args_batch or [None])
|
131
|
-
]
|
132
|
-
return [
|
133
|
-
UserMessage(content=in_content, name=self._agent_name)
|
134
|
-
for in_content in in_content_batch
|
135
|
-
]
|
136
|
-
|
137
|
-
def _usr_messages_from_text(self, text: str) -> list[UserMessage]:
|
138
|
-
return [UserMessage.from_text(text, name=self._agent_name)]
|
128
|
+
in_content = self.make_input_content(
|
129
|
+
in_args=in_args, usr_args=usr_args, ctx=ctx
|
130
|
+
)
|
139
131
|
|
140
|
-
|
141
|
-
self, content_parts: Sequence[str | ImageData]
|
142
|
-
) -> list[UserMessage]:
|
143
|
-
return [UserMessage.from_content_parts(content_parts, name=self._agent_name)]
|
132
|
+
return UserMessage(content=in_content, name=self._agent_name)
|
144
133
|
|
145
134
|
def _validate_prompt_args(
|
146
135
|
self,
|
147
136
|
*,
|
148
|
-
in_args:
|
137
|
+
in_args: InT | None,
|
149
138
|
usr_args: LLMPromptArgs | None,
|
150
|
-
) -> tuple[
|
139
|
+
) -> tuple[InT | None, LLMPromptArgs | None]:
|
151
140
|
val_usr_args = usr_args
|
152
141
|
if usr_args is not None:
|
153
142
|
if self.in_prompt_template is None:
|
154
|
-
raise
|
143
|
+
raise InputPromptBuilderError(
|
155
144
|
"Input prompt template is not set, but user arguments are provided."
|
156
145
|
)
|
157
146
|
if self.usr_args_schema is None:
|
158
|
-
raise
|
147
|
+
raise InputPromptBuilderError(
|
159
148
|
"User arguments schema is not provided, but user arguments are "
|
160
149
|
"given."
|
161
150
|
)
|
@@ -167,12 +156,12 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT_contra, CtxT]):
|
|
167
156
|
if isinstance(val_in_args, BaseModel):
|
168
157
|
has_image = self._has_image_data(val_in_args)
|
169
158
|
if has_image and self.in_prompt_template is None:
|
170
|
-
raise
|
159
|
+
raise InputPromptBuilderError(
|
171
160
|
"BaseModel input arguments contain ImageData, but input prompt "
|
172
161
|
"template is not set. Cannot format input arguments."
|
173
162
|
)
|
174
163
|
elif self.in_prompt_template is not None:
|
175
|
-
raise
|
164
|
+
raise InputPromptBuilderError(
|
176
165
|
"Cannot use the input prompt template with "
|
177
166
|
"non-BaseModel input arguments."
|
178
167
|
)
|
@@ -189,9 +178,7 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT_contra, CtxT]):
|
|
189
178
|
return contains_image_data
|
190
179
|
|
191
180
|
@staticmethod
|
192
|
-
def _format_pydantic_prompt_args(
|
193
|
-
inp: BaseModel,
|
194
|
-
) -> dict[str, PromptArgumentType]:
|
181
|
+
def _format_pydantic_prompt_args(inp: BaseModel) -> dict[str, PromptArgumentType]:
|
195
182
|
formatted_args: dict[str, PromptArgumentType] = {}
|
196
183
|
for field in type(inp).model_fields:
|
197
184
|
if field == "selected_recipients":
|
@@ -200,18 +187,17 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT_contra, CtxT]):
|
|
200
187
|
val = getattr(inp, field)
|
201
188
|
if isinstance(val, (int, str, bool, ImageData)):
|
202
189
|
formatted_args[field] = val
|
203
|
-
elif isinstance(val, BaseModel):
|
204
|
-
formatted_args[field] = val.model_dump_json(indent=2, warnings="error")
|
205
190
|
else:
|
206
|
-
|
207
|
-
|
208
|
-
|
191
|
+
formatted_args[field] = (
|
192
|
+
TypeAdapter(type(val)) # type: ignore[return-value]
|
193
|
+
.dump_json(val, indent=2, warnings="error")
|
194
|
+
.decode("utf-8")
|
209
195
|
)
|
210
196
|
|
211
197
|
return formatted_args
|
212
198
|
|
213
199
|
def _combine_args(
|
214
|
-
self, *, in_args:
|
200
|
+
self, *, in_args: InT | None, usr_args: LLMPromptArgs | None
|
215
201
|
) -> Mapping[str, PromptArgumentType] | str:
|
216
202
|
fmt_usr_args = self._format_pydantic_prompt_args(usr_args) if usr_args else {}
|
217
203
|
|
grasp_agents/run_context.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1
1
|
from collections import defaultdict
|
2
2
|
from collections.abc import Mapping
|
3
3
|
from typing import Any, Generic, TypeVar
|
4
|
-
from uuid import uuid4
|
5
4
|
|
6
5
|
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
|
7
6
|
|
@@ -23,8 +22,6 @@ CtxT = TypeVar("CtxT")
|
|
23
22
|
|
24
23
|
|
25
24
|
class RunContext(BaseModel, Generic[CtxT]):
|
26
|
-
run_id: str = Field(default_factory=lambda: str(uuid4())[:8], frozen=True)
|
27
|
-
|
28
25
|
state: CtxT | None = None
|
29
26
|
|
30
27
|
run_args: dict[ProcName, RunArgs] = Field(default_factory=dict)
|
@@ -39,9 +36,8 @@ class RunContext(BaseModel, Generic[CtxT]):
|
|
39
36
|
_printer: Printer = PrivateAttr()
|
40
37
|
|
41
38
|
def model_post_init(self, context: Any) -> None: # noqa: ARG002
|
42
|
-
self._usage_tracker = UsageTracker(
|
39
|
+
self._usage_tracker = UsageTracker()
|
43
40
|
self._printer = Printer(
|
44
|
-
source_id=self.run_id,
|
45
41
|
print_messages=self.print_messages,
|
46
42
|
color_by=self.color_messages_by,
|
47
43
|
)
|
@@ -11,6 +11,7 @@ from openai.types.chat.chat_completion_token_logprob import (
|
|
11
11
|
)
|
12
12
|
from pydantic import BaseModel, Field
|
13
13
|
|
14
|
+
from ..errors import CombineCompletionChunksError
|
14
15
|
from .completion import (
|
15
16
|
Completion,
|
16
17
|
CompletionChoice,
|
@@ -54,21 +55,25 @@ class CompletionChunk(BaseModel):
|
|
54
55
|
|
55
56
|
def combine_completion_chunks(chunks: list[CompletionChunk]) -> Completion:
|
56
57
|
if not chunks:
|
57
|
-
raise
|
58
|
+
raise CombineCompletionChunksError(
|
59
|
+
"Cannot combine an empty list of completion chunks."
|
60
|
+
)
|
58
61
|
|
59
62
|
model_list = {chunk.model for chunk in chunks}
|
60
63
|
if len(model_list) > 1:
|
61
|
-
raise
|
64
|
+
raise CombineCompletionChunksError("All chunks must have the same model.")
|
62
65
|
model = model_list.pop()
|
63
66
|
|
64
67
|
name_list = {chunk.name for chunk in chunks}
|
65
68
|
if len(name_list) > 1:
|
66
|
-
raise
|
69
|
+
raise CombineCompletionChunksError("All chunks must have the same name.")
|
67
70
|
name = name_list.pop()
|
68
71
|
|
69
72
|
system_fingerprints_list = {chunk.system_fingerprint for chunk in chunks}
|
70
73
|
if len(system_fingerprints_list) > 1:
|
71
|
-
raise
|
74
|
+
raise CombineCompletionChunksError(
|
75
|
+
"All chunks must have the same system fingerprint."
|
76
|
+
)
|
72
77
|
system_fingerprint = system_fingerprints_list.pop()
|
73
78
|
|
74
79
|
created_list = [chunk.created for chunk in chunks]
|
@@ -128,7 +133,7 @@ def combine_completion_chunks(chunks: list[CompletionChunk]) -> Completion:
|
|
128
133
|
or _tool_call.tool_name is None
|
129
134
|
or _tool_call.tool_arguments is None
|
130
135
|
):
|
131
|
-
raise
|
136
|
+
raise CombineCompletionChunksError(
|
132
137
|
"Completion chunk tool calls must have id, tool_name, "
|
133
138
|
"and tool_arguments set."
|
134
139
|
)
|
grasp_agents/typing/content.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
import base64
|
2
2
|
import re
|
3
|
-
from collections.abc import Iterable
|
3
|
+
from collections.abc import Iterable
|
4
4
|
from enum import StrEnum
|
5
5
|
from pathlib import Path
|
6
6
|
from typing import Annotated, Any, Literal, TypeAlias
|
@@ -66,7 +66,7 @@ class Content(BaseModel):
|
|
66
66
|
def from_formatted_prompt(
|
67
67
|
cls,
|
68
68
|
prompt_template: str,
|
69
|
-
prompt_args:
|
69
|
+
**prompt_args: str | int | bool | ImageData | None,
|
70
70
|
) -> "Content":
|
71
71
|
prompt_args = prompt_args or {}
|
72
72
|
image_args = {
|
grasp_agents/typing/io.py
CHANGED
@@ -5,12 +5,12 @@ from pydantic import BaseModel
|
|
5
5
|
ProcName: TypeAlias = str
|
6
6
|
|
7
7
|
|
8
|
+
InT = TypeVar("InT")
|
9
|
+
OutT_co = TypeVar("OutT_co", covariant=True)
|
10
|
+
|
11
|
+
|
8
12
|
class LLMPromptArgs(BaseModel):
|
9
13
|
pass
|
10
14
|
|
11
15
|
|
12
|
-
InT_contra = TypeVar("InT_contra", contravariant=True)
|
13
|
-
OutT_co = TypeVar("OutT_co", covariant=True)
|
14
|
-
MemT_co = TypeVar("MemT_co", covariant=True)
|
15
|
-
|
16
16
|
LLMPrompt: TypeAlias = str
|