grasp_agents 0.1.18__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
grasp_agents/llm_agent.py CHANGED
@@ -1,6 +1,6 @@
1
1
  from collections.abc import Sequence
2
2
  from pathlib import Path
3
- from typing import Any, Generic, cast, final
3
+ from typing import Any, ClassVar, Generic, Protocol, cast, final
4
4
 
5
5
  from pydantic import BaseModel
6
6
 
@@ -10,7 +10,7 @@ from .comm_agent import CommunicatingAgent
10
10
  from .llm import LLM, LLMSettings
11
11
  from .llm_agent_state import (
12
12
  LLMAgentState,
13
- MakeCustomAgentState,
13
+ SetAgentState,
14
14
  SetAgentStateStrategy,
15
15
  )
16
16
  from .prompt_builder import (
@@ -25,12 +25,15 @@ from .run_context import (
25
25
  SystemRunArgs,
26
26
  UserRunArgs,
27
27
  )
28
- from .tool_orchestrator import ToolCallLoopExitHandler, ToolOrchestrator
28
+ from .tool_orchestrator import (
29
+ ExitToolCallLoopHandler,
30
+ ManageAgentStateHandler,
31
+ ToolOrchestrator,
32
+ )
29
33
  from .typing.content import ImageData
30
34
  from .typing.converters import Converters
31
35
  from .typing.io import (
32
36
  AgentID,
33
- AgentPayload,
34
37
  AgentState,
35
38
  InT,
36
39
  LLMFormattedArgs,
@@ -41,13 +44,29 @@ from .typing.io import (
41
44
  )
42
45
  from .typing.message import Conversation, Message, SystemMessage
43
46
  from .typing.tool import BaseTool
44
- from .utils import get_prompt
47
+ from .utils import get_prompt, validate_obj_from_json_or_py_string
48
+
49
+
50
+ class ParseOutputHandler(Protocol[InT, OutT, CtxT]):
51
+ def __call__(
52
+ self,
53
+ conversation: Conversation,
54
+ *,
55
+ rcv_args: InT | None,
56
+ batch_idx: int,
57
+ ctx: RunContextWrapper[CtxT] | None,
58
+ ) -> OutT: ...
45
59
 
46
60
 
47
61
  class LLMAgent(
48
62
  CommunicatingAgent[InT, OutT, LLMAgentState, CtxT],
49
63
  Generic[InT, OutT, CtxT],
50
64
  ):
65
+ _generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
66
+ 0: "_in_type",
67
+ 1: "_out_type",
68
+ }
69
+
51
70
  def __init__(
52
71
  self,
53
72
  agent_id: AgentID,
@@ -64,10 +83,6 @@ class LLMAgent(
64
83
  sys_args_schema: type[LLMPromptArgs] = LLMPromptArgs,
65
84
  # User args (static args provided via RunContextWrapper)
66
85
  usr_args_schema: type[LLMPromptArgs] = LLMPromptArgs,
67
- # Received args (args from another agent)
68
- rcv_args_schema: type[InT] = cast("type[InT]", AgentPayload),
69
- # Output schema
70
- out_schema: type[OutT] = cast("type[OutT]", AgentPayload),
71
86
  # Tools
72
87
  tools: list[BaseTool[Any, Any, CtxT]] | None = None,
73
88
  max_turns: int = 1000,
@@ -79,19 +94,21 @@ class LLMAgent(
79
94
  recipient_ids: list[AgentID] | None = None,
80
95
  ) -> None:
81
96
  super().__init__(
82
- agent_id=agent_id,
83
- out_schema=out_schema,
84
- rcv_args_schema=rcv_args_schema,
85
- message_pool=message_pool,
86
- recipient_ids=recipient_ids,
97
+ agent_id=agent_id, message_pool=message_pool, recipient_ids=recipient_ids
87
98
  )
88
99
 
89
100
  # Agent state
90
101
  self._state: LLMAgentState = LLMAgentState()
91
102
  self.set_state_strategy: SetAgentStateStrategy = set_state_strategy
92
- self._make_custom_agent_state_impl: MakeCustomAgentState | None = None
103
+ self._set_agent_state_impl: SetAgentState | None = None
93
104
 
94
105
  # Tool orchestrator
106
+
107
+ self._using_default_llm_response_format: bool = False
108
+ if llm.response_format is None and tools is None:
109
+ llm.response_format = self.out_type
110
+ self._using_default_llm_response_format = True
111
+
95
112
  self._tool_orchestrator: ToolOrchestrator[CtxT] = ToolOrchestrator[CtxT](
96
113
  agent_id=self.agent_id,
97
114
  llm=llm,
@@ -103,28 +120,19 @@ class LLMAgent(
103
120
  # Prompt builder
104
121
  sys_prompt = get_prompt(prompt_text=sys_prompt, prompt_path=sys_prompt_path)
105
122
  inp_prompt = get_prompt(prompt_text=inp_prompt, prompt_path=inp_prompt_path)
106
- self._prompt_builder: PromptBuilder[InT, CtxT] = PromptBuilder[InT, CtxT](
123
+ self._prompt_builder: PromptBuilder[InT, CtxT] = PromptBuilder[
124
+ self.in_type, CtxT
125
+ ](
107
126
  agent_id=self._agent_id,
108
127
  sys_prompt=sys_prompt,
109
128
  inp_prompt=inp_prompt,
110
129
  sys_args_schema=sys_args_schema,
111
130
  usr_args_schema=usr_args_schema,
112
- rcv_args_schema=rcv_args_schema,
113
131
  )
114
132
 
115
133
  self.no_tqdm = getattr(llm, "no_tqdm", False)
116
134
 
117
- if type(self)._format_sys_args is not LLMAgent[Any, Any, Any]._format_sys_args: # noqa: SLF001
118
- self._prompt_builder.format_sys_args_impl = self._format_sys_args
119
-
120
- if type(self)._format_inp_args is not LLMAgent[Any, Any, Any]._format_inp_args: # noqa: SLF001
121
- self._prompt_builder.format_inp_args_impl = self._format_inp_args
122
-
123
- if (
124
- type(self)._tool_call_loop_exit # noqa: SLF001
125
- is not LLMAgent[Any, Any, Any]._tool_call_loop_exit # noqa: SLF001
126
- ):
127
- self._tool_orchestrator.tool_call_loop_exit_impl = self._tool_call_loop_exit
135
+ self._register_overridden_handlers()
128
136
 
129
137
  @property
130
138
  def llm(self) -> LLM[LLMSettings, Converters]:
@@ -159,17 +167,29 @@ class LLMAgent(
159
167
  conversation: Conversation,
160
168
  *,
161
169
  rcv_args: InT | None = None,
170
+ batch_idx: int = 0,
162
171
  ctx: RunContextWrapper[CtxT] | None = None,
163
- **kwargs: Any,
164
172
  ) -> OutT:
165
173
  if self._parse_output_impl:
174
+ if self._using_default_llm_response_format:
175
+ # When using custom output parsing, the required LLM response format
176
+ # can differ from the final agent output type ->
177
+ # set it back to None unless it was specified explicitly at init.
178
+ self._tool_orchestrator.llm.response_format = None
179
+ self._using_default_llm_response_format = False
180
+
166
181
  return self._parse_output_impl(
167
- conversation=conversation, rcv_args=rcv_args, ctx=ctx, **kwargs
182
+ conversation=conversation,
183
+ rcv_args=rcv_args,
184
+ batch_idx=batch_idx,
185
+ ctx=ctx,
168
186
  )
169
- try:
170
- return self._out_schema.model_validate_json(str(conversation[-1].content))
171
- except Exception:
172
- return self._out_schema()
187
+
188
+ return validate_obj_from_json_or_py_string(
189
+ str(conversation[-1].content),
190
+ adapter=self._out_type_adapter,
191
+ from_substring=True,
192
+ )
173
193
 
174
194
  @final
175
195
  async def run(
@@ -177,7 +197,7 @@ class LLMAgent(
177
197
  inp_items: LLMPrompt | list[str | ImageData] | None = None,
178
198
  *,
179
199
  ctx: RunContextWrapper[CtxT] | None = None,
180
- rcv_message: AgentMessage[InT, LLMAgentState] | None = None,
200
+ rcv_message: AgentMessage[InT, AgentState] | None = None,
181
201
  entry_point: bool = False,
182
202
  forbid_state_change: bool = False,
183
203
  **gen_kwargs: Any, # noqa: ARG002
@@ -217,7 +237,7 @@ class LLMAgent(
217
237
  rcv_state=rcv_state,
218
238
  sys_prompt=formatted_sys_prompt,
219
239
  strategy=self.set_state_strategy,
220
- make_custom_state_impl=self._make_custom_agent_state_impl,
240
+ set_agent_state_impl=self._set_agent_state_impl,
221
241
  ctx=ctx,
222
242
  )
223
243
 
@@ -241,13 +261,13 @@ class LLMAgent(
241
261
  else:
242
262
  # 4. Run tool call loop (new messages are added to the message
243
263
  # history inside the loop)
244
- await self._tool_orchestrator.run_loop(agent_state=state, ctx=ctx)
264
+ await self._tool_orchestrator.run_loop(state=state, ctx=ctx)
245
265
 
246
266
  # 5. Parse outputs
247
267
  batch_size = state.message_history.batch_size
248
268
  rcv_args_batch = rcv_message.payloads if rcv_message else batch_size * [None]
249
269
  val_output_batch = [
250
- self.out_schema.model_validate(
270
+ self._out_type_adapter.validate_python(
251
271
  self._parse_output(conversation=conv, rcv_args=rcv_args, ctx=ctx)
252
272
  )
253
273
  for conv, rcv_args in zip(
@@ -276,7 +296,7 @@ class LLMAgent(
276
296
  )
277
297
  ctx.interaction_history.append(
278
298
  cast(
279
- "InteractionRecord[AgentPayload, AgentPayload, AgentState]",
299
+ "InteractionRecord[Any, Any, AgentState]",
280
300
  interaction_record,
281
301
  )
282
302
  )
@@ -330,22 +350,57 @@ class LLMAgent(
330
350
 
331
351
  return func
332
352
 
333
- def make_custom_agent_state_handler(
334
- self, func: MakeCustomAgentState
335
- ) -> MakeCustomAgentState:
353
+ def parse_output_handler(
354
+ self, func: ParseOutputHandler[InT, OutT, CtxT]
355
+ ) -> ParseOutputHandler[InT, OutT, CtxT]:
356
+ self._parse_output_impl = func
357
+
358
+ return func
359
+
360
+ def set_agent_state_handler(self, func: SetAgentState) -> SetAgentState:
336
361
  self._make_custom_agent_state_impl = func
337
362
 
338
363
  return func
339
364
 
340
- def tool_call_loop_exit_handler(
341
- self, func: ToolCallLoopExitHandler[CtxT]
342
- ) -> ToolCallLoopExitHandler[CtxT]:
343
- self._tool_orchestrator.tool_call_loop_exit_impl = func
365
+ def exit_tool_call_loop_handler(
366
+ self, func: ExitToolCallLoopHandler[CtxT]
367
+ ) -> ExitToolCallLoopHandler[CtxT]:
368
+ self._tool_orchestrator.exit_tool_call_loop_impl = func
369
+
370
+ return func
371
+
372
+ def manage_agent_state_handler(
373
+ self, func: ManageAgentStateHandler[CtxT]
374
+ ) -> ManageAgentStateHandler[CtxT]:
375
+ self._tool_orchestrator.manage_agent_state_impl = func
344
376
 
345
377
  return func
346
378
 
347
379
  # -- Override these methods in subclasses if needed --
348
380
 
381
+ def _register_overridden_handlers(self) -> None:
382
+ cur_cls = type(self)
383
+ base_cls = LLMAgent[Any, Any, Any]
384
+
385
+ if cur_cls._format_sys_args is not base_cls._format_sys_args: # noqa: SLF001
386
+ self._prompt_builder.format_sys_args_impl = self._format_sys_args
387
+
388
+ if cur_cls._format_inp_args is not base_cls._format_inp_args: # noqa: SLF001
389
+ self._prompt_builder.format_inp_args_impl = self._format_inp_args
390
+
391
+ if cur_cls._set_agent_state is not base_cls._set_agent_state: # noqa: SLF001
392
+ self._set_agent_state_impl = self._set_agent_state
393
+
394
+ if cur_cls._manage_agent_state is not base_cls._manage_agent_state: # noqa: SLF001
395
+ self._tool_orchestrator.manage_agent_state_impl = self._manage_agent_state
396
+
397
+ if (
398
+ cur_cls._tool_call_loop_exit is not base_cls._tool_call_loop_exit # noqa: SLF001
399
+ ):
400
+ self._tool_orchestrator.exit_tool_call_loop_impl = self._tool_call_loop_exit
401
+
402
+ self._parse_output_impl: ParseOutputHandler[InT, OutT, CtxT] | None = None
403
+
349
404
  def _format_sys_args(
350
405
  self,
351
406
  sys_args: LLMPromptArgs,
@@ -361,12 +416,26 @@ class LLMAgent(
361
416
  self,
362
417
  usr_args: LLMPromptArgs,
363
418
  rcv_args: InT,
419
+ *,
420
+ batch_idx: int = 0,
364
421
  ctx: RunContextWrapper[CtxT] | None = None,
365
422
  ) -> LLMFormattedArgs:
366
423
  raise NotImplementedError(
367
424
  "LLMAgent._format_inp_args must be overridden by a subclass"
368
425
  )
369
426
 
427
+ def _set_agent_state(
428
+ self,
429
+ cur_state: LLMAgentState,
430
+ *,
431
+ rcv_state: AgentState | None,
432
+ sys_prompt: LLMPrompt | None,
433
+ ctx: RunContextWrapper[Any] | None,
434
+ ) -> LLMAgentState:
435
+ raise NotImplementedError(
436
+ "LLMAgent._set_agent_state_handler must be overridden by a subclass"
437
+ )
438
+
370
439
  def _tool_call_loop_exit(
371
440
  self,
372
441
  conversation: Conversation,
@@ -377,3 +446,14 @@ class LLMAgent(
377
446
  raise NotImplementedError(
378
447
  "LLMAgent._tool_call_loop_exit must be overridden by a subclass"
379
448
  )
449
+
450
+ def _manage_agent_state(
451
+ self,
452
+ state: LLMAgentState,
453
+ *,
454
+ ctx: RunContextWrapper[CtxT] | None = None,
455
+ **kwargs: Any,
456
+ ) -> None:
457
+ raise NotImplementedError(
458
+ "LLMAgent._manage_agent_state must be overridden by a subclass"
459
+ )
@@ -10,11 +10,12 @@ from .typing.io import AgentState, LLMPrompt
10
10
  SetAgentStateStrategy = Literal["keep", "reset", "from_sender", "custom"]
11
11
 
12
12
 
13
- class MakeCustomAgentState(Protocol):
13
+ class SetAgentState(Protocol):
14
14
  def __call__(
15
15
  self,
16
- cur_state: Optional["LLMAgentState"],
17
- rec_state: Optional["LLMAgentState"],
16
+ cur_state: "LLMAgentState",
17
+ *,
18
+ rcv_state: AgentState | None,
18
19
  sys_prompt: LLMPrompt | None,
19
20
  ctx: RunContextWrapper[Any] | None,
20
21
  ) -> "LLMAgentState": ...
@@ -30,11 +31,12 @@ class LLMAgentState(AgentState):
30
31
  @classmethod
31
32
  def from_cur_and_rcv_states(
32
33
  cls,
33
- cur_state: Optional["LLMAgentState"] = None,
34
- rcv_state: Optional["LLMAgentState"] = None,
34
+ cur_state: "LLMAgentState",
35
+ *,
36
+ rcv_state: Optional["AgentState"] = None,
35
37
  sys_prompt: LLMPrompt | None = None,
36
38
  strategy: SetAgentStateStrategy = "from_sender",
37
- make_custom_state_impl: MakeCustomAgentState | None = None,
39
+ set_agent_state_impl: SetAgentState | None = None,
38
40
  ctx: RunContextWrapper[Any] | None = None,
39
41
  ) -> "LLMAgentState":
40
42
  upd_mh = cur_state.message_history if cur_state else None
@@ -48,24 +50,28 @@ class LLMAgentState(AgentState):
48
50
  upd_mh.reset(sys_prompt)
49
51
 
50
52
  elif strategy == "from_sender":
51
- rcv_mh = rcv_state.message_history if rcv_state else None
53
+ rcv_mh = (
54
+ rcv_state.message_history
55
+ if rcv_state and isinstance(rcv_state, "LLMAgentState")
56
+ else None
57
+ )
52
58
  if rcv_mh:
53
59
  upd_mh = deepcopy(rcv_mh)
54
60
  else:
55
61
  upd_mh.reset(sys_prompt)
56
62
 
57
63
  elif strategy == "custom":
58
- assert make_custom_state_impl is not None, (
59
- "Custom message history handler implementation is not provided."
64
+ assert set_agent_state_impl is not None, (
65
+ "Agent state setter implementation is not provided."
60
66
  )
61
- return make_custom_state_impl(
67
+ return set_agent_state_impl(
62
68
  cur_state=cur_state,
63
- rec_state=rcv_state,
69
+ rcv_state=rcv_state,
64
70
  sys_prompt=sys_prompt,
65
71
  ctx=ctx,
66
72
  )
67
73
 
68
- return cls(message_history=upd_mh)
74
+ return cls.model_construct(message_history=upd_mh)
69
75
 
70
76
  def __repr__(self) -> str:
71
77
  return f"Message History: {len(self.message_history)}"
@@ -1,11 +1,13 @@
1
1
  from collections.abc import Sequence
2
2
  from copy import deepcopy
3
- from typing import Generic, Protocol
3
+ from typing import ClassVar, Generic, Protocol
4
4
 
5
+ from pydantic import BaseModel, TypeAdapter
6
+
7
+ from .generics_utils import AutoInstanceAttributesMixin
5
8
  from .run_context import CtxT, RunContextWrapper, UserRunArgs
6
9
  from .typing.content import ImageData
7
10
  from .typing.io import (
8
- AgentPayload,
9
11
  InT,
10
12
  LLMFormattedArgs,
11
13
  LLMFormattedSystemArgs,
@@ -15,6 +17,10 @@ from .typing.io import (
15
17
  from .typing.message import UserMessage
16
18
 
17
19
 
20
+ class DummySchema(BaseModel):
21
+ pass
22
+
23
+
18
24
  class FormatSystemArgsHandler(Protocol[CtxT]):
19
25
  def __call__(
20
26
  self,
@@ -29,11 +35,15 @@ class FormatInputArgsHandler(Protocol[InT, CtxT]):
29
35
  self,
30
36
  usr_args: LLMPromptArgs,
31
37
  rcv_args: InT,
38
+ *,
39
+ batch_idx: int,
32
40
  ctx: RunContextWrapper[CtxT] | None,
33
41
  ) -> LLMFormattedArgs: ...
34
42
 
35
43
 
36
- class PromptBuilder(Generic[InT, CtxT]):
44
+ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
45
+ _generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {0: "_in_type"}
46
+
37
47
  def __init__(
38
48
  self,
39
49
  agent_id: str,
@@ -41,17 +51,20 @@ class PromptBuilder(Generic[InT, CtxT]):
41
51
  inp_prompt: LLMPrompt | None,
42
52
  sys_args_schema: type[LLMPromptArgs],
43
53
  usr_args_schema: type[LLMPromptArgs],
44
- rcv_args_schema: type[InT],
45
54
  ):
55
+ self._in_type: type[InT]
56
+ super().__init__()
57
+
46
58
  self._agent_id = agent_id
47
59
  self.sys_prompt = sys_prompt
48
60
  self.inp_prompt = inp_prompt
49
61
  self.sys_args_schema = sys_args_schema
50
62
  self.usr_args_schema = usr_args_schema
51
- self.rcv_args_schema = rcv_args_schema
52
63
  self.format_sys_args_impl: FormatSystemArgsHandler[CtxT] | None = None
53
64
  self.format_inp_args_impl: FormatInputArgsHandler[InT, CtxT] | None = None
54
65
 
66
+ self._rcv_args_type_adapter: TypeAdapter[InT] = TypeAdapter(self._in_type)
67
+
55
68
  def _format_sys_args(
56
69
  self,
57
70
  sys_args: LLMPromptArgs,
@@ -66,20 +79,32 @@ class PromptBuilder(Generic[InT, CtxT]):
66
79
  self,
67
80
  usr_args: LLMPromptArgs,
68
81
  rcv_args: InT,
82
+ *,
83
+ batch_idx: int = 0,
69
84
  ctx: RunContextWrapper[CtxT] | None = None,
70
85
  ) -> LLMFormattedArgs:
71
86
  if self.format_inp_args_impl:
72
87
  return self.format_inp_args_impl(
73
- usr_args=usr_args, rcv_args=rcv_args, ctx=ctx
88
+ usr_args=usr_args, rcv_args=rcv_args, batch_idx=batch_idx, ctx=ctx
74
89
  )
75
90
 
76
- return usr_args.model_dump(exclude_unset=True) | rcv_args.model_dump(
77
- exclude_unset=True, exclude={"selected_recipient_ids"}
78
- )
91
+ if not isinstance(rcv_args, BaseModel) and rcv_args is not None:
92
+ raise TypeError(
93
+ "Cannot apply default formatting to non-BaseModel received arguments."
94
+ )
95
+
96
+ usr_args_ = usr_args
97
+ rcv_args_ = DummySchema() if rcv_args is None else rcv_args
98
+
99
+ usr_args_dump = usr_args_.model_dump(exclude_unset=True)
100
+ rcv_args_dump = rcv_args_.model_dump(exclude={"selected_recipient_ids"})
101
+
102
+ return usr_args_dump | rcv_args_dump
79
103
 
80
104
  def make_sys_prompt(
81
105
  self,
82
106
  sys_args: LLMPromptArgs,
107
+ *,
83
108
  ctx: RunContextWrapper[CtxT] | None,
84
109
  ) -> LLMPrompt | None:
85
110
  if self.sys_prompt is None:
@@ -100,20 +125,18 @@ class PromptBuilder(Generic[InT, CtxT]):
100
125
  def _usr_messages_from_rcv_args(
101
126
  self, rcv_args_batch: Sequence[InT]
102
127
  ) -> list[UserMessage]:
103
- val_rcv_args_batch = [
104
- self.rcv_args_schema.model_validate(rcv) for rcv in rcv_args_batch
105
- ]
106
-
107
128
  return [
108
129
  UserMessage.from_text(
109
- rcv.model_dump_json(
130
+ self._rcv_args_type_adapter.dump_json(
131
+ rcv,
110
132
  exclude_unset=True,
111
133
  indent=2,
112
134
  exclude={"selected_recipient_ids"},
113
- ),
135
+ warnings="error",
136
+ ).decode("utf-8"),
114
137
  model_id=self._agent_id,
115
138
  )
116
- for rcv in val_rcv_args_batch
139
+ for rcv in rcv_args_batch
117
140
  ]
118
141
 
119
142
  def _usr_messages_from_prompt_template(
@@ -123,17 +146,21 @@ class PromptBuilder(Generic[InT, CtxT]):
123
146
  rcv_args_batch: Sequence[InT] | None = None,
124
147
  ctx: RunContextWrapper[CtxT] | None = None,
125
148
  ) -> Sequence[UserMessage]:
126
- usr_args_batch, rcv_args_batch = self._make_batched(usr_args, rcv_args_batch)
127
- val_usr_args_batch = [
128
- self.usr_args_schema.model_validate(u) for u in usr_args_batch
149
+ usr_args_batch_, rcv_args_batch_ = self._make_batched(usr_args, rcv_args_batch)
150
+
151
+ val_usr_args_batch_ = [
152
+ self.usr_args_schema.model_validate(u) for u in usr_args_batch_
129
153
  ]
130
- val_rcv_args_batch = [
131
- self.rcv_args_schema.model_validate(r) for r in rcv_args_batch
154
+ val_rcv_args_batch_ = [
155
+ self._rcv_args_type_adapter.validate_python(rcv) for rcv in rcv_args_batch_
132
156
  ]
157
+
133
158
  formatted_inp_args_batch = [
134
- self._format_inp_args(usr_args=val_usr_args, rcv_args=val_rcv_args, ctx=ctx)
135
- for val_usr_args, val_rcv_args in zip(
136
- val_usr_args_batch, val_rcv_args_batch, strict=False
159
+ self._format_inp_args(
160
+ usr_args=val_usr_args, rcv_args=val_rcv_args, batch_idx=i, ctx=ctx
161
+ )
162
+ for i, (val_usr_args, val_rcv_args) in enumerate(
163
+ zip(val_usr_args_batch_, val_rcv_args_batch_, strict=False)
137
164
  )
138
165
  ]
139
166
 
@@ -187,11 +214,11 @@ class PromptBuilder(Generic[InT, CtxT]):
187
214
  self,
188
215
  usr_args: UserRunArgs | None = None,
189
216
  rcv_args_batch: Sequence[InT] | None = None,
190
- ) -> tuple[Sequence[LLMPromptArgs], Sequence[InT]]:
217
+ ) -> tuple[Sequence[LLMPromptArgs | DummySchema], Sequence[InT | DummySchema]]:
191
218
  usr_args_batch_ = (
192
- usr_args if isinstance(usr_args, list) else [usr_args or LLMPromptArgs()]
219
+ usr_args if isinstance(usr_args, list) else [usr_args or DummySchema()]
193
220
  )
194
- rcv_args_batch_ = rcv_args_batch or [AgentPayload()]
221
+ rcv_args_batch_ = rcv_args_batch or [DummySchema()]
195
222
 
196
223
  # Broadcast singleton → match lengths
197
224
  if len(usr_args_batch_) == 1 and len(rcv_args_batch_) > 1:
@@ -201,4 +228,4 @@ class PromptBuilder(Generic[InT, CtxT]):
201
228
  if len(usr_args_batch_) != len(rcv_args_batch_):
202
229
  raise ValueError("User args and received args must have the same length")
203
230
 
204
- return usr_args_batch_, rcv_args_batch_ # type: ignore
231
+ return usr_args_batch_, rcv_args_batch_