grasp_agents 0.1.18__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_agents/agent_message.py +2 -2
- grasp_agents/agent_message_pool.py +6 -8
- grasp_agents/base_agent.py +15 -36
- grasp_agents/cloud_llm.py +10 -6
- grasp_agents/comm_agent.py +39 -43
- grasp_agents/generics_utils.py +159 -0
- grasp_agents/llm.py +4 -0
- grasp_agents/llm_agent.py +126 -46
- grasp_agents/llm_agent_state.py +18 -12
- grasp_agents/prompt_builder.py +55 -28
- grasp_agents/rate_limiting/rate_limiter_chunked.py +49 -48
- grasp_agents/rate_limiting/types.py +19 -40
- grasp_agents/rate_limiting/utils.py +24 -27
- grasp_agents/run_context.py +2 -15
- grasp_agents/tool_orchestrator.py +34 -12
- grasp_agents/typing/content.py +2 -2
- grasp_agents/typing/converters.py +3 -1
- grasp_agents/typing/io.py +7 -11
- grasp_agents/typing/message.py +2 -2
- grasp_agents/typing/tool.py +26 -14
- grasp_agents/utils.py +90 -96
- grasp_agents/workflow/looped_agent.py +12 -9
- grasp_agents/workflow/sequential_agent.py +9 -6
- grasp_agents/workflow/workflow_agent.py +16 -11
- {grasp_agents-0.1.18.dist-info → grasp_agents-0.2.1.dist-info}/METADATA +37 -33
- grasp_agents-0.2.1.dist-info/RECORD +45 -0
- {grasp_agents-0.1.18.dist-info → grasp_agents-0.2.1.dist-info}/licenses/LICENSE.md +1 -1
- grasp_agents-0.1.18.dist-info/RECORD +0 -44
- {grasp_agents-0.1.18.dist-info → grasp_agents-0.2.1.dist-info}/WHEEL +0 -0
grasp_agents/llm_agent.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
from collections.abc import Sequence
|
2
2
|
from pathlib import Path
|
3
|
-
from typing import Any, Generic, cast, final
|
3
|
+
from typing import Any, ClassVar, Generic, Protocol, cast, final
|
4
4
|
|
5
5
|
from pydantic import BaseModel
|
6
6
|
|
@@ -10,7 +10,7 @@ from .comm_agent import CommunicatingAgent
|
|
10
10
|
from .llm import LLM, LLMSettings
|
11
11
|
from .llm_agent_state import (
|
12
12
|
LLMAgentState,
|
13
|
-
|
13
|
+
SetAgentState,
|
14
14
|
SetAgentStateStrategy,
|
15
15
|
)
|
16
16
|
from .prompt_builder import (
|
@@ -25,12 +25,15 @@ from .run_context import (
|
|
25
25
|
SystemRunArgs,
|
26
26
|
UserRunArgs,
|
27
27
|
)
|
28
|
-
from .tool_orchestrator import
|
28
|
+
from .tool_orchestrator import (
|
29
|
+
ExitToolCallLoopHandler,
|
30
|
+
ManageAgentStateHandler,
|
31
|
+
ToolOrchestrator,
|
32
|
+
)
|
29
33
|
from .typing.content import ImageData
|
30
34
|
from .typing.converters import Converters
|
31
35
|
from .typing.io import (
|
32
36
|
AgentID,
|
33
|
-
AgentPayload,
|
34
37
|
AgentState,
|
35
38
|
InT,
|
36
39
|
LLMFormattedArgs,
|
@@ -41,13 +44,29 @@ from .typing.io import (
|
|
41
44
|
)
|
42
45
|
from .typing.message import Conversation, Message, SystemMessage
|
43
46
|
from .typing.tool import BaseTool
|
44
|
-
from .utils import get_prompt
|
47
|
+
from .utils import get_prompt, validate_obj_from_json_or_py_string
|
48
|
+
|
49
|
+
|
50
|
+
class ParseOutputHandler(Protocol[InT, OutT, CtxT]):
|
51
|
+
def __call__(
|
52
|
+
self,
|
53
|
+
conversation: Conversation,
|
54
|
+
*,
|
55
|
+
rcv_args: InT | None,
|
56
|
+
batch_idx: int,
|
57
|
+
ctx: RunContextWrapper[CtxT] | None,
|
58
|
+
) -> OutT: ...
|
45
59
|
|
46
60
|
|
47
61
|
class LLMAgent(
|
48
62
|
CommunicatingAgent[InT, OutT, LLMAgentState, CtxT],
|
49
63
|
Generic[InT, OutT, CtxT],
|
50
64
|
):
|
65
|
+
_generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
|
66
|
+
0: "_in_type",
|
67
|
+
1: "_out_type",
|
68
|
+
}
|
69
|
+
|
51
70
|
def __init__(
|
52
71
|
self,
|
53
72
|
agent_id: AgentID,
|
@@ -64,10 +83,6 @@ class LLMAgent(
|
|
64
83
|
sys_args_schema: type[LLMPromptArgs] = LLMPromptArgs,
|
65
84
|
# User args (static args provided via RunContextWrapper)
|
66
85
|
usr_args_schema: type[LLMPromptArgs] = LLMPromptArgs,
|
67
|
-
# Received args (args from another agent)
|
68
|
-
rcv_args_schema: type[InT] = cast("type[InT]", AgentPayload),
|
69
|
-
# Output schema
|
70
|
-
out_schema: type[OutT] = cast("type[OutT]", AgentPayload),
|
71
86
|
# Tools
|
72
87
|
tools: list[BaseTool[Any, Any, CtxT]] | None = None,
|
73
88
|
max_turns: int = 1000,
|
@@ -79,19 +94,21 @@ class LLMAgent(
|
|
79
94
|
recipient_ids: list[AgentID] | None = None,
|
80
95
|
) -> None:
|
81
96
|
super().__init__(
|
82
|
-
agent_id=agent_id,
|
83
|
-
out_schema=out_schema,
|
84
|
-
rcv_args_schema=rcv_args_schema,
|
85
|
-
message_pool=message_pool,
|
86
|
-
recipient_ids=recipient_ids,
|
97
|
+
agent_id=agent_id, message_pool=message_pool, recipient_ids=recipient_ids
|
87
98
|
)
|
88
99
|
|
89
100
|
# Agent state
|
90
101
|
self._state: LLMAgentState = LLMAgentState()
|
91
102
|
self.set_state_strategy: SetAgentStateStrategy = set_state_strategy
|
92
|
-
self.
|
103
|
+
self._set_agent_state_impl: SetAgentState | None = None
|
93
104
|
|
94
105
|
# Tool orchestrator
|
106
|
+
|
107
|
+
self._using_default_llm_response_format: bool = False
|
108
|
+
if llm.response_format is None and tools is None:
|
109
|
+
llm.response_format = self.out_type
|
110
|
+
self._using_default_llm_response_format = True
|
111
|
+
|
95
112
|
self._tool_orchestrator: ToolOrchestrator[CtxT] = ToolOrchestrator[CtxT](
|
96
113
|
agent_id=self.agent_id,
|
97
114
|
llm=llm,
|
@@ -103,28 +120,19 @@ class LLMAgent(
|
|
103
120
|
# Prompt builder
|
104
121
|
sys_prompt = get_prompt(prompt_text=sys_prompt, prompt_path=sys_prompt_path)
|
105
122
|
inp_prompt = get_prompt(prompt_text=inp_prompt, prompt_path=inp_prompt_path)
|
106
|
-
self._prompt_builder: PromptBuilder[InT, CtxT] = PromptBuilder[
|
123
|
+
self._prompt_builder: PromptBuilder[InT, CtxT] = PromptBuilder[
|
124
|
+
self.in_type, CtxT
|
125
|
+
](
|
107
126
|
agent_id=self._agent_id,
|
108
127
|
sys_prompt=sys_prompt,
|
109
128
|
inp_prompt=inp_prompt,
|
110
129
|
sys_args_schema=sys_args_schema,
|
111
130
|
usr_args_schema=usr_args_schema,
|
112
|
-
rcv_args_schema=rcv_args_schema,
|
113
131
|
)
|
114
132
|
|
115
133
|
self.no_tqdm = getattr(llm, "no_tqdm", False)
|
116
134
|
|
117
|
-
|
118
|
-
self._prompt_builder.format_sys_args_impl = self._format_sys_args
|
119
|
-
|
120
|
-
if type(self)._format_inp_args is not LLMAgent[Any, Any, Any]._format_inp_args: # noqa: SLF001
|
121
|
-
self._prompt_builder.format_inp_args_impl = self._format_inp_args
|
122
|
-
|
123
|
-
if (
|
124
|
-
type(self)._tool_call_loop_exit # noqa: SLF001
|
125
|
-
is not LLMAgent[Any, Any, Any]._tool_call_loop_exit # noqa: SLF001
|
126
|
-
):
|
127
|
-
self._tool_orchestrator.tool_call_loop_exit_impl = self._tool_call_loop_exit
|
135
|
+
self._register_overridden_handlers()
|
128
136
|
|
129
137
|
@property
|
130
138
|
def llm(self) -> LLM[LLMSettings, Converters]:
|
@@ -159,17 +167,29 @@ class LLMAgent(
|
|
159
167
|
conversation: Conversation,
|
160
168
|
*,
|
161
169
|
rcv_args: InT | None = None,
|
170
|
+
batch_idx: int = 0,
|
162
171
|
ctx: RunContextWrapper[CtxT] | None = None,
|
163
|
-
**kwargs: Any,
|
164
172
|
) -> OutT:
|
165
173
|
if self._parse_output_impl:
|
174
|
+
if self._using_default_llm_response_format:
|
175
|
+
# When using custom output parsing, the required LLM response format
|
176
|
+
# can differ from the final agent output type ->
|
177
|
+
# set it back to None unless it was specified explicitly at init.
|
178
|
+
self._tool_orchestrator.llm.response_format = None
|
179
|
+
self._using_default_llm_response_format = False
|
180
|
+
|
166
181
|
return self._parse_output_impl(
|
167
|
-
conversation=conversation,
|
182
|
+
conversation=conversation,
|
183
|
+
rcv_args=rcv_args,
|
184
|
+
batch_idx=batch_idx,
|
185
|
+
ctx=ctx,
|
168
186
|
)
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
187
|
+
|
188
|
+
return validate_obj_from_json_or_py_string(
|
189
|
+
str(conversation[-1].content),
|
190
|
+
adapter=self._out_type_adapter,
|
191
|
+
from_substring=True,
|
192
|
+
)
|
173
193
|
|
174
194
|
@final
|
175
195
|
async def run(
|
@@ -177,7 +197,7 @@ class LLMAgent(
|
|
177
197
|
inp_items: LLMPrompt | list[str | ImageData] | None = None,
|
178
198
|
*,
|
179
199
|
ctx: RunContextWrapper[CtxT] | None = None,
|
180
|
-
rcv_message: AgentMessage[InT,
|
200
|
+
rcv_message: AgentMessage[InT, AgentState] | None = None,
|
181
201
|
entry_point: bool = False,
|
182
202
|
forbid_state_change: bool = False,
|
183
203
|
**gen_kwargs: Any, # noqa: ARG002
|
@@ -217,7 +237,7 @@ class LLMAgent(
|
|
217
237
|
rcv_state=rcv_state,
|
218
238
|
sys_prompt=formatted_sys_prompt,
|
219
239
|
strategy=self.set_state_strategy,
|
220
|
-
|
240
|
+
set_agent_state_impl=self._set_agent_state_impl,
|
221
241
|
ctx=ctx,
|
222
242
|
)
|
223
243
|
|
@@ -241,13 +261,13 @@ class LLMAgent(
|
|
241
261
|
else:
|
242
262
|
# 4. Run tool call loop (new messages are added to the message
|
243
263
|
# history inside the loop)
|
244
|
-
await self._tool_orchestrator.run_loop(
|
264
|
+
await self._tool_orchestrator.run_loop(state=state, ctx=ctx)
|
245
265
|
|
246
266
|
# 5. Parse outputs
|
247
267
|
batch_size = state.message_history.batch_size
|
248
268
|
rcv_args_batch = rcv_message.payloads if rcv_message else batch_size * [None]
|
249
269
|
val_output_batch = [
|
250
|
-
self.
|
270
|
+
self._out_type_adapter.validate_python(
|
251
271
|
self._parse_output(conversation=conv, rcv_args=rcv_args, ctx=ctx)
|
252
272
|
)
|
253
273
|
for conv, rcv_args in zip(
|
@@ -276,7 +296,7 @@ class LLMAgent(
|
|
276
296
|
)
|
277
297
|
ctx.interaction_history.append(
|
278
298
|
cast(
|
279
|
-
"InteractionRecord[
|
299
|
+
"InteractionRecord[Any, Any, AgentState]",
|
280
300
|
interaction_record,
|
281
301
|
)
|
282
302
|
)
|
@@ -330,22 +350,57 @@ class LLMAgent(
|
|
330
350
|
|
331
351
|
return func
|
332
352
|
|
333
|
-
def
|
334
|
-
self, func:
|
335
|
-
) ->
|
353
|
+
def parse_output_handler(
|
354
|
+
self, func: ParseOutputHandler[InT, OutT, CtxT]
|
355
|
+
) -> ParseOutputHandler[InT, OutT, CtxT]:
|
356
|
+
self._parse_output_impl = func
|
357
|
+
|
358
|
+
return func
|
359
|
+
|
360
|
+
def set_agent_state_handler(self, func: SetAgentState) -> SetAgentState:
|
336
361
|
self._make_custom_agent_state_impl = func
|
337
362
|
|
338
363
|
return func
|
339
364
|
|
340
|
-
def
|
341
|
-
self, func:
|
342
|
-
) ->
|
343
|
-
self._tool_orchestrator.
|
365
|
+
def exit_tool_call_loop_handler(
|
366
|
+
self, func: ExitToolCallLoopHandler[CtxT]
|
367
|
+
) -> ExitToolCallLoopHandler[CtxT]:
|
368
|
+
self._tool_orchestrator.exit_tool_call_loop_impl = func
|
369
|
+
|
370
|
+
return func
|
371
|
+
|
372
|
+
def manage_agent_state_handler(
|
373
|
+
self, func: ManageAgentStateHandler[CtxT]
|
374
|
+
) -> ManageAgentStateHandler[CtxT]:
|
375
|
+
self._tool_orchestrator.manage_agent_state_impl = func
|
344
376
|
|
345
377
|
return func
|
346
378
|
|
347
379
|
# -- Override these methods in subclasses if needed --
|
348
380
|
|
381
|
+
def _register_overridden_handlers(self) -> None:
|
382
|
+
cur_cls = type(self)
|
383
|
+
base_cls = LLMAgent[Any, Any, Any]
|
384
|
+
|
385
|
+
if cur_cls._format_sys_args is not base_cls._format_sys_args: # noqa: SLF001
|
386
|
+
self._prompt_builder.format_sys_args_impl = self._format_sys_args
|
387
|
+
|
388
|
+
if cur_cls._format_inp_args is not base_cls._format_inp_args: # noqa: SLF001
|
389
|
+
self._prompt_builder.format_inp_args_impl = self._format_inp_args
|
390
|
+
|
391
|
+
if cur_cls._set_agent_state is not base_cls._set_agent_state: # noqa: SLF001
|
392
|
+
self._set_agent_state_impl = self._set_agent_state
|
393
|
+
|
394
|
+
if cur_cls._manage_agent_state is not base_cls._manage_agent_state: # noqa: SLF001
|
395
|
+
self._tool_orchestrator.manage_agent_state_impl = self._manage_agent_state
|
396
|
+
|
397
|
+
if (
|
398
|
+
cur_cls._tool_call_loop_exit is not base_cls._tool_call_loop_exit # noqa: SLF001
|
399
|
+
):
|
400
|
+
self._tool_orchestrator.exit_tool_call_loop_impl = self._tool_call_loop_exit
|
401
|
+
|
402
|
+
self._parse_output_impl: ParseOutputHandler[InT, OutT, CtxT] | None = None
|
403
|
+
|
349
404
|
def _format_sys_args(
|
350
405
|
self,
|
351
406
|
sys_args: LLMPromptArgs,
|
@@ -361,12 +416,26 @@ class LLMAgent(
|
|
361
416
|
self,
|
362
417
|
usr_args: LLMPromptArgs,
|
363
418
|
rcv_args: InT,
|
419
|
+
*,
|
420
|
+
batch_idx: int = 0,
|
364
421
|
ctx: RunContextWrapper[CtxT] | None = None,
|
365
422
|
) -> LLMFormattedArgs:
|
366
423
|
raise NotImplementedError(
|
367
424
|
"LLMAgent._format_inp_args must be overridden by a subclass"
|
368
425
|
)
|
369
426
|
|
427
|
+
def _set_agent_state(
|
428
|
+
self,
|
429
|
+
cur_state: LLMAgentState,
|
430
|
+
*,
|
431
|
+
rcv_state: AgentState | None,
|
432
|
+
sys_prompt: LLMPrompt | None,
|
433
|
+
ctx: RunContextWrapper[Any] | None,
|
434
|
+
) -> LLMAgentState:
|
435
|
+
raise NotImplementedError(
|
436
|
+
"LLMAgent._set_agent_state_handler must be overridden by a subclass"
|
437
|
+
)
|
438
|
+
|
370
439
|
def _tool_call_loop_exit(
|
371
440
|
self,
|
372
441
|
conversation: Conversation,
|
@@ -377,3 +446,14 @@ class LLMAgent(
|
|
377
446
|
raise NotImplementedError(
|
378
447
|
"LLMAgent._tool_call_loop_exit must be overridden by a subclass"
|
379
448
|
)
|
449
|
+
|
450
|
+
def _manage_agent_state(
|
451
|
+
self,
|
452
|
+
state: LLMAgentState,
|
453
|
+
*,
|
454
|
+
ctx: RunContextWrapper[CtxT] | None = None,
|
455
|
+
**kwargs: Any,
|
456
|
+
) -> None:
|
457
|
+
raise NotImplementedError(
|
458
|
+
"LLMAgent._manage_agent_state must be overridden by a subclass"
|
459
|
+
)
|
grasp_agents/llm_agent_state.py
CHANGED
@@ -10,11 +10,12 @@ from .typing.io import AgentState, LLMPrompt
|
|
10
10
|
SetAgentStateStrategy = Literal["keep", "reset", "from_sender", "custom"]
|
11
11
|
|
12
12
|
|
13
|
-
class
|
13
|
+
class SetAgentState(Protocol):
|
14
14
|
def __call__(
|
15
15
|
self,
|
16
|
-
cur_state:
|
17
|
-
|
16
|
+
cur_state: "LLMAgentState",
|
17
|
+
*,
|
18
|
+
rcv_state: AgentState | None,
|
18
19
|
sys_prompt: LLMPrompt | None,
|
19
20
|
ctx: RunContextWrapper[Any] | None,
|
20
21
|
) -> "LLMAgentState": ...
|
@@ -30,11 +31,12 @@ class LLMAgentState(AgentState):
|
|
30
31
|
@classmethod
|
31
32
|
def from_cur_and_rcv_states(
|
32
33
|
cls,
|
33
|
-
cur_state:
|
34
|
-
|
34
|
+
cur_state: "LLMAgentState",
|
35
|
+
*,
|
36
|
+
rcv_state: Optional["AgentState"] = None,
|
35
37
|
sys_prompt: LLMPrompt | None = None,
|
36
38
|
strategy: SetAgentStateStrategy = "from_sender",
|
37
|
-
|
39
|
+
set_agent_state_impl: SetAgentState | None = None,
|
38
40
|
ctx: RunContextWrapper[Any] | None = None,
|
39
41
|
) -> "LLMAgentState":
|
40
42
|
upd_mh = cur_state.message_history if cur_state else None
|
@@ -48,24 +50,28 @@ class LLMAgentState(AgentState):
|
|
48
50
|
upd_mh.reset(sys_prompt)
|
49
51
|
|
50
52
|
elif strategy == "from_sender":
|
51
|
-
rcv_mh =
|
53
|
+
rcv_mh = (
|
54
|
+
rcv_state.message_history
|
55
|
+
if rcv_state and isinstance(rcv_state, "LLMAgentState")
|
56
|
+
else None
|
57
|
+
)
|
52
58
|
if rcv_mh:
|
53
59
|
upd_mh = deepcopy(rcv_mh)
|
54
60
|
else:
|
55
61
|
upd_mh.reset(sys_prompt)
|
56
62
|
|
57
63
|
elif strategy == "custom":
|
58
|
-
assert
|
59
|
-
"
|
64
|
+
assert set_agent_state_impl is not None, (
|
65
|
+
"Agent state setter implementation is not provided."
|
60
66
|
)
|
61
|
-
return
|
67
|
+
return set_agent_state_impl(
|
62
68
|
cur_state=cur_state,
|
63
|
-
|
69
|
+
rcv_state=rcv_state,
|
64
70
|
sys_prompt=sys_prompt,
|
65
71
|
ctx=ctx,
|
66
72
|
)
|
67
73
|
|
68
|
-
return cls(message_history=upd_mh)
|
74
|
+
return cls.model_construct(message_history=upd_mh)
|
69
75
|
|
70
76
|
def __repr__(self) -> str:
|
71
77
|
return f"Message History: {len(self.message_history)}"
|
grasp_agents/prompt_builder.py
CHANGED
@@ -1,11 +1,13 @@
|
|
1
1
|
from collections.abc import Sequence
|
2
2
|
from copy import deepcopy
|
3
|
-
from typing import Generic, Protocol
|
3
|
+
from typing import ClassVar, Generic, Protocol
|
4
4
|
|
5
|
+
from pydantic import BaseModel, TypeAdapter
|
6
|
+
|
7
|
+
from .generics_utils import AutoInstanceAttributesMixin
|
5
8
|
from .run_context import CtxT, RunContextWrapper, UserRunArgs
|
6
9
|
from .typing.content import ImageData
|
7
10
|
from .typing.io import (
|
8
|
-
AgentPayload,
|
9
11
|
InT,
|
10
12
|
LLMFormattedArgs,
|
11
13
|
LLMFormattedSystemArgs,
|
@@ -15,6 +17,10 @@ from .typing.io import (
|
|
15
17
|
from .typing.message import UserMessage
|
16
18
|
|
17
19
|
|
20
|
+
class DummySchema(BaseModel):
|
21
|
+
pass
|
22
|
+
|
23
|
+
|
18
24
|
class FormatSystemArgsHandler(Protocol[CtxT]):
|
19
25
|
def __call__(
|
20
26
|
self,
|
@@ -29,11 +35,15 @@ class FormatInputArgsHandler(Protocol[InT, CtxT]):
|
|
29
35
|
self,
|
30
36
|
usr_args: LLMPromptArgs,
|
31
37
|
rcv_args: InT,
|
38
|
+
*,
|
39
|
+
batch_idx: int,
|
32
40
|
ctx: RunContextWrapper[CtxT] | None,
|
33
41
|
) -> LLMFormattedArgs: ...
|
34
42
|
|
35
43
|
|
36
|
-
class PromptBuilder(Generic[InT, CtxT]):
|
44
|
+
class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
|
45
|
+
_generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {0: "_in_type"}
|
46
|
+
|
37
47
|
def __init__(
|
38
48
|
self,
|
39
49
|
agent_id: str,
|
@@ -41,17 +51,20 @@ class PromptBuilder(Generic[InT, CtxT]):
|
|
41
51
|
inp_prompt: LLMPrompt | None,
|
42
52
|
sys_args_schema: type[LLMPromptArgs],
|
43
53
|
usr_args_schema: type[LLMPromptArgs],
|
44
|
-
rcv_args_schema: type[InT],
|
45
54
|
):
|
55
|
+
self._in_type: type[InT]
|
56
|
+
super().__init__()
|
57
|
+
|
46
58
|
self._agent_id = agent_id
|
47
59
|
self.sys_prompt = sys_prompt
|
48
60
|
self.inp_prompt = inp_prompt
|
49
61
|
self.sys_args_schema = sys_args_schema
|
50
62
|
self.usr_args_schema = usr_args_schema
|
51
|
-
self.rcv_args_schema = rcv_args_schema
|
52
63
|
self.format_sys_args_impl: FormatSystemArgsHandler[CtxT] | None = None
|
53
64
|
self.format_inp_args_impl: FormatInputArgsHandler[InT, CtxT] | None = None
|
54
65
|
|
66
|
+
self._rcv_args_type_adapter: TypeAdapter[InT] = TypeAdapter(self._in_type)
|
67
|
+
|
55
68
|
def _format_sys_args(
|
56
69
|
self,
|
57
70
|
sys_args: LLMPromptArgs,
|
@@ -66,20 +79,32 @@ class PromptBuilder(Generic[InT, CtxT]):
|
|
66
79
|
self,
|
67
80
|
usr_args: LLMPromptArgs,
|
68
81
|
rcv_args: InT,
|
82
|
+
*,
|
83
|
+
batch_idx: int = 0,
|
69
84
|
ctx: RunContextWrapper[CtxT] | None = None,
|
70
85
|
) -> LLMFormattedArgs:
|
71
86
|
if self.format_inp_args_impl:
|
72
87
|
return self.format_inp_args_impl(
|
73
|
-
usr_args=usr_args, rcv_args=rcv_args, ctx=ctx
|
88
|
+
usr_args=usr_args, rcv_args=rcv_args, batch_idx=batch_idx, ctx=ctx
|
74
89
|
)
|
75
90
|
|
76
|
-
|
77
|
-
|
78
|
-
|
91
|
+
if not isinstance(rcv_args, BaseModel) and rcv_args is not None:
|
92
|
+
raise TypeError(
|
93
|
+
"Cannot apply default formatting to non-BaseModel received arguments."
|
94
|
+
)
|
95
|
+
|
96
|
+
usr_args_ = usr_args
|
97
|
+
rcv_args_ = DummySchema() if rcv_args is None else rcv_args
|
98
|
+
|
99
|
+
usr_args_dump = usr_args_.model_dump(exclude_unset=True)
|
100
|
+
rcv_args_dump = rcv_args_.model_dump(exclude={"selected_recipient_ids"})
|
101
|
+
|
102
|
+
return usr_args_dump | rcv_args_dump
|
79
103
|
|
80
104
|
def make_sys_prompt(
|
81
105
|
self,
|
82
106
|
sys_args: LLMPromptArgs,
|
107
|
+
*,
|
83
108
|
ctx: RunContextWrapper[CtxT] | None,
|
84
109
|
) -> LLMPrompt | None:
|
85
110
|
if self.sys_prompt is None:
|
@@ -100,20 +125,18 @@ class PromptBuilder(Generic[InT, CtxT]):
|
|
100
125
|
def _usr_messages_from_rcv_args(
|
101
126
|
self, rcv_args_batch: Sequence[InT]
|
102
127
|
) -> list[UserMessage]:
|
103
|
-
val_rcv_args_batch = [
|
104
|
-
self.rcv_args_schema.model_validate(rcv) for rcv in rcv_args_batch
|
105
|
-
]
|
106
|
-
|
107
128
|
return [
|
108
129
|
UserMessage.from_text(
|
109
|
-
|
130
|
+
self._rcv_args_type_adapter.dump_json(
|
131
|
+
rcv,
|
110
132
|
exclude_unset=True,
|
111
133
|
indent=2,
|
112
134
|
exclude={"selected_recipient_ids"},
|
113
|
-
|
135
|
+
warnings="error",
|
136
|
+
).decode("utf-8"),
|
114
137
|
model_id=self._agent_id,
|
115
138
|
)
|
116
|
-
for rcv in
|
139
|
+
for rcv in rcv_args_batch
|
117
140
|
]
|
118
141
|
|
119
142
|
def _usr_messages_from_prompt_template(
|
@@ -123,17 +146,21 @@ class PromptBuilder(Generic[InT, CtxT]):
|
|
123
146
|
rcv_args_batch: Sequence[InT] | None = None,
|
124
147
|
ctx: RunContextWrapper[CtxT] | None = None,
|
125
148
|
) -> Sequence[UserMessage]:
|
126
|
-
|
127
|
-
|
128
|
-
|
149
|
+
usr_args_batch_, rcv_args_batch_ = self._make_batched(usr_args, rcv_args_batch)
|
150
|
+
|
151
|
+
val_usr_args_batch_ = [
|
152
|
+
self.usr_args_schema.model_validate(u) for u in usr_args_batch_
|
129
153
|
]
|
130
|
-
|
131
|
-
self.
|
154
|
+
val_rcv_args_batch_ = [
|
155
|
+
self._rcv_args_type_adapter.validate_python(rcv) for rcv in rcv_args_batch_
|
132
156
|
]
|
157
|
+
|
133
158
|
formatted_inp_args_batch = [
|
134
|
-
self._format_inp_args(
|
135
|
-
|
136
|
-
|
159
|
+
self._format_inp_args(
|
160
|
+
usr_args=val_usr_args, rcv_args=val_rcv_args, batch_idx=i, ctx=ctx
|
161
|
+
)
|
162
|
+
for i, (val_usr_args, val_rcv_args) in enumerate(
|
163
|
+
zip(val_usr_args_batch_, val_rcv_args_batch_, strict=False)
|
137
164
|
)
|
138
165
|
]
|
139
166
|
|
@@ -187,11 +214,11 @@ class PromptBuilder(Generic[InT, CtxT]):
|
|
187
214
|
self,
|
188
215
|
usr_args: UserRunArgs | None = None,
|
189
216
|
rcv_args_batch: Sequence[InT] | None = None,
|
190
|
-
) -> tuple[Sequence[LLMPromptArgs], Sequence[InT]]:
|
217
|
+
) -> tuple[Sequence[LLMPromptArgs | DummySchema], Sequence[InT | DummySchema]]:
|
191
218
|
usr_args_batch_ = (
|
192
|
-
usr_args if isinstance(usr_args, list) else [usr_args or
|
219
|
+
usr_args if isinstance(usr_args, list) else [usr_args or DummySchema()]
|
193
220
|
)
|
194
|
-
rcv_args_batch_ = rcv_args_batch or [
|
221
|
+
rcv_args_batch_ = rcv_args_batch or [DummySchema()]
|
195
222
|
|
196
223
|
# Broadcast singleton → match lengths
|
197
224
|
if len(usr_args_batch_) == 1 and len(rcv_args_batch_) > 1:
|
@@ -201,4 +228,4 @@ class PromptBuilder(Generic[InT, CtxT]):
|
|
201
228
|
if len(usr_args_batch_) != len(rcv_args_batch_):
|
202
229
|
raise ValueError("User args and received args must have the same length")
|
203
230
|
|
204
|
-
return usr_args_batch_, rcv_args_batch_
|
231
|
+
return usr_args_batch_, rcv_args_batch_
|