grasp_agents 0.1.18__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_agents/agent_message.py +2 -2
- grasp_agents/agent_message_pool.py +6 -8
- grasp_agents/base_agent.py +15 -36
- grasp_agents/cloud_llm.py +9 -6
- grasp_agents/comm_agent.py +39 -43
- grasp_agents/generics_utils.py +159 -0
- grasp_agents/llm.py +4 -0
- grasp_agents/llm_agent.py +68 -37
- grasp_agents/llm_agent_state.py +9 -5
- grasp_agents/prompt_builder.py +45 -25
- grasp_agents/rate_limiting/rate_limiter_chunked.py +49 -48
- grasp_agents/rate_limiting/types.py +19 -40
- grasp_agents/rate_limiting/utils.py +24 -27
- grasp_agents/run_context.py +2 -15
- grasp_agents/tool_orchestrator.py +30 -8
- grasp_agents/typing/converters.py +3 -1
- grasp_agents/typing/io.py +4 -9
- grasp_agents/typing/tool.py +26 -7
- grasp_agents/utils.py +26 -39
- grasp_agents/workflow/looped_agent.py +12 -9
- grasp_agents/workflow/sequential_agent.py +9 -6
- grasp_agents/workflow/workflow_agent.py +16 -11
- {grasp_agents-0.1.18.dist-info → grasp_agents-0.2.0.dist-info}/METADATA +37 -33
- grasp_agents-0.2.0.dist-info/RECORD +45 -0
- grasp_agents-0.1.18.dist-info/RECORD +0 -44
- {grasp_agents-0.1.18.dist-info → grasp_agents-0.2.0.dist-info}/WHEEL +0 -0
- {grasp_agents-0.1.18.dist-info → grasp_agents-0.2.0.dist-info}/licenses/LICENSE.md +0 -0
@@ -1,16 +1,8 @@
|
|
1
1
|
import inspect
|
2
2
|
from collections.abc import Callable, Coroutine, Sequence
|
3
|
-
from typing import
|
4
|
-
Any,
|
5
|
-
)
|
3
|
+
from typing import Any, overload
|
6
4
|
|
7
|
-
from .types import
|
8
|
-
QueryP,
|
9
|
-
QueryR,
|
10
|
-
QueryT,
|
11
|
-
RetrievalCallableList,
|
12
|
-
RetrievalCallableSingle,
|
13
|
-
)
|
5
|
+
from .types import P, ProcessorCallableList, ProcessorCallableSingle, R, T
|
14
6
|
|
15
7
|
|
16
8
|
def is_bound_method(func: Callable[..., Any], self_candidate: Any) -> bool:
|
@@ -19,13 +11,24 @@ def is_bound_method(func: Callable[..., Any], self_candidate: Any) -> bool:
|
|
19
11
|
)
|
20
12
|
|
21
13
|
|
14
|
+
@overload
|
22
15
|
def split_pos_args(
|
23
|
-
call:
|
24
|
-
RetrievalCallableSingle[QueryT, QueryP, QueryR]
|
25
|
-
| RetrievalCallableList[QueryT, QueryP, QueryR]
|
26
|
-
),
|
16
|
+
call: ProcessorCallableSingle[T, P, R],
|
27
17
|
args: Sequence[Any],
|
28
|
-
) -> tuple[Any | None,
|
18
|
+
) -> tuple[Any | None, T, Sequence[Any]]: ...
|
19
|
+
|
20
|
+
|
21
|
+
@overload
|
22
|
+
def split_pos_args(
|
23
|
+
call: ProcessorCallableList[T, P, R],
|
24
|
+
args: Sequence[Any],
|
25
|
+
) -> tuple[Any | None, list[T], Sequence[Any]]: ...
|
26
|
+
|
27
|
+
|
28
|
+
def split_pos_args(
|
29
|
+
call: (ProcessorCallableSingle[T, P, R] | ProcessorCallableList[T, P, R]),
|
30
|
+
args: Sequence[Any],
|
31
|
+
) -> tuple[Any | None, T | list[T], Sequence[Any]]:
|
29
32
|
if not args:
|
30
33
|
raise ValueError("No positional arguments passed.")
|
31
34
|
maybe_self = args[0]
|
@@ -45,13 +48,13 @@ def split_pos_args(
|
|
45
48
|
return None, args[0], args[1:]
|
46
49
|
|
47
50
|
|
48
|
-
def
|
49
|
-
call: Callable[..., Coroutine[Any, Any,
|
51
|
+
def partial_processor_callable(
|
52
|
+
call: Callable[..., Coroutine[Any, Any, R]],
|
50
53
|
self_obj: Any,
|
51
|
-
*args:
|
52
|
-
**kwargs:
|
53
|
-
) -> Callable[[
|
54
|
-
async def wrapper(inp:
|
54
|
+
*args: Any,
|
55
|
+
**kwargs: Any,
|
56
|
+
) -> Callable[[Any], Coroutine[Any, Any, R]]:
|
57
|
+
async def wrapper(inp: Any) -> R:
|
55
58
|
if self_obj is not None:
|
56
59
|
# `call` is a method
|
57
60
|
return await call(self_obj, inp, *args, **kwargs)
|
@@ -59,9 +62,3 @@ def partial_retrieval_callable(
|
|
59
62
|
return await call(inp, *args, **kwargs)
|
60
63
|
|
61
64
|
return wrapper
|
62
|
-
|
63
|
-
|
64
|
-
def expected_exec_time_from_max_concurrency_and_rpm(
|
65
|
-
rpm: float, max_concurrency: int
|
66
|
-
) -> float:
|
67
|
-
return 60.0 / (rpm / max_concurrency)
|
grasp_agents/run_context.py
CHANGED
@@ -8,7 +8,6 @@ from .printer import Printer
|
|
8
8
|
from .typing.content import ImageData
|
9
9
|
from .typing.io import (
|
10
10
|
AgentID,
|
11
|
-
AgentPayload,
|
12
11
|
AgentState,
|
13
12
|
InT,
|
14
13
|
LLMPrompt,
|
@@ -44,9 +43,7 @@ class InteractionRecord(BaseModel, Generic[InT, OutT, StateT]):
|
|
44
43
|
model_config = ConfigDict(extra="forbid", frozen=True)
|
45
44
|
|
46
45
|
|
47
|
-
InteractionHistory: TypeAlias = list[
|
48
|
-
InteractionRecord[AgentPayload, AgentPayload, AgentState]
|
49
|
-
]
|
46
|
+
InteractionHistory: TypeAlias = list[InteractionRecord[Any, Any, AgentState]]
|
50
47
|
|
51
48
|
|
52
49
|
CtxT = TypeVar("CtxT")
|
@@ -56,23 +53,13 @@ class RunContextWrapper(BaseModel, Generic[CtxT]):
|
|
56
53
|
context: CtxT | None = None
|
57
54
|
run_id: str = Field(default_factory=lambda: str(uuid4())[:8], frozen=True)
|
58
55
|
run_args: dict[AgentID, RunArgs] = Field(default_factory=dict)
|
59
|
-
interaction_history: InteractionHistory = Field(default_factory=list)
|
56
|
+
interaction_history: InteractionHistory = Field(default_factory=list) # type: ignore[valid-type]
|
60
57
|
|
61
58
|
print_messages: bool = False
|
62
59
|
|
63
60
|
_usage_tracker: UsageTracker = PrivateAttr()
|
64
61
|
_printer: Printer = PrivateAttr()
|
65
62
|
|
66
|
-
# usage_tracker: Optional[UsageTracker] = None
|
67
|
-
# printer: Optional[Printer] = None
|
68
|
-
|
69
|
-
# @model_validator(mode="after")
|
70
|
-
# def set_usage_tracker_and_printer(self) -> "RunContextWrapper":
|
71
|
-
# self.usage_tracker = UsageTracker(source_id=self.run_id)
|
72
|
-
# self.printer = Printer(source_id=self.run_id)
|
73
|
-
|
74
|
-
# return self
|
75
|
-
|
76
63
|
def model_post_init(self, context: Any) -> None: # noqa: ARG002
|
77
64
|
self._usage_tracker = UsageTracker(source_id=self.run_id)
|
78
65
|
self._printer = Printer(
|
@@ -16,7 +16,7 @@ from .typing.tool import BaseTool, ToolCall, ToolChoice
|
|
16
16
|
logger = getLogger(__name__)
|
17
17
|
|
18
18
|
|
19
|
-
class
|
19
|
+
class ExitToolCallLoopHandler(Protocol[CtxT]):
|
20
20
|
def __call__(
|
21
21
|
self,
|
22
22
|
conversation: Conversation,
|
@@ -26,6 +26,16 @@ class ToolCallLoopExitHandler(Protocol[CtxT]):
|
|
26
26
|
) -> bool: ...
|
27
27
|
|
28
28
|
|
29
|
+
class ManageAgentStateHandler(Protocol[CtxT]):
|
30
|
+
def __call__(
|
31
|
+
self,
|
32
|
+
agent_state: LLMAgentState,
|
33
|
+
*,
|
34
|
+
ctx: RunContextWrapper[CtxT] | None,
|
35
|
+
**kwargs: Any,
|
36
|
+
) -> None: ...
|
37
|
+
|
38
|
+
|
29
39
|
class ToolOrchestrator(Generic[CtxT]):
|
30
40
|
def __init__(
|
31
41
|
self,
|
@@ -38,13 +48,13 @@ class ToolOrchestrator(Generic[CtxT]):
|
|
38
48
|
self._agent_id = agent_id
|
39
49
|
|
40
50
|
self._llm = llm
|
41
|
-
self.
|
42
|
-
self.llm.tools = tools
|
51
|
+
self._llm.tools = tools
|
43
52
|
|
44
53
|
self._max_turns = max_turns
|
45
54
|
self._react_mode = react_mode
|
46
55
|
|
47
|
-
self.
|
56
|
+
self.exit_tool_call_loop_impl: ExitToolCallLoopHandler[CtxT] | None = None
|
57
|
+
self.manage_agent_state_impl: ManageAgentStateHandler[CtxT] | None = None
|
48
58
|
|
49
59
|
@property
|
50
60
|
def agent_id(self) -> str:
|
@@ -62,15 +72,15 @@ class ToolOrchestrator(Generic[CtxT]):
|
|
62
72
|
def max_turns(self) -> int:
|
63
73
|
return self._max_turns
|
64
74
|
|
65
|
-
def
|
75
|
+
def _exit_tool_call_loop(
|
66
76
|
self,
|
67
77
|
conversation: Conversation,
|
68
78
|
*,
|
69
79
|
ctx: RunContextWrapper[CtxT] | None = None,
|
70
80
|
**kwargs: Any,
|
71
81
|
) -> bool:
|
72
|
-
if self.
|
73
|
-
return self.
|
82
|
+
if self.exit_tool_call_loop_impl:
|
83
|
+
return self.exit_tool_call_loop_impl(
|
74
84
|
conversation=conversation, ctx=ctx, **kwargs
|
75
85
|
)
|
76
86
|
|
@@ -81,6 +91,16 @@ class ToolOrchestrator(Generic[CtxT]):
|
|
81
91
|
|
82
92
|
return not bool(conversation[-1].tool_calls)
|
83
93
|
|
94
|
+
def _manage_agent_state(
|
95
|
+
self,
|
96
|
+
agent_state: LLMAgentState,
|
97
|
+
*,
|
98
|
+
ctx: RunContextWrapper[CtxT] | None = None,
|
99
|
+
**kwargs: Any,
|
100
|
+
) -> None:
|
101
|
+
if self.manage_agent_state_impl:
|
102
|
+
self.manage_agent_state_impl(agent_state=agent_state, ctx=ctx, **kwargs)
|
103
|
+
|
84
104
|
async def generate_once(
|
85
105
|
self,
|
86
106
|
agent_state: LLMAgentState,
|
@@ -117,7 +137,9 @@ class ToolOrchestrator(Generic[CtxT]):
|
|
117
137
|
turns = 0
|
118
138
|
|
119
139
|
while True:
|
120
|
-
|
140
|
+
self._manage_agent_state(agent_state=agent_state, ctx=ctx, num_turns=turns)
|
141
|
+
|
142
|
+
if self._exit_tool_call_loop(
|
121
143
|
message_history.batched_conversations[0], ctx=ctx, num_turns=turns
|
122
144
|
):
|
123
145
|
return
|
@@ -2,6 +2,8 @@ from abc import ABC, abstractmethod
|
|
2
2
|
from collections.abc import AsyncIterator
|
3
3
|
from typing import Any
|
4
4
|
|
5
|
+
from pydantic import BaseModel
|
6
|
+
|
5
7
|
from .completion import Completion, CompletionChunk
|
6
8
|
from .content import Content
|
7
9
|
from .message import AssistantMessage, Message, SystemMessage, ToolMessage, UserMessage
|
@@ -64,7 +66,7 @@ class Converters(ABC):
|
|
64
66
|
|
65
67
|
@staticmethod
|
66
68
|
@abstractmethod
|
67
|
-
def to_tool(tool: BaseTool[
|
69
|
+
def to_tool(tool: BaseTool[BaseModel, Any, Any], **kwargs: Any) -> Any:
|
68
70
|
pass
|
69
71
|
|
70
72
|
@staticmethod
|
grasp_agents/typing/io.py
CHANGED
@@ -7,23 +7,18 @@ from .content import ImageData
|
|
7
7
|
AgentID: TypeAlias = str
|
8
8
|
|
9
9
|
|
10
|
-
class AgentPayload(BaseModel):
|
11
|
-
pass
|
12
|
-
|
13
|
-
|
14
10
|
class AgentState(BaseModel):
|
15
11
|
pass
|
16
12
|
|
17
13
|
|
18
|
-
InT = TypeVar("InT", bound=AgentPayload, contravariant=True) # noqa: PLC0105
|
19
|
-
OutT = TypeVar("OutT", bound=AgentPayload, covariant=True) # noqa: PLC0105
|
20
|
-
StateT = TypeVar("StateT", bound=AgentState, covariant=True) # noqa: PLC0105
|
21
|
-
|
22
|
-
|
23
14
|
class LLMPromptArgs(BaseModel):
|
24
15
|
pass
|
25
16
|
|
26
17
|
|
18
|
+
InT = TypeVar("InT", contravariant=True) # noqa: PLC0105
|
19
|
+
OutT = TypeVar("OutT", covariant=True) # noqa: PLC0105
|
20
|
+
StateT = TypeVar("StateT", bound=AgentState, covariant=True) # noqa: PLC0105
|
21
|
+
|
27
22
|
LLMPrompt: TypeAlias = str
|
28
23
|
LLMFormattedSystemArgs: TypeAlias = dict[str, str]
|
29
24
|
LLMFormattedArgs: TypeAlias = dict[str, str | ImageData]
|
grasp_agents/typing/tool.py
CHANGED
@@ -3,9 +3,11 @@ from __future__ import annotations
|
|
3
3
|
import asyncio
|
4
4
|
from abc import ABC, abstractmethod
|
5
5
|
from collections.abc import Sequence
|
6
|
-
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, TypeVar
|
6
|
+
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, TypeAlias, TypeVar
|
7
7
|
|
8
|
-
from pydantic import BaseModel, TypeAdapter
|
8
|
+
from pydantic import BaseModel, PrivateAttr, TypeAdapter
|
9
|
+
|
10
|
+
from ..generics_utils import AutoInstanceAttributesMixin
|
9
11
|
|
10
12
|
if TYPE_CHECKING:
|
11
13
|
from ..run_context import CtxT, RunContextWrapper
|
@@ -26,15 +28,32 @@ class ToolCall(BaseModel):
|
|
26
28
|
tool_arguments: str
|
27
29
|
|
28
30
|
|
29
|
-
class BaseTool(
|
31
|
+
class BaseTool(
|
32
|
+
AutoInstanceAttributesMixin, BaseModel, ABC, Generic[_ToolInT, _ToolOutT, CtxT]
|
33
|
+
):
|
34
|
+
_generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
|
35
|
+
0: "_in_schema",
|
36
|
+
1: "_out_schema",
|
37
|
+
}
|
38
|
+
|
30
39
|
name: str
|
31
40
|
description: str
|
32
|
-
|
33
|
-
|
41
|
+
|
42
|
+
_in_schema: type[_ToolInT] = PrivateAttr()
|
43
|
+
_out_schema: type[_ToolOutT] = PrivateAttr()
|
34
44
|
|
35
45
|
# Supported by OpenAI API
|
36
46
|
strict: bool | None = None
|
37
47
|
|
48
|
+
@property
|
49
|
+
def in_schema(self) -> type[_ToolInT]: # type: ignore[reportInvalidTypeVarUse]
|
50
|
+
# Exposing the type of a contravariant variable only, should be type safe
|
51
|
+
return self._in_schema
|
52
|
+
|
53
|
+
@property
|
54
|
+
def out_schema(self) -> type[_ToolOutT]:
|
55
|
+
return self._out_schema
|
56
|
+
|
38
57
|
@abstractmethod
|
39
58
|
async def run(
|
40
59
|
self, inp: _ToolInT, ctx: RunContextWrapper[CtxT] | None = None
|
@@ -49,9 +68,9 @@ class BaseTool(BaseModel, ABC, Generic[_ToolInT, _ToolOutT, CtxT]):
|
|
49
68
|
async def __call__(
|
50
69
|
self, ctx: RunContextWrapper[CtxT] | None = None, **kwargs: Any
|
51
70
|
) -> _ToolOutT:
|
52
|
-
result = await self.run(self.
|
71
|
+
result = await self.run(self._in_schema(**kwargs), ctx=ctx)
|
53
72
|
|
54
|
-
return TypeAdapter(self.
|
73
|
+
return TypeAdapter(self._out_schema).validate_python(result)
|
55
74
|
|
56
75
|
|
57
76
|
ToolChoice: TypeAlias = (
|
grasp_agents/utils.py
CHANGED
@@ -1,34 +1,25 @@
|
|
1
1
|
import ast
|
2
2
|
import asyncio
|
3
|
-
import functools
|
4
3
|
import json
|
5
4
|
import re
|
6
|
-
from collections.abc import
|
7
|
-
from copy import deepcopy
|
5
|
+
from collections.abc import Coroutine
|
8
6
|
from datetime import datetime
|
9
7
|
from logging import getLogger
|
10
8
|
from pathlib import Path
|
11
|
-
from typing import Any, TypeVar
|
12
|
-
|
13
|
-
from pydantic import
|
14
|
-
|
9
|
+
from typing import Any, TypeVar
|
10
|
+
|
11
|
+
from pydantic import (
|
12
|
+
BaseModel,
|
13
|
+
GetCoreSchemaHandler,
|
14
|
+
TypeAdapter,
|
15
|
+
ValidationError,
|
16
|
+
)
|
15
17
|
from pydantic_core import core_schema
|
16
18
|
from tqdm.autonotebook import tqdm
|
17
19
|
|
18
20
|
logger = getLogger(__name__)
|
19
21
|
|
20
|
-
|
21
|
-
def merge_pydantic_models(*models: type[BaseModel]) -> type[BaseModel]:
|
22
|
-
fields_dict: dict[str, FieldInfo] = {}
|
23
|
-
for model in models:
|
24
|
-
for field_name, field_info in model.model_fields.items():
|
25
|
-
if field_name in fields_dict:
|
26
|
-
raise ValueError(
|
27
|
-
f"Field conflict detected: '{field_name}' exists in multiple models"
|
28
|
-
)
|
29
|
-
fields_dict[field_name] = field_info
|
30
|
-
|
31
|
-
return create_model("MergedModel", __module__=__name__, **fields_dict) # type: ignore
|
22
|
+
T = TypeVar("T")
|
32
23
|
|
33
24
|
|
34
25
|
def filter_fields(data: dict[str, Any], model: type[BaseModel]) -> dict[str, Any]:
|
@@ -57,7 +48,7 @@ def format_json_string(text: str) -> str:
|
|
57
48
|
return text
|
58
49
|
|
59
50
|
|
60
|
-
def
|
51
|
+
def parse_json_or_py_string(
|
61
52
|
json_str: str, return_none_on_failure: bool = False
|
62
53
|
) -> dict[str, Any] | list[Any] | None:
|
63
54
|
try:
|
@@ -79,7 +70,21 @@ def read_json_string(
|
|
79
70
|
def extract_json(
|
80
71
|
json_str: str, return_none_on_failure: bool = False
|
81
72
|
) -> dict[str, Any] | list[Any] | None:
|
82
|
-
return
|
73
|
+
return parse_json_or_py_string(format_json_string(json_str), return_none_on_failure)
|
74
|
+
|
75
|
+
|
76
|
+
def validate_obj_from_json_or_py_string(s: str, adapter: TypeAdapter[T]) -> T:
|
77
|
+
s_fmt = re.sub(r"```[a-zA-Z0-9]*\n|```", "", s).strip()
|
78
|
+
try:
|
79
|
+
parsed = json.loads(s_fmt)
|
80
|
+
return adapter.validate_python(parsed)
|
81
|
+
except (json.JSONDecodeError, ValidationError):
|
82
|
+
try:
|
83
|
+
return adapter.validate_python(s_fmt)
|
84
|
+
except ValidationError as exc:
|
85
|
+
raise ValueError(
|
86
|
+
f"Invalid JSON or Python string:\n{s}\nExpected type: {adapter._type}", # type: ignore[arg-type]
|
87
|
+
) from exc
|
83
88
|
|
84
89
|
|
85
90
|
def extract_xml_list(text: str) -> list[str]:
|
@@ -131,24 +136,6 @@ def make_conditional_parsed_output_type(
|
|
131
136
|
return ParsedOutput
|
132
137
|
|
133
138
|
|
134
|
-
T = TypeVar("T", bound=Callable[..., Any])
|
135
|
-
|
136
|
-
|
137
|
-
def forbid_state_change(method: T) -> T:
|
138
|
-
@functools.wraps(method)
|
139
|
-
def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
|
140
|
-
before = deepcopy(self.__dict__)
|
141
|
-
result = method(self, *args, **kwargs)
|
142
|
-
after = self.__dict__
|
143
|
-
if before != after:
|
144
|
-
raise RuntimeError(
|
145
|
-
f"Method '{method.__name__}' modified the instance state."
|
146
|
-
)
|
147
|
-
return result
|
148
|
-
|
149
|
-
return cast("T", wrapper)
|
150
|
-
|
151
|
-
|
152
139
|
def read_contents_from_file(
|
153
140
|
file_path: str | Path,
|
154
141
|
binary_mode: bool = False,
|
@@ -1,16 +1,16 @@
|
|
1
1
|
from collections.abc import Sequence
|
2
2
|
from logging import getLogger
|
3
|
-
from typing import Any, Generic, Protocol, TypeVar, cast, final
|
3
|
+
from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast, final
|
4
4
|
|
5
5
|
from ..agent_message_pool import AgentMessage, AgentMessagePool
|
6
6
|
from ..comm_agent import CommunicatingAgent
|
7
7
|
from ..run_context import CtxT, RunContextWrapper
|
8
|
-
from ..typing.io import AgentID,
|
8
|
+
from ..typing.io import AgentID, AgentState, InT, OutT
|
9
9
|
from .workflow_agent import WorkflowAgent
|
10
10
|
|
11
11
|
logger = getLogger(__name__)
|
12
12
|
|
13
|
-
_EH_OutT = TypeVar("_EH_OutT",
|
13
|
+
_EH_OutT = TypeVar("_EH_OutT", contravariant=True) # noqa: PLC0105
|
14
14
|
|
15
15
|
|
16
16
|
class WorkflowLoopExitHandler(Protocol[_EH_OutT, CtxT]):
|
@@ -23,13 +23,16 @@ class WorkflowLoopExitHandler(Protocol[_EH_OutT, CtxT]):
|
|
23
23
|
|
24
24
|
|
25
25
|
class LoopedWorkflowAgent(WorkflowAgent[InT, OutT, CtxT], Generic[InT, OutT, CtxT]):
|
26
|
+
_generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
|
27
|
+
0: "_in_type",
|
28
|
+
1: "_out_type",
|
29
|
+
}
|
30
|
+
|
26
31
|
def __init__(
|
27
32
|
self,
|
28
33
|
agent_id: AgentID,
|
29
|
-
subagents: Sequence[
|
30
|
-
|
31
|
-
],
|
32
|
-
exit_agent: CommunicatingAgent[AgentPayload, OutT, AgentState, CtxT],
|
34
|
+
subagents: Sequence[CommunicatingAgent[Any, Any, AgentState, CtxT]],
|
35
|
+
exit_agent: CommunicatingAgent[Any, OutT, AgentState, CtxT],
|
33
36
|
message_pool: AgentMessagePool[CtxT] | None = None,
|
34
37
|
recipient_ids: list[AgentID] | None = None,
|
35
38
|
dynamic_routing: bool = False,
|
@@ -61,7 +64,7 @@ class LoopedWorkflowAgent(WorkflowAgent[InT, OutT, CtxT], Generic[InT, OutT, Ctx
|
|
61
64
|
|
62
65
|
return func
|
63
66
|
|
64
|
-
def
|
67
|
+
def _exit_workflow_loop(
|
65
68
|
self,
|
66
69
|
output_message: AgentMessage[OutT, AgentState],
|
67
70
|
ctx: RunContextWrapper[CtxT] | None,
|
@@ -101,7 +104,7 @@ class LoopedWorkflowAgent(WorkflowAgent[InT, OutT, CtxT], Generic[InT, OutT, Ctx
|
|
101
104
|
if subagent is self._end_agent:
|
102
105
|
num_iterations += 1
|
103
106
|
exit_message = cast("AgentMessage[OutT, AgentState]", agent_message)
|
104
|
-
if self.
|
107
|
+
if self._exit_workflow_loop(exit_message, ctx=ctx):
|
105
108
|
return exit_message
|
106
109
|
if num_iterations >= self._max_iterations:
|
107
110
|
logger.info(
|
@@ -1,20 +1,23 @@
|
|
1
1
|
from collections.abc import Sequence
|
2
|
-
from typing import Any, Generic, cast, final
|
2
|
+
from typing import Any, ClassVar, Generic, cast, final
|
3
3
|
|
4
4
|
from ..agent_message_pool import AgentMessage, AgentMessagePool
|
5
5
|
from ..comm_agent import CommunicatingAgent
|
6
6
|
from ..run_context import CtxT, RunContextWrapper
|
7
|
-
from ..typing.io import AgentID,
|
7
|
+
from ..typing.io import AgentID, AgentState, InT, OutT
|
8
8
|
from .workflow_agent import WorkflowAgent
|
9
9
|
|
10
10
|
|
11
11
|
class SequentialWorkflowAgent(WorkflowAgent[InT, OutT, CtxT], Generic[InT, OutT, CtxT]):
|
12
|
+
_generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
|
13
|
+
0: "_in_type",
|
14
|
+
1: "_out_type",
|
15
|
+
}
|
16
|
+
|
12
17
|
def __init__(
|
13
18
|
self,
|
14
19
|
agent_id: AgentID,
|
15
|
-
subagents: Sequence[
|
16
|
-
CommunicatingAgent[AgentPayload, AgentPayload, AgentState, CtxT]
|
17
|
-
],
|
20
|
+
subagents: Sequence[CommunicatingAgent[Any, Any, AgentState, CtxT]],
|
18
21
|
message_pool: AgentMessagePool[CtxT] | None = None,
|
19
22
|
recipient_ids: list[AgentID] | None = None,
|
20
23
|
dynamic_routing: bool = False,
|
@@ -23,7 +26,7 @@ class SequentialWorkflowAgent(WorkflowAgent[InT, OutT, CtxT], Generic[InT, OutT,
|
|
23
26
|
super().__init__(
|
24
27
|
subagents=subagents,
|
25
28
|
start_agent=subagents[0],
|
26
|
-
end_agent=subagents[-1],
|
29
|
+
end_agent=subagents[-1],
|
27
30
|
agent_id=agent_id,
|
28
31
|
message_pool=message_pool,
|
29
32
|
recipient_ids=recipient_ids,
|
@@ -1,11 +1,11 @@
|
|
1
1
|
from abc import ABC, abstractmethod
|
2
2
|
from collections.abc import Sequence
|
3
|
-
from typing import Any, Generic
|
3
|
+
from typing import Any, ClassVar, Generic
|
4
4
|
|
5
5
|
from ..agent_message_pool import AgentMessage, AgentMessagePool
|
6
6
|
from ..comm_agent import CommunicatingAgent
|
7
7
|
from ..run_context import CtxT, RunContextWrapper
|
8
|
-
from ..typing.io import AgentID,
|
8
|
+
from ..typing.io import AgentID, AgentState, InT, OutT
|
9
9
|
|
10
10
|
|
11
11
|
class WorkflowAgent(
|
@@ -13,14 +13,17 @@ class WorkflowAgent(
|
|
13
13
|
ABC,
|
14
14
|
Generic[InT, OutT, CtxT],
|
15
15
|
):
|
16
|
+
_generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
|
17
|
+
0: "_in_type",
|
18
|
+
1: "_out_type",
|
19
|
+
}
|
20
|
+
|
16
21
|
def __init__(
|
17
22
|
self,
|
18
23
|
agent_id: AgentID,
|
19
|
-
subagents: Sequence[
|
20
|
-
|
21
|
-
],
|
22
|
-
start_agent: CommunicatingAgent[InT, AgentPayload, AgentState, CtxT],
|
23
|
-
end_agent: CommunicatingAgent[AgentPayload, OutT, AgentState, CtxT],
|
24
|
+
subagents: Sequence[CommunicatingAgent[Any, Any, AgentState, CtxT]],
|
25
|
+
start_agent: CommunicatingAgent[InT, Any, AgentState, CtxT],
|
26
|
+
end_agent: CommunicatingAgent[Any, OutT, AgentState, CtxT],
|
24
27
|
message_pool: AgentMessagePool[CtxT] | None = None,
|
25
28
|
recipient_ids: list[AgentID] | None = None,
|
26
29
|
dynamic_routing: bool = False,
|
@@ -28,6 +31,10 @@ class WorkflowAgent(
|
|
28
31
|
) -> None:
|
29
32
|
if not subagents:
|
30
33
|
raise ValueError("At least one step is required")
|
34
|
+
if start_agent not in subagents:
|
35
|
+
raise ValueError("Start agent must be in the subagents list")
|
36
|
+
if end_agent not in subagents:
|
37
|
+
raise ValueError("End agent must be in the subagents list")
|
31
38
|
|
32
39
|
self.subagents = subagents
|
33
40
|
|
@@ -36,8 +43,6 @@ class WorkflowAgent(
|
|
36
43
|
|
37
44
|
super().__init__(
|
38
45
|
agent_id=agent_id,
|
39
|
-
out_schema=end_agent.out_schema,
|
40
|
-
rcv_args_schema=start_agent.rcv_args_schema,
|
41
46
|
message_pool=message_pool,
|
42
47
|
recipient_ids=recipient_ids,
|
43
48
|
dynamic_routing=dynamic_routing,
|
@@ -48,11 +53,11 @@ class WorkflowAgent(
|
|
48
53
|
)
|
49
54
|
|
50
55
|
@property
|
51
|
-
def start_agent(self) -> CommunicatingAgent[InT,
|
56
|
+
def start_agent(self) -> CommunicatingAgent[InT, Any, AgentState, CtxT]:
|
52
57
|
return self._start_agent
|
53
58
|
|
54
59
|
@property
|
55
|
-
def end_agent(self) -> CommunicatingAgent[
|
60
|
+
def end_agent(self) -> CommunicatingAgent[Any, OutT, AgentState, CtxT]:
|
56
61
|
return self._end_agent
|
57
62
|
|
58
63
|
@abstractmethod
|