grasp_agents 0.1.18__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,16 +1,8 @@
1
1
  import inspect
2
2
  from collections.abc import Callable, Coroutine, Sequence
3
- from typing import (
4
- Any,
5
- )
3
+ from typing import Any, overload
6
4
 
7
- from .types import (
8
- QueryP,
9
- QueryR,
10
- QueryT,
11
- RetrievalCallableList,
12
- RetrievalCallableSingle,
13
- )
5
+ from .types import P, ProcessorCallableList, ProcessorCallableSingle, R, T
14
6
 
15
7
 
16
8
  def is_bound_method(func: Callable[..., Any], self_candidate: Any) -> bool:
@@ -19,13 +11,24 @@ def is_bound_method(func: Callable[..., Any], self_candidate: Any) -> bool:
19
11
  )
20
12
 
21
13
 
14
+ @overload
22
15
  def split_pos_args(
23
- call: (
24
- RetrievalCallableSingle[QueryT, QueryP, QueryR]
25
- | RetrievalCallableList[QueryT, QueryP, QueryR]
26
- ),
16
+ call: ProcessorCallableSingle[T, P, R],
27
17
  args: Sequence[Any],
28
- ) -> tuple[Any | None, QueryT | list[QueryT], Sequence[Any]]:
18
+ ) -> tuple[Any | None, T, Sequence[Any]]: ...
19
+
20
+
21
+ @overload
22
+ def split_pos_args(
23
+ call: ProcessorCallableList[T, P, R],
24
+ args: Sequence[Any],
25
+ ) -> tuple[Any | None, list[T], Sequence[Any]]: ...
26
+
27
+
28
+ def split_pos_args(
29
+ call: (ProcessorCallableSingle[T, P, R] | ProcessorCallableList[T, P, R]),
30
+ args: Sequence[Any],
31
+ ) -> tuple[Any | None, T | list[T], Sequence[Any]]:
29
32
  if not args:
30
33
  raise ValueError("No positional arguments passed.")
31
34
  maybe_self = args[0]
@@ -45,13 +48,13 @@ def split_pos_args(
45
48
  return None, args[0], args[1:]
46
49
 
47
50
 
48
- def partial_retrieval_callable(
49
- call: Callable[..., Coroutine[Any, Any, QueryR]],
51
+ def partial_processor_callable(
52
+ call: Callable[..., Coroutine[Any, Any, R]],
50
53
  self_obj: Any,
51
- *args: QueryP.args,
52
- **kwargs: QueryP.kwargs,
53
- ) -> Callable[[QueryT], Coroutine[Any, Any, QueryR]]:
54
- async def wrapper(inp: QueryT) -> QueryR:
54
+ *args: Any,
55
+ **kwargs: Any,
56
+ ) -> Callable[[Any], Coroutine[Any, Any, R]]:
57
+ async def wrapper(inp: Any) -> R:
55
58
  if self_obj is not None:
56
59
  # `call` is a method
57
60
  return await call(self_obj, inp, *args, **kwargs)
@@ -59,9 +62,3 @@ def partial_retrieval_callable(
59
62
  return await call(inp, *args, **kwargs)
60
63
 
61
64
  return wrapper
62
-
63
-
64
- def expected_exec_time_from_max_concurrency_and_rpm(
65
- rpm: float, max_concurrency: int
66
- ) -> float:
67
- return 60.0 / (rpm / max_concurrency)
@@ -8,7 +8,6 @@ from .printer import Printer
8
8
  from .typing.content import ImageData
9
9
  from .typing.io import (
10
10
  AgentID,
11
- AgentPayload,
12
11
  AgentState,
13
12
  InT,
14
13
  LLMPrompt,
@@ -44,9 +43,7 @@ class InteractionRecord(BaseModel, Generic[InT, OutT, StateT]):
44
43
  model_config = ConfigDict(extra="forbid", frozen=True)
45
44
 
46
45
 
47
- InteractionHistory: TypeAlias = list[
48
- InteractionRecord[AgentPayload, AgentPayload, AgentState]
49
- ]
46
+ InteractionHistory: TypeAlias = list[InteractionRecord[Any, Any, AgentState]]
50
47
 
51
48
 
52
49
  CtxT = TypeVar("CtxT")
@@ -56,23 +53,13 @@ class RunContextWrapper(BaseModel, Generic[CtxT]):
56
53
  context: CtxT | None = None
57
54
  run_id: str = Field(default_factory=lambda: str(uuid4())[:8], frozen=True)
58
55
  run_args: dict[AgentID, RunArgs] = Field(default_factory=dict)
59
- interaction_history: InteractionHistory = Field(default_factory=list)
56
+ interaction_history: InteractionHistory = Field(default_factory=list) # type: ignore[valid-type]
60
57
 
61
58
  print_messages: bool = False
62
59
 
63
60
  _usage_tracker: UsageTracker = PrivateAttr()
64
61
  _printer: Printer = PrivateAttr()
65
62
 
66
- # usage_tracker: Optional[UsageTracker] = None
67
- # printer: Optional[Printer] = None
68
-
69
- # @model_validator(mode="after")
70
- # def set_usage_tracker_and_printer(self) -> "RunContextWrapper":
71
- # self.usage_tracker = UsageTracker(source_id=self.run_id)
72
- # self.printer = Printer(source_id=self.run_id)
73
-
74
- # return self
75
-
76
63
  def model_post_init(self, context: Any) -> None: # noqa: ARG002
77
64
  self._usage_tracker = UsageTracker(source_id=self.run_id)
78
65
  self._printer = Printer(
@@ -16,7 +16,7 @@ from .typing.tool import BaseTool, ToolCall, ToolChoice
16
16
  logger = getLogger(__name__)
17
17
 
18
18
 
19
- class ToolCallLoopExitHandler(Protocol[CtxT]):
19
+ class ExitToolCallLoopHandler(Protocol[CtxT]):
20
20
  def __call__(
21
21
  self,
22
22
  conversation: Conversation,
@@ -26,6 +26,16 @@ class ToolCallLoopExitHandler(Protocol[CtxT]):
26
26
  ) -> bool: ...
27
27
 
28
28
 
29
+ class ManageAgentStateHandler(Protocol[CtxT]):
30
+ def __call__(
31
+ self,
32
+ agent_state: LLMAgentState,
33
+ *,
34
+ ctx: RunContextWrapper[CtxT] | None,
35
+ **kwargs: Any,
36
+ ) -> None: ...
37
+
38
+
29
39
  class ToolOrchestrator(Generic[CtxT]):
30
40
  def __init__(
31
41
  self,
@@ -38,13 +48,13 @@ class ToolOrchestrator(Generic[CtxT]):
38
48
  self._agent_id = agent_id
39
49
 
40
50
  self._llm = llm
41
- self._tools = tools
42
- self.llm.tools = tools
51
+ self._llm.tools = tools
43
52
 
44
53
  self._max_turns = max_turns
45
54
  self._react_mode = react_mode
46
55
 
47
- self.tool_call_loop_exit_impl: ToolCallLoopExitHandler[CtxT] | None = None
56
+ self.exit_tool_call_loop_impl: ExitToolCallLoopHandler[CtxT] | None = None
57
+ self.manage_agent_state_impl: ManageAgentStateHandler[CtxT] | None = None
48
58
 
49
59
  @property
50
60
  def agent_id(self) -> str:
@@ -62,15 +72,15 @@ class ToolOrchestrator(Generic[CtxT]):
62
72
  def max_turns(self) -> int:
63
73
  return self._max_turns
64
74
 
65
- def _tool_call_loop_exit(
75
+ def _exit_tool_call_loop(
66
76
  self,
67
77
  conversation: Conversation,
68
78
  *,
69
79
  ctx: RunContextWrapper[CtxT] | None = None,
70
80
  **kwargs: Any,
71
81
  ) -> bool:
72
- if self.tool_call_loop_exit_impl:
73
- return self.tool_call_loop_exit_impl(
82
+ if self.exit_tool_call_loop_impl:
83
+ return self.exit_tool_call_loop_impl(
74
84
  conversation=conversation, ctx=ctx, **kwargs
75
85
  )
76
86
 
@@ -81,6 +91,16 @@ class ToolOrchestrator(Generic[CtxT]):
81
91
 
82
92
  return not bool(conversation[-1].tool_calls)
83
93
 
94
+ def _manage_agent_state(
95
+ self,
96
+ agent_state: LLMAgentState,
97
+ *,
98
+ ctx: RunContextWrapper[CtxT] | None = None,
99
+ **kwargs: Any,
100
+ ) -> None:
101
+ if self.manage_agent_state_impl:
102
+ self.manage_agent_state_impl(agent_state=agent_state, ctx=ctx, **kwargs)
103
+
84
104
  async def generate_once(
85
105
  self,
86
106
  agent_state: LLMAgentState,
@@ -117,7 +137,9 @@ class ToolOrchestrator(Generic[CtxT]):
117
137
  turns = 0
118
138
 
119
139
  while True:
120
- if self._tool_call_loop_exit(
140
+ self._manage_agent_state(agent_state=agent_state, ctx=ctx, num_turns=turns)
141
+
142
+ if self._exit_tool_call_loop(
121
143
  message_history.batched_conversations[0], ctx=ctx, num_turns=turns
122
144
  ):
123
145
  return
@@ -2,6 +2,8 @@ from abc import ABC, abstractmethod
2
2
  from collections.abc import AsyncIterator
3
3
  from typing import Any
4
4
 
5
+ from pydantic import BaseModel
6
+
5
7
  from .completion import Completion, CompletionChunk
6
8
  from .content import Content
7
9
  from .message import AssistantMessage, Message, SystemMessage, ToolMessage, UserMessage
@@ -64,7 +66,7 @@ class Converters(ABC):
64
66
 
65
67
  @staticmethod
66
68
  @abstractmethod
67
- def to_tool(tool: BaseTool[Any, Any, Any], **kwargs: Any) -> Any:
69
+ def to_tool(tool: BaseTool[BaseModel, Any, Any], **kwargs: Any) -> Any:
68
70
  pass
69
71
 
70
72
  @staticmethod
grasp_agents/typing/io.py CHANGED
@@ -7,23 +7,18 @@ from .content import ImageData
7
7
  AgentID: TypeAlias = str
8
8
 
9
9
 
10
- class AgentPayload(BaseModel):
11
- pass
12
-
13
-
14
10
  class AgentState(BaseModel):
15
11
  pass
16
12
 
17
13
 
18
- InT = TypeVar("InT", bound=AgentPayload, contravariant=True) # noqa: PLC0105
19
- OutT = TypeVar("OutT", bound=AgentPayload, covariant=True) # noqa: PLC0105
20
- StateT = TypeVar("StateT", bound=AgentState, covariant=True) # noqa: PLC0105
21
-
22
-
23
14
  class LLMPromptArgs(BaseModel):
24
15
  pass
25
16
 
26
17
 
18
+ InT = TypeVar("InT", contravariant=True) # noqa: PLC0105
19
+ OutT = TypeVar("OutT", covariant=True) # noqa: PLC0105
20
+ StateT = TypeVar("StateT", bound=AgentState, covariant=True) # noqa: PLC0105
21
+
27
22
  LLMPrompt: TypeAlias = str
28
23
  LLMFormattedSystemArgs: TypeAlias = dict[str, str]
29
24
  LLMFormattedArgs: TypeAlias = dict[str, str | ImageData]
@@ -3,9 +3,11 @@ from __future__ import annotations
3
3
  import asyncio
4
4
  from abc import ABC, abstractmethod
5
5
  from collections.abc import Sequence
6
- from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, TypeVar
6
+ from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, TypeAlias, TypeVar
7
7
 
8
- from pydantic import BaseModel, TypeAdapter
8
+ from pydantic import BaseModel, PrivateAttr, TypeAdapter
9
+
10
+ from ..generics_utils import AutoInstanceAttributesMixin
9
11
 
10
12
  if TYPE_CHECKING:
11
13
  from ..run_context import CtxT, RunContextWrapper
@@ -26,15 +28,32 @@ class ToolCall(BaseModel):
26
28
  tool_arguments: str
27
29
 
28
30
 
29
- class BaseTool(BaseModel, ABC, Generic[_ToolInT, _ToolOutT, CtxT]):
31
+ class BaseTool(
32
+ AutoInstanceAttributesMixin, BaseModel, ABC, Generic[_ToolInT, _ToolOutT, CtxT]
33
+ ):
34
+ _generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
35
+ 0: "_in_schema",
36
+ 1: "_out_schema",
37
+ }
38
+
30
39
  name: str
31
40
  description: str
32
- in_schema: type[_ToolInT]
33
- out_schema: type[_ToolOutT]
41
+
42
+ _in_schema: type[_ToolInT] = PrivateAttr()
43
+ _out_schema: type[_ToolOutT] = PrivateAttr()
34
44
 
35
45
  # Supported by OpenAI API
36
46
  strict: bool | None = None
37
47
 
48
+ @property
49
+ def in_schema(self) -> type[_ToolInT]: # type: ignore[reportInvalidTypeVarUse]
50
+ # Exposing the type of a contravariant variable only, should be type safe
51
+ return self._in_schema
52
+
53
+ @property
54
+ def out_schema(self) -> type[_ToolOutT]:
55
+ return self._out_schema
56
+
38
57
  @abstractmethod
39
58
  async def run(
40
59
  self, inp: _ToolInT, ctx: RunContextWrapper[CtxT] | None = None
@@ -49,9 +68,9 @@ class BaseTool(BaseModel, ABC, Generic[_ToolInT, _ToolOutT, CtxT]):
49
68
  async def __call__(
50
69
  self, ctx: RunContextWrapper[CtxT] | None = None, **kwargs: Any
51
70
  ) -> _ToolOutT:
52
- result = await self.run(self.in_schema(**kwargs), ctx=ctx)
71
+ result = await self.run(self._in_schema(**kwargs), ctx=ctx)
53
72
 
54
- return TypeAdapter(self.out_schema).validate_python(result)
73
+ return TypeAdapter(self._out_schema).validate_python(result)
55
74
 
56
75
 
57
76
  ToolChoice: TypeAlias = (
grasp_agents/utils.py CHANGED
@@ -1,34 +1,25 @@
1
1
  import ast
2
2
  import asyncio
3
- import functools
4
3
  import json
5
4
  import re
6
- from collections.abc import Callable, Coroutine
7
- from copy import deepcopy
5
+ from collections.abc import Coroutine
8
6
  from datetime import datetime
9
7
  from logging import getLogger
10
8
  from pathlib import Path
11
- from typing import Any, TypeVar, cast
12
-
13
- from pydantic import BaseModel, GetCoreSchemaHandler, TypeAdapter, create_model
14
- from pydantic.fields import FieldInfo
9
+ from typing import Any, TypeVar
10
+
11
+ from pydantic import (
12
+ BaseModel,
13
+ GetCoreSchemaHandler,
14
+ TypeAdapter,
15
+ ValidationError,
16
+ )
15
17
  from pydantic_core import core_schema
16
18
  from tqdm.autonotebook import tqdm
17
19
 
18
20
  logger = getLogger(__name__)
19
21
 
20
-
21
- def merge_pydantic_models(*models: type[BaseModel]) -> type[BaseModel]:
22
- fields_dict: dict[str, FieldInfo] = {}
23
- for model in models:
24
- for field_name, field_info in model.model_fields.items():
25
- if field_name in fields_dict:
26
- raise ValueError(
27
- f"Field conflict detected: '{field_name}' exists in multiple models"
28
- )
29
- fields_dict[field_name] = field_info
30
-
31
- return create_model("MergedModel", __module__=__name__, **fields_dict) # type: ignore
22
+ T = TypeVar("T")
32
23
 
33
24
 
34
25
  def filter_fields(data: dict[str, Any], model: type[BaseModel]) -> dict[str, Any]:
@@ -57,7 +48,7 @@ def format_json_string(text: str) -> str:
57
48
  return text
58
49
 
59
50
 
60
- def read_json_string(
51
+ def parse_json_or_py_string(
61
52
  json_str: str, return_none_on_failure: bool = False
62
53
  ) -> dict[str, Any] | list[Any] | None:
63
54
  try:
@@ -79,7 +70,21 @@ def read_json_string(
79
70
  def extract_json(
80
71
  json_str: str, return_none_on_failure: bool = False
81
72
  ) -> dict[str, Any] | list[Any] | None:
82
- return read_json_string(format_json_string(json_str), return_none_on_failure)
73
+ return parse_json_or_py_string(format_json_string(json_str), return_none_on_failure)
74
+
75
+
76
+ def validate_obj_from_json_or_py_string(s: str, adapter: TypeAdapter[T]) -> T:
77
+ s_fmt = re.sub(r"```[a-zA-Z0-9]*\n|```", "", s).strip()
78
+ try:
79
+ parsed = json.loads(s_fmt)
80
+ return adapter.validate_python(parsed)
81
+ except (json.JSONDecodeError, ValidationError):
82
+ try:
83
+ return adapter.validate_python(s_fmt)
84
+ except ValidationError as exc:
85
+ raise ValueError(
86
+ f"Invalid JSON or Python string:\n{s}\nExpected type: {adapter._type}", # type: ignore[arg-type]
87
+ ) from exc
83
88
 
84
89
 
85
90
  def extract_xml_list(text: str) -> list[str]:
@@ -131,24 +136,6 @@ def make_conditional_parsed_output_type(
131
136
  return ParsedOutput
132
137
 
133
138
 
134
- T = TypeVar("T", bound=Callable[..., Any])
135
-
136
-
137
- def forbid_state_change(method: T) -> T:
138
- @functools.wraps(method)
139
- def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
140
- before = deepcopy(self.__dict__)
141
- result = method(self, *args, **kwargs)
142
- after = self.__dict__
143
- if before != after:
144
- raise RuntimeError(
145
- f"Method '{method.__name__}' modified the instance state."
146
- )
147
- return result
148
-
149
- return cast("T", wrapper)
150
-
151
-
152
139
  def read_contents_from_file(
153
140
  file_path: str | Path,
154
141
  binary_mode: bool = False,
@@ -1,16 +1,16 @@
1
1
  from collections.abc import Sequence
2
2
  from logging import getLogger
3
- from typing import Any, Generic, Protocol, TypeVar, cast, final
3
+ from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast, final
4
4
 
5
5
  from ..agent_message_pool import AgentMessage, AgentMessagePool
6
6
  from ..comm_agent import CommunicatingAgent
7
7
  from ..run_context import CtxT, RunContextWrapper
8
- from ..typing.io import AgentID, AgentPayload, AgentState, InT, OutT
8
+ from ..typing.io import AgentID, AgentState, InT, OutT
9
9
  from .workflow_agent import WorkflowAgent
10
10
 
11
11
  logger = getLogger(__name__)
12
12
 
13
- _EH_OutT = TypeVar("_EH_OutT", bound=AgentPayload, contravariant=True) # noqa: PLC0105
13
+ _EH_OutT = TypeVar("_EH_OutT", contravariant=True) # noqa: PLC0105
14
14
 
15
15
 
16
16
  class WorkflowLoopExitHandler(Protocol[_EH_OutT, CtxT]):
@@ -23,13 +23,16 @@ class WorkflowLoopExitHandler(Protocol[_EH_OutT, CtxT]):
23
23
 
24
24
 
25
25
  class LoopedWorkflowAgent(WorkflowAgent[InT, OutT, CtxT], Generic[InT, OutT, CtxT]):
26
+ _generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
27
+ 0: "_in_type",
28
+ 1: "_out_type",
29
+ }
30
+
26
31
  def __init__(
27
32
  self,
28
33
  agent_id: AgentID,
29
- subagents: Sequence[
30
- CommunicatingAgent[AgentPayload, AgentPayload, AgentState, CtxT]
31
- ],
32
- exit_agent: CommunicatingAgent[AgentPayload, OutT, AgentState, CtxT],
34
+ subagents: Sequence[CommunicatingAgent[Any, Any, AgentState, CtxT]],
35
+ exit_agent: CommunicatingAgent[Any, OutT, AgentState, CtxT],
33
36
  message_pool: AgentMessagePool[CtxT] | None = None,
34
37
  recipient_ids: list[AgentID] | None = None,
35
38
  dynamic_routing: bool = False,
@@ -61,7 +64,7 @@ class LoopedWorkflowAgent(WorkflowAgent[InT, OutT, CtxT], Generic[InT, OutT, Ctx
61
64
 
62
65
  return func
63
66
 
64
- def _workflow_loop_exit(
67
+ def _exit_workflow_loop(
65
68
  self,
66
69
  output_message: AgentMessage[OutT, AgentState],
67
70
  ctx: RunContextWrapper[CtxT] | None,
@@ -101,7 +104,7 @@ class LoopedWorkflowAgent(WorkflowAgent[InT, OutT, CtxT], Generic[InT, OutT, Ctx
101
104
  if subagent is self._end_agent:
102
105
  num_iterations += 1
103
106
  exit_message = cast("AgentMessage[OutT, AgentState]", agent_message)
104
- if self._workflow_loop_exit(exit_message, ctx=ctx):
107
+ if self._exit_workflow_loop(exit_message, ctx=ctx):
105
108
  return exit_message
106
109
  if num_iterations >= self._max_iterations:
107
110
  logger.info(
@@ -1,20 +1,23 @@
1
1
  from collections.abc import Sequence
2
- from typing import Any, Generic, cast, final
2
+ from typing import Any, ClassVar, Generic, cast, final
3
3
 
4
4
  from ..agent_message_pool import AgentMessage, AgentMessagePool
5
5
  from ..comm_agent import CommunicatingAgent
6
6
  from ..run_context import CtxT, RunContextWrapper
7
- from ..typing.io import AgentID, AgentPayload, AgentState, InT, OutT
7
+ from ..typing.io import AgentID, AgentState, InT, OutT
8
8
  from .workflow_agent import WorkflowAgent
9
9
 
10
10
 
11
11
  class SequentialWorkflowAgent(WorkflowAgent[InT, OutT, CtxT], Generic[InT, OutT, CtxT]):
12
+ _generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
13
+ 0: "_in_type",
14
+ 1: "_out_type",
15
+ }
16
+
12
17
  def __init__(
13
18
  self,
14
19
  agent_id: AgentID,
15
- subagents: Sequence[
16
- CommunicatingAgent[AgentPayload, AgentPayload, AgentState, CtxT]
17
- ],
20
+ subagents: Sequence[CommunicatingAgent[Any, Any, AgentState, CtxT]],
18
21
  message_pool: AgentMessagePool[CtxT] | None = None,
19
22
  recipient_ids: list[AgentID] | None = None,
20
23
  dynamic_routing: bool = False,
@@ -23,7 +26,7 @@ class SequentialWorkflowAgent(WorkflowAgent[InT, OutT, CtxT], Generic[InT, OutT,
23
26
  super().__init__(
24
27
  subagents=subagents,
25
28
  start_agent=subagents[0],
26
- end_agent=subagents[-1], # type: ignore[assignment]
29
+ end_agent=subagents[-1],
27
30
  agent_id=agent_id,
28
31
  message_pool=message_pool,
29
32
  recipient_ids=recipient_ids,
@@ -1,11 +1,11 @@
1
1
  from abc import ABC, abstractmethod
2
2
  from collections.abc import Sequence
3
- from typing import Any, Generic
3
+ from typing import Any, ClassVar, Generic
4
4
 
5
5
  from ..agent_message_pool import AgentMessage, AgentMessagePool
6
6
  from ..comm_agent import CommunicatingAgent
7
7
  from ..run_context import CtxT, RunContextWrapper
8
- from ..typing.io import AgentID, AgentPayload, AgentState, InT, OutT
8
+ from ..typing.io import AgentID, AgentState, InT, OutT
9
9
 
10
10
 
11
11
  class WorkflowAgent(
@@ -13,14 +13,17 @@ class WorkflowAgent(
13
13
  ABC,
14
14
  Generic[InT, OutT, CtxT],
15
15
  ):
16
+ _generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
17
+ 0: "_in_type",
18
+ 1: "_out_type",
19
+ }
20
+
16
21
  def __init__(
17
22
  self,
18
23
  agent_id: AgentID,
19
- subagents: Sequence[
20
- CommunicatingAgent[AgentPayload, AgentPayload, AgentState, CtxT]
21
- ],
22
- start_agent: CommunicatingAgent[InT, AgentPayload, AgentState, CtxT],
23
- end_agent: CommunicatingAgent[AgentPayload, OutT, AgentState, CtxT],
24
+ subagents: Sequence[CommunicatingAgent[Any, Any, AgentState, CtxT]],
25
+ start_agent: CommunicatingAgent[InT, Any, AgentState, CtxT],
26
+ end_agent: CommunicatingAgent[Any, OutT, AgentState, CtxT],
24
27
  message_pool: AgentMessagePool[CtxT] | None = None,
25
28
  recipient_ids: list[AgentID] | None = None,
26
29
  dynamic_routing: bool = False,
@@ -28,6 +31,10 @@ class WorkflowAgent(
28
31
  ) -> None:
29
32
  if not subagents:
30
33
  raise ValueError("At least one step is required")
34
+ if start_agent not in subagents:
35
+ raise ValueError("Start agent must be in the subagents list")
36
+ if end_agent not in subagents:
37
+ raise ValueError("End agent must be in the subagents list")
31
38
 
32
39
  self.subagents = subagents
33
40
 
@@ -36,8 +43,6 @@ class WorkflowAgent(
36
43
 
37
44
  super().__init__(
38
45
  agent_id=agent_id,
39
- out_schema=end_agent.out_schema,
40
- rcv_args_schema=start_agent.rcv_args_schema,
41
46
  message_pool=message_pool,
42
47
  recipient_ids=recipient_ids,
43
48
  dynamic_routing=dynamic_routing,
@@ -48,11 +53,11 @@ class WorkflowAgent(
48
53
  )
49
54
 
50
55
  @property
51
- def start_agent(self) -> CommunicatingAgent[InT, AgentPayload, AgentState, CtxT]:
56
+ def start_agent(self) -> CommunicatingAgent[InT, Any, AgentState, CtxT]:
52
57
  return self._start_agent
53
58
 
54
59
  @property
55
- def end_agent(self) -> CommunicatingAgent[AgentPayload, OutT, AgentState, CtxT]:
60
+ def end_agent(self) -> CommunicatingAgent[Any, OutT, AgentState, CtxT]:
56
61
  return self._end_agent
57
62
 
58
63
  @abstractmethod