grasp_agents 0.2.10__py3-none-any.whl → 0.2.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
grasp_agents/llm_agent.py CHANGED
@@ -18,13 +18,7 @@ from .prompt_builder import (
18
18
  FormatSystemArgsHandler,
19
19
  PromptBuilder,
20
20
  )
21
- from .run_context import (
22
- CtxT,
23
- InteractionRecord,
24
- RunContextWrapper,
25
- SystemRunArgs,
26
- UserRunArgs,
27
- )
21
+ from .run_context import CtxT, InteractionRecord, RunContextWrapper
28
22
  from .tool_orchestrator import (
29
23
  ExitToolCallLoopHandler,
30
24
  ManageAgentStateHandler,
@@ -226,8 +220,8 @@ class LLMAgent(
226
220
  **gen_kwargs: Any, # noqa: ARG002
227
221
  ) -> AgentMessage[OutT, LLMAgentState]:
228
222
  # Get run arguments
229
- sys_args: SystemRunArgs = LLMPromptArgs()
230
- usr_args: UserRunArgs = LLMPromptArgs()
223
+ sys_args: LLMPromptArgs = LLMPromptArgs()
224
+ usr_args: LLMPromptArgs | Sequence[LLMPromptArgs] = LLMPromptArgs()
231
225
  if ctx is not None:
232
226
  run_args = ctx.run_args.get(self.agent_id)
233
227
  if run_args is not None:
@@ -240,6 +234,7 @@ class LLMAgent(
240
234
  in_message=in_message,
241
235
  entry_point=entry_point,
242
236
  )
237
+ resolved_in_args = in_message.payloads if in_message else in_args
243
238
 
244
239
  # 1. Make system prompt (can be None)
245
240
  formatted_sys_prompt = self._prompt_builder.make_sys_prompt(
@@ -264,16 +259,11 @@ class LLMAgent(
264
259
  self._print_sys_msg(state=state, prev_mh_len=prev_mh_len, ctx=ctx)
265
260
 
266
261
  # 3. Make and add user messages (can be empty)
267
- _in_args_batch: Sequence[InT] | None = None
268
- if in_message is not None:
269
- _in_args_batch = in_message.payloads
270
- elif in_args is not None:
271
- _in_args_batch = in_args if isinstance(in_args, Sequence) else [in_args] # type: ignore[assignment]
272
262
 
273
263
  user_message_batch = self._prompt_builder.make_user_messages(
274
264
  chat_inputs=chat_inputs,
265
+ in_args=resolved_in_args,
275
266
  usr_args=usr_args,
276
- in_args_batch=_in_args_batch,
277
267
  entry_point=entry_point,
278
268
  ctx=ctx,
279
269
  )
@@ -290,18 +280,23 @@ class LLMAgent(
290
280
  await self._tool_orchestrator.run_loop(state=state, ctx=ctx)
291
281
 
292
282
  # 5. Parse outputs
293
- batch_size = state.message_history.batch_size
294
- in_args_batch = in_message.payloads if in_message else batch_size * [None]
295
- val_output_batch = [
296
- self._out_type_adapter.validate_python(
297
- self._parse_output(conversation=conv, in_args=in_args, ctx=ctx)
298
- )
299
- for conv, in_args in zip(
300
- state.message_history.batched_conversations,
301
- in_args_batch,
302
- strict=False,
283
+
284
+ val_output_batch: list[OutT] = []
285
+ for i, _conv in enumerate(state.message_history.batched_conversations):
286
+ if isinstance(resolved_in_args, Sequence):
287
+ _resolved_in_args = cast("Sequence[InT]", resolved_in_args)
288
+ _in_args = _resolved_in_args[min(i, len(_resolved_in_args) - 1)]
289
+ else:
290
+ _resolved_in_args = cast("InT | None", resolved_in_args)
291
+ _in_args = _resolved_in_args
292
+
293
+ val_output_batch.append(
294
+ self._out_type_adapter.validate_python(
295
+ self._parse_output(
296
+ conversation=_conv, in_args=_in_args, batch_idx=i, ctx=ctx
297
+ )
298
+ )
303
299
  )
304
- ]
305
300
 
306
301
  # 6. Write interaction history to context
307
302
 
@@ -316,7 +311,7 @@ class LLMAgent(
316
311
  in_prompt=self.in_prompt,
317
312
  sys_args=sys_args,
318
313
  usr_args=usr_args,
319
- in_args=(in_message.payloads if in_message is not None else None),
314
+ in_args=resolved_in_args, # type: ignore[valid-type]
320
315
  outputs=val_output_batch,
321
316
  state=state,
322
317
  )
@@ -1,11 +1,11 @@
1
1
  from collections.abc import Sequence
2
2
  from copy import deepcopy
3
- from typing import ClassVar, Generic, Protocol
3
+ from typing import ClassVar, Generic, Protocol, cast
4
4
 
5
5
  from pydantic import BaseModel, TypeAdapter
6
6
 
7
7
  from .generics_utils import AutoInstanceAttributesMixin
8
- from .run_context import CtxT, RunContextWrapper, UserRunArgs
8
+ from .run_context import CtxT, RunContextWrapper
9
9
  from .typing.content import ImageData
10
10
  from .typing.io import (
11
11
  InT,
@@ -142,11 +142,13 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
142
142
  def _usr_messages_from_prompt_template(
143
143
  self,
144
144
  in_prompt: LLMPrompt,
145
- usr_args: UserRunArgs | None = None,
146
145
  in_args_batch: Sequence[InT] | None = None,
146
+ usr_args_batch: Sequence[LLMPromptArgs] | None = None,
147
147
  ctx: RunContextWrapper[CtxT] | None = None,
148
148
  ) -> Sequence[UserMessage]:
149
- usr_args_batch_, in_args_batch_ = self._make_batched(usr_args, in_args_batch)
149
+ usr_args_batch_, in_args_batch_ = self._align_in_usr_batches(
150
+ in_args_batch, usr_args_batch
151
+ )
150
152
 
151
153
  val_usr_args_batch_ = [
152
154
  self.usr_args_schema.model_validate(u) for u in usr_args_batch_
@@ -174,8 +176,8 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
174
176
  def make_user_messages(
175
177
  self,
176
178
  chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
177
- usr_args: UserRunArgs | None = None,
178
- in_args_batch: Sequence[InT] | None = None,
179
+ in_args: InT | Sequence[InT] | None = None,
180
+ usr_args: LLMPromptArgs | Sequence[LLMPromptArgs] | None = None,
179
181
  entry_point: bool = False,
180
182
  ctx: RunContextWrapper[CtxT] | None = None,
181
183
  ) -> Sequence[UserMessage]:
@@ -196,10 +198,24 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
196
198
  if isinstance(chat_inputs, Sequence) and chat_inputs:
197
199
  return self._usr_messages_from_content_parts(chat_inputs)
198
200
 
201
+ in_args_batch = cast(
202
+ "Sequence[InT] | None",
203
+ in_args if (isinstance(in_args, Sequence) or not in_args) else [in_args],
204
+ )
205
+
199
206
  # 2) No input prompt template + received args → raw JSON messages
200
207
  if self.in_prompt is None and in_args_batch:
201
208
  return self._usr_messages_from_in_args(in_args_batch)
202
209
 
210
+ usr_args_batch = cast(
211
+ "Sequence[LLMPromptArgs] | None",
212
+ (
213
+ usr_args
214
+ if (isinstance(usr_args, Sequence) or not usr_args)
215
+ else [usr_args]
216
+ ),
217
+ )
218
+
203
219
  # 3) Input prompt template + any args → batch & format
204
220
  if self.in_prompt is not None:
205
221
  if in_args_batch and not isinstance(in_args_batch[0], BaseModel):
@@ -209,21 +225,19 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
209
225
  )
210
226
  return self._usr_messages_from_prompt_template(
211
227
  in_prompt=self.in_prompt,
212
- usr_args=usr_args,
228
+ usr_args_batch=usr_args_batch,
213
229
  in_args_batch=in_args_batch,
214
230
  ctx=ctx,
215
231
  )
216
232
 
217
233
  return []
218
234
 
219
- def _make_batched(
235
+ def _align_in_usr_batches(
220
236
  self,
221
- usr_args: UserRunArgs | None = None,
222
237
  in_args_batch: Sequence[InT] | None = None,
238
+ usr_args_batch: Sequence[LLMPromptArgs] | None = None,
223
239
  ) -> tuple[Sequence[LLMPromptArgs | DummySchema], Sequence[InT | DummySchema]]:
224
- usr_args_batch_ = (
225
- usr_args if isinstance(usr_args, list) else [usr_args or DummySchema()]
226
- )
240
+ usr_args_batch_ = usr_args_batch or [DummySchema()]
227
241
  in_args_batch_ = in_args_batch or [DummySchema()]
228
242
 
229
243
  # Broadcast singleton → match lengths
@@ -17,13 +17,10 @@ from .typing.io import (
17
17
  )
18
18
  from .usage_tracker import UsageTracker
19
19
 
20
- SystemRunArgs: TypeAlias = LLMPromptArgs
21
- UserRunArgs: TypeAlias = LLMPromptArgs | list[LLMPromptArgs]
22
-
23
20
 
24
21
  class RunArgs(BaseModel):
25
- sys: SystemRunArgs = Field(default_factory=LLMPromptArgs)
26
- usr: UserRunArgs = Field(default_factory=LLMPromptArgs)
22
+ sys: LLMPromptArgs = Field(default_factory=LLMPromptArgs)
23
+ usr: LLMPromptArgs | Sequence[LLMPromptArgs] = Field(default_factory=LLMPromptArgs)
27
24
 
28
25
  model_config = ConfigDict(extra="forbid")
29
26
 
@@ -35,9 +32,9 @@ class InteractionRecord(BaseModel, Generic[InT, OutT, StateT]):
35
32
  chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None
36
33
  sys_prompt: LLMPrompt | None = None
37
34
  in_prompt: LLMPrompt | None = None
38
- sys_args: SystemRunArgs | None = None
39
- usr_args: UserRunArgs | None = None
40
- in_args: Sequence[InT] | None = None
35
+ sys_args: LLMPromptArgs | None = None
36
+ usr_args: LLMPromptArgs | Sequence[LLMPromptArgs] | None = None
37
+ in_args: InT | Sequence[InT] | None = None
41
38
  outputs: Sequence[OutT]
42
39
 
43
40
  model_config = ConfigDict(extra="forbid", frozen=True)
@@ -1,4 +1,5 @@
1
1
  from collections.abc import Sequence
2
+ from itertools import pairwise
2
3
  from logging import getLogger
3
4
  from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast, final
4
5
 
@@ -49,6 +50,19 @@ class LoopedWorkflowAgent(WorkflowAgent[InT, OutT, CtxT], Generic[InT, OutT, Ctx
49
50
  dynamic_routing=dynamic_routing,
50
51
  )
51
52
 
53
+ for prev_agent, agent in pairwise(subagents):
54
+ if prev_agent.out_type != agent.in_type:
55
+ raise ValueError(
56
+ f"Output type {prev_agent.out_type} of agent {prev_agent.agent_id} "
57
+ f"does not match input type {agent.in_type} of agent "
58
+ f"{agent.agent_id}"
59
+ )
60
+ if subagents[-1].out_type != subagents[0].in_type:
61
+ raise ValueError(
62
+ f"Looped workflow's last agent output type {subagents[-1].out_type} "
63
+ f"does not match first agent input type {subagents[0].in_type}"
64
+ )
65
+
52
66
  self._max_iterations = max_iterations
53
67
 
54
68
  self._exit_workflow_loop_impl: ExitWorkflowLoopHandler[OutT, CtxT] | None = None
@@ -1,4 +1,5 @@
1
1
  from collections.abc import Sequence
2
+ from itertools import pairwise
2
3
  from typing import Any, ClassVar, Generic, cast, final
3
4
 
4
5
  from ..agent_message_pool import AgentMessage, AgentMessagePool
@@ -33,6 +34,14 @@ class SequentialWorkflowAgent(WorkflowAgent[InT, OutT, CtxT], Generic[InT, OutT,
33
34
  dynamic_routing=dynamic_routing,
34
35
  )
35
36
 
37
+ for prev_agent, agent in pairwise(subagents):
38
+ if prev_agent.out_type != agent.in_type:
39
+ raise ValueError(
40
+ f"Output type {prev_agent.out_type} of agent {prev_agent.agent_id} "
41
+ f"does not match input type {agent.in_type} of agent "
42
+ f"{agent.agent_id}"
43
+ )
44
+
36
45
  @final
37
46
  async def run(
38
47
  self,
@@ -27,29 +27,44 @@ class WorkflowAgent(
27
27
  dynamic_routing: bool = False,
28
28
  **kwargs: Any, # noqa: ARG002
29
29
  ) -> None:
30
- if not subagents:
31
- raise ValueError("At least one step is required")
30
+ super().__init__(
31
+ agent_id=agent_id,
32
+ message_pool=message_pool,
33
+ recipient_ids=recipient_ids,
34
+ dynamic_routing=dynamic_routing,
35
+ )
36
+
37
+ if len(subagents) < 2:
38
+ raise ValueError("At least two steps are required")
32
39
  if start_agent not in subagents:
33
40
  raise ValueError("Start agent must be in the subagents list")
34
41
  if end_agent not in subagents:
35
42
  raise ValueError("End agent must be in the subagents list")
36
43
 
37
- self.subagents = subagents
44
+ if start_agent.in_type != self.in_type:
45
+ raise ValueError(
46
+ f"Start agent's input type {start_agent.in_type} does not "
47
+ f"match workflow's input type {self._in_type}"
48
+ )
49
+ if end_agent.out_type != self.out_type:
50
+ raise ValueError(
51
+ f"End agent's output type {end_agent.out_type} does not "
52
+ f"match workflow's output type {self._out_type}"
53
+ )
38
54
 
55
+ self._subagents = subagents
39
56
  self._start_agent = start_agent
40
57
  self._end_agent = end_agent
41
58
 
42
- super().__init__(
43
- agent_id=agent_id,
44
- message_pool=message_pool,
45
- recipient_ids=recipient_ids,
46
- dynamic_routing=dynamic_routing,
47
- )
48
59
  for subagent in subagents:
49
60
  assert not subagent.recipient_ids, (
50
61
  "Subagents must not have recipient_ids set."
51
62
  )
52
63
 
64
+ @property
65
+ def subagents(self) -> Sequence[CommunicatingAgent[Any, Any, Any, CtxT]]:
66
+ return self._subagents
67
+
53
68
  @property
54
69
  def start_agent(self) -> CommunicatingAgent[InT, Any, Any, CtxT]:
55
70
  return self._start_agent
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: grasp_agents
3
- Version: 0.2.10
3
+ Version: 0.2.11
4
4
  Summary: Grasp Agents Library
5
5
  License-File: LICENSE.md
6
6
  Requires-Python: <4,>=3.11.4
@@ -9,12 +9,12 @@ grasp_agents/generics_utils.py,sha256=kw4Odte6Nvl4c9U7-mKPgXCavWZXo009zYDHAA0BR3
9
9
  grasp_agents/grasp_logging.py,sha256=H1GYhXdQvVkmauFDZ-KDwvVmPQHZUUm9sRqX_ObK2xI,1111
10
10
  grasp_agents/http_client.py,sha256=KZva2MjJjuI5ohUeU8RdTAImUnQYaqBrV2jDH8smbJw,738
11
11
  grasp_agents/llm.py,sha256=YClNxN9GUGaFHXhTU72z1AqW_Y726OC7kyVRYCnfhZ8,3682
12
- grasp_agents/llm_agent.py,sha256=trKvxIS5z9GYusB9zNRv_75trGP1NwzdvK2_-rElYnE,16102
12
+ grasp_agents/llm_agent.py,sha256=UvmcvuLl8EWe51isSw9Oa9lxA9gQBJngRXj8LUMAd4w,16097
13
13
  grasp_agents/llm_agent_state.py,sha256=L54zUW5nAT-ubvEB7XNAQ84ExOgRlUFzc-Q49mUXUT0,2390
14
14
  grasp_agents/memory.py,sha256=X1YtVX8XxP5KnGPMW8BqjID8QK4hTG2obxoyhnnZ4pU,5575
15
15
  grasp_agents/printer.py,sha256=ZENcTITCcgSizvcUXnIQiFl_NcVHg-801Z-sbT0D8rg,5030
16
- grasp_agents/prompt_builder.py,sha256=uN21GAYdmbNiIkv60Rs_ixJyybR1JWR0lsOM9HVkiLE,8411
17
- grasp_agents/run_context.py,sha256=4v6IcddHSWUAMYX8M9hQRXNMfPf7gUv45nn_Fb5lLaI,2233
16
+ grasp_agents/prompt_builder.py,sha256=A4z8B5nIfgn_cHnZCNyYjwA904UDiOaCyBc5VPu8168,8867
17
+ grasp_agents/run_context.py,sha256=OpTdo32-WPD7XwE16LQphh2C7yrrkZ3C4Ia80pWeBDQ,2192
18
18
  grasp_agents/tool_orchestrator.py,sha256=XgKCOytPs4y4SBpS5i4eogCB428XaDnNIB3VxzJC_k0,6095
19
19
  grasp_agents/usage_tracker.py,sha256=4gy0XtfIBAjQHblEFpQuPelmxMGDvE7vw4c8Ccr1msk,3471
20
20
  grasp_agents/utils.py,sha256=KzoInW0sq-pwwUtgjtYMW8b0ivBH6MR0Zxv6Kqq3k3M,4510
@@ -37,10 +37,10 @@ grasp_agents/typing/io.py,sha256=uxSvbD05UK5nIhPfDvXIoGuU6xRMW4USZq_4IgBeGCY,609
37
37
  grasp_agents/typing/message.py,sha256=DC24XMcUPG1MHDNZKT67yKWyaMTQNF5B0yDiUg8b54Q,3833
38
38
  grasp_agents/typing/tool.py,sha256=e0pTMnRcpMpGNVQ8muE9wnh7LdIgh92AqXDo9hMDxf0,1960
39
39
  grasp_agents/workflow/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
40
- grasp_agents/workflow/looped_agent.py,sha256=5tvlkPyhPdNGVSEaxTNPX0QQfPwP01OpZrvmv4bWXsA,4046
41
- grasp_agents/workflow/sequential_agent.py,sha256=SbeAKHCbhMy-2CmQWs1f5VMSRj88MLiAQc76g_ivGvA,2131
42
- grasp_agents/workflow/workflow_agent.py,sha256=E6jzfFmdUfCoYehH9ZsAa_MTICjEaP4NGwNxhc3NR_E,2436
43
- grasp_agents-0.2.10.dist-info/METADATA,sha256=lBHs7V6zm6E63yspRA9OOHS80M7hCtyVxuCNkaNce5Y,7188
44
- grasp_agents-0.2.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
45
- grasp_agents-0.2.10.dist-info/licenses/LICENSE.md,sha256=-nNNdWqGB8gJ2O-peFQ2Irshv5tW5pHKyTcYkwvH7CE,1201
46
- grasp_agents-0.2.10.dist-info/RECORD,,
40
+ grasp_agents/workflow/looped_agent.py,sha256=f7R-yVq4cP8zMTlvAF0GCpD0Izkyz5jUvUrxaXycusw,4710
41
+ grasp_agents/workflow/sequential_agent.py,sha256=yU3X28rqZo-xVdkbEETb9pBCBG7v1GK7oxtbl5xLdCo,2526
42
+ grasp_agents/workflow/workflow_agent.py,sha256=T1DpSipIkcrC_WJKR6Ho_cSHmVltCcF6Ve1F7isPRp0,3031
43
+ grasp_agents-0.2.11.dist-info/METADATA,sha256=xXvakAAM5nycFuJf70buNFqQ5BMNyMbI7rdFZHHwphc,7188
44
+ grasp_agents-0.2.11.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
45
+ grasp_agents-0.2.11.dist-info/licenses/LICENSE.md,sha256=-nNNdWqGB8gJ2O-peFQ2Irshv5tW5pHKyTcYkwvH7CE,1201
46
+ grasp_agents-0.2.11.dist-info/RECORD,,