grasp_agents 0.1.18__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_agents/agent_message.py +2 -2
- grasp_agents/agent_message_pool.py +6 -8
- grasp_agents/base_agent.py +15 -36
- grasp_agents/cloud_llm.py +10 -6
- grasp_agents/comm_agent.py +39 -43
- grasp_agents/generics_utils.py +159 -0
- grasp_agents/llm.py +4 -0
- grasp_agents/llm_agent.py +126 -46
- grasp_agents/llm_agent_state.py +18 -12
- grasp_agents/prompt_builder.py +55 -28
- grasp_agents/rate_limiting/rate_limiter_chunked.py +49 -48
- grasp_agents/rate_limiting/types.py +19 -40
- grasp_agents/rate_limiting/utils.py +24 -27
- grasp_agents/run_context.py +2 -15
- grasp_agents/tool_orchestrator.py +34 -12
- grasp_agents/typing/content.py +2 -2
- grasp_agents/typing/converters.py +3 -1
- grasp_agents/typing/io.py +7 -11
- grasp_agents/typing/message.py +2 -2
- grasp_agents/typing/tool.py +26 -14
- grasp_agents/utils.py +90 -96
- grasp_agents/workflow/looped_agent.py +12 -9
- grasp_agents/workflow/sequential_agent.py +9 -6
- grasp_agents/workflow/workflow_agent.py +16 -11
- {grasp_agents-0.1.18.dist-info → grasp_agents-0.2.1.dist-info}/METADATA +37 -33
- grasp_agents-0.2.1.dist-info/RECORD +45 -0
- {grasp_agents-0.1.18.dist-info → grasp_agents-0.2.1.dist-info}/licenses/LICENSE.md +1 -1
- grasp_agents-0.1.18.dist-info/RECORD +0 -44
- {grasp_agents-0.1.18.dist-info → grasp_agents-0.2.1.dist-info}/WHEEL +0 -0
@@ -2,6 +2,7 @@ import asyncio
|
|
2
2
|
import functools
|
3
3
|
import logging
|
4
4
|
from collections.abc import Callable, Coroutine
|
5
|
+
from dataclasses import dataclass
|
5
6
|
from time import monotonic
|
6
7
|
from typing import Any, Generic, overload
|
7
8
|
|
@@ -9,21 +10,25 @@ from tqdm.autonotebook import tqdm
|
|
9
10
|
|
10
11
|
from ..utils import asyncio_gather_with_pbar
|
11
12
|
from .types import (
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
RetrievalCallableSingle,
|
13
|
+
P,
|
14
|
+
ProcessorCallableList,
|
15
|
+
ProcessorCallableSingle,
|
16
|
+
R,
|
17
|
+
RateLimWrapperWithArgsList,
|
18
|
+
RateLimWrapperWithArgsSingle,
|
19
|
+
T,
|
20
20
|
)
|
21
|
-
from .utils import
|
21
|
+
from .utils import partial_processor_callable, split_pos_args
|
22
22
|
|
23
23
|
logger = logging.getLogger(__name__)
|
24
24
|
|
25
25
|
|
26
|
-
|
26
|
+
@dataclass
|
27
|
+
class RateLimiterState:
|
28
|
+
next_request_time: float = 0.0
|
29
|
+
|
30
|
+
|
31
|
+
class RateLimiterC(Generic[T, R]):
|
27
32
|
def __init__(
|
28
33
|
self,
|
29
34
|
rpm: float,
|
@@ -40,9 +45,9 @@ class RateLimiterC(Generic[QueryT, QueryR]):
|
|
40
45
|
|
41
46
|
async def process_input(
|
42
47
|
self,
|
43
|
-
func_partial: Callable[[
|
44
|
-
inp:
|
45
|
-
) ->
|
48
|
+
func_partial: Callable[[T], Coroutine[Any, Any, R]],
|
49
|
+
inp: T,
|
50
|
+
) -> R:
|
46
51
|
async with self._semaphore:
|
47
52
|
async with self._lock:
|
48
53
|
now = monotonic()
|
@@ -53,11 +58,11 @@ class RateLimiterC(Generic[QueryT, QueryR]):
|
|
53
58
|
|
54
59
|
async def process_inputs(
|
55
60
|
self,
|
56
|
-
func_partial: Callable[[
|
57
|
-
inputs: list[
|
61
|
+
func_partial: Callable[[T], Coroutine[Any, Any, R]],
|
62
|
+
inputs: list[T],
|
58
63
|
no_tqdm: bool = False,
|
59
|
-
) -> list[
|
60
|
-
results: list[
|
64
|
+
) -> list[R]:
|
65
|
+
results: list[R] = []
|
61
66
|
for i in tqdm(
|
62
67
|
range(0, len(inputs), self._chunk_size),
|
63
68
|
disable=no_tqdm,
|
@@ -95,33 +100,30 @@ class RateLimiterC(Generic[QueryT, QueryR]):
|
|
95
100
|
|
96
101
|
@overload
|
97
102
|
def limit_rate(
|
98
|
-
call:
|
99
|
-
rate_limiter: RateLimiterC[
|
100
|
-
) ->
|
103
|
+
call: ProcessorCallableSingle[T, P, R],
|
104
|
+
rate_limiter: RateLimiterC[T, R] | None = None,
|
105
|
+
) -> ProcessorCallableSingle[T, P, R]: ...
|
101
106
|
|
102
107
|
|
103
108
|
@overload
|
104
109
|
def limit_rate(
|
105
110
|
call: None = None,
|
106
|
-
rate_limiter: RateLimiterC[
|
107
|
-
) ->
|
111
|
+
rate_limiter: RateLimiterC[T, R] | None = None,
|
112
|
+
) -> RateLimWrapperWithArgsSingle[T, P, R]: ...
|
108
113
|
|
109
114
|
|
110
115
|
def limit_rate(
|
111
|
-
call:
|
112
|
-
rate_limiter: RateLimiterC[
|
113
|
-
) ->
|
114
|
-
RetrievalCallableSingle[QueryT, QueryP, QueryR]
|
115
|
-
| RateLimDecoratorWithArgsSingle[QueryT, QueryP, QueryR]
|
116
|
-
):
|
116
|
+
call: ProcessorCallableSingle[T, P, R] | None = None,
|
117
|
+
rate_limiter: RateLimiterC[T, R] | None = None,
|
118
|
+
) -> ProcessorCallableSingle[T, P, R] | RateLimWrapperWithArgsSingle[T, P, R]:
|
117
119
|
if call is None:
|
118
120
|
return functools.partial(limit_rate, rate_limiter=rate_limiter)
|
119
121
|
|
120
122
|
@functools.wraps(call) # type: ignore
|
121
|
-
async def wrapper(*args: Any, **kwargs: Any) ->
|
122
|
-
inp:
|
123
|
-
self_obj, inp, other_args = split_pos_args(call, args)
|
124
|
-
call_partial =
|
123
|
+
async def wrapper(*args: Any, **kwargs: Any) -> R:
|
124
|
+
inp: T
|
125
|
+
self_obj, inp, other_args = split_pos_args(call, args)
|
126
|
+
call_partial = partial_processor_callable(call, self_obj, *other_args, **kwargs)
|
125
127
|
|
126
128
|
_rate_limiter = rate_limiter
|
127
129
|
if _rate_limiter is None:
|
@@ -136,39 +138,36 @@ def limit_rate(
|
|
136
138
|
|
137
139
|
@overload
|
138
140
|
def limit_rate_chunked(
|
139
|
-
call:
|
140
|
-
rate_limiter: RateLimiterC[
|
141
|
+
call: ProcessorCallableList[T, P, R],
|
142
|
+
rate_limiter: RateLimiterC[T, R] | None = None,
|
141
143
|
no_tqdm: bool | None = None,
|
142
|
-
) ->
|
144
|
+
) -> ProcessorCallableList[T, P, R]: ...
|
143
145
|
|
144
146
|
|
145
147
|
@overload
|
146
148
|
def limit_rate_chunked(
|
147
149
|
call: None = None,
|
148
|
-
rate_limiter: RateLimiterC[
|
150
|
+
rate_limiter: RateLimiterC[T, R] | None = None,
|
149
151
|
no_tqdm: bool | None = None,
|
150
|
-
) ->
|
152
|
+
) -> RateLimWrapperWithArgsList[T, P, R]: ...
|
151
153
|
|
152
154
|
|
153
155
|
def limit_rate_chunked(
|
154
|
-
call:
|
155
|
-
rate_limiter: RateLimiterC[
|
156
|
+
call: ProcessorCallableList[T, P, R] | None = None,
|
157
|
+
rate_limiter: RateLimiterC[T, R] | None = None,
|
156
158
|
no_tqdm: bool | None = None,
|
157
|
-
) ->
|
158
|
-
RetrievalCallableList[QueryT, QueryP, QueryR]
|
159
|
-
| RateLimDecoratorWithArgsList[QueryT, QueryP, QueryR]
|
160
|
-
):
|
159
|
+
) -> ProcessorCallableList[T, P, R] | RateLimWrapperWithArgsList[T, P, R]:
|
161
160
|
if call is None:
|
162
161
|
return functools.partial(
|
163
162
|
limit_rate_chunked, rate_limiter=rate_limiter, no_tqdm=no_tqdm
|
164
|
-
)
|
163
|
+
) # type: ignore
|
165
164
|
|
166
165
|
@functools.wraps(call) # type: ignore
|
167
|
-
async def wrapper(*args: Any, **kwargs: Any) -> list[
|
166
|
+
async def wrapper(*args: Any, **kwargs: Any) -> list[R]:
|
168
167
|
assert call is not None
|
169
168
|
|
170
|
-
self_obj, inputs, other_args = split_pos_args(call, args)
|
171
|
-
call_partial =
|
169
|
+
self_obj, inputs, other_args = split_pos_args(call, args)
|
170
|
+
call_partial = partial_processor_callable(call, self_obj, *other_args, **kwargs)
|
172
171
|
|
173
172
|
_no_tqdm = no_tqdm
|
174
173
|
_rate_limiter = rate_limiter
|
@@ -182,7 +181,9 @@ def limit_rate_chunked(
|
|
182
181
|
*[call_partial(inp) for inp in inputs], no_tqdm=_no_tqdm
|
183
182
|
)
|
184
183
|
return await _rate_limiter.process_inputs(
|
185
|
-
func_partial=call_partial,
|
184
|
+
func_partial=call_partial, # type: ignore
|
185
|
+
inputs=inputs,
|
186
|
+
no_tqdm=_no_tqdm,
|
186
187
|
)
|
187
188
|
|
188
189
|
return wrapper
|
@@ -1,57 +1,36 @@
|
|
1
1
|
from collections.abc import Callable, Coroutine
|
2
|
-
from
|
3
|
-
from typing import (
|
4
|
-
Any,
|
5
|
-
Concatenate,
|
6
|
-
ParamSpec,
|
7
|
-
TypeAlias,
|
8
|
-
TypeVar,
|
9
|
-
)
|
10
|
-
|
11
|
-
MAX_RPM = 1e10
|
12
|
-
|
2
|
+
from typing import Any, Concatenate, ParamSpec, TypeAlias, TypeVar
|
13
3
|
|
14
|
-
|
15
|
-
|
16
|
-
|
4
|
+
T = TypeVar("T")
|
5
|
+
R = TypeVar("R")
|
6
|
+
P = ParamSpec("P")
|
17
7
|
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
QueryP = ParamSpec("QueryP")
|
22
|
-
|
23
|
-
RetrievalFuncSingle: TypeAlias = Callable[
|
24
|
-
Concatenate[QueryT, QueryP], Coroutine[Any, Any, QueryR]
|
25
|
-
]
|
26
|
-
RetrievalFuncList: TypeAlias = Callable[
|
27
|
-
Concatenate[list[QueryT], QueryP], Coroutine[Any, Any, list[QueryR]]
|
8
|
+
ProcessorFuncSingle: TypeAlias = Callable[Concatenate[T, P], Coroutine[Any, Any, R]]
|
9
|
+
ProcessorFuncList: TypeAlias = Callable[
|
10
|
+
Concatenate[list[T], P], Coroutine[Any, Any, list[R]]
|
28
11
|
]
|
29
12
|
|
30
|
-
|
31
|
-
Concatenate[Any,
|
13
|
+
ProcessorMethodSingle: TypeAlias = Callable[
|
14
|
+
Concatenate[Any, T, P], Coroutine[Any, Any, R]
|
32
15
|
]
|
33
|
-
|
34
|
-
Concatenate[Any, list[
|
16
|
+
ProcessorMethodList: TypeAlias = Callable[
|
17
|
+
Concatenate[Any, list[T], P], Coroutine[Any, Any, list[R]]
|
35
18
|
]
|
36
19
|
|
37
|
-
|
38
|
-
|
39
|
-
| RetrievalMethodSingle[QueryT, QueryP, QueryR]
|
20
|
+
ProcessorCallableSingle: TypeAlias = (
|
21
|
+
ProcessorFuncSingle[T, P, R] | ProcessorMethodSingle[T, P, R]
|
40
22
|
)
|
41
23
|
|
42
|
-
|
43
|
-
|
44
|
-
| RetrievalMethodList[QueryT, QueryP, QueryR]
|
24
|
+
ProcessorCallableList: TypeAlias = (
|
25
|
+
ProcessorFuncList[T, P, R] | ProcessorMethodList[T, P, R]
|
45
26
|
)
|
46
27
|
|
47
28
|
|
48
|
-
|
49
|
-
[
|
50
|
-
RetrievalCallableSingle[QueryT, QueryP, QueryR],
|
29
|
+
RateLimWrapperWithArgsSingle = Callable[
|
30
|
+
[ProcessorCallableSingle[T, P, R]], ProcessorCallableSingle[T, P, R]
|
51
31
|
]
|
52
32
|
|
53
33
|
|
54
|
-
|
55
|
-
[
|
56
|
-
RetrievalCallableList[QueryT, QueryP, QueryR],
|
34
|
+
RateLimWrapperWithArgsList = Callable[
|
35
|
+
[ProcessorCallableList[T, P, R]], ProcessorCallableList[T, P, R]
|
57
36
|
]
|
@@ -1,16 +1,8 @@
|
|
1
1
|
import inspect
|
2
2
|
from collections.abc import Callable, Coroutine, Sequence
|
3
|
-
from typing import
|
4
|
-
Any,
|
5
|
-
)
|
3
|
+
from typing import Any, overload
|
6
4
|
|
7
|
-
from .types import
|
8
|
-
QueryP,
|
9
|
-
QueryR,
|
10
|
-
QueryT,
|
11
|
-
RetrievalCallableList,
|
12
|
-
RetrievalCallableSingle,
|
13
|
-
)
|
5
|
+
from .types import P, ProcessorCallableList, ProcessorCallableSingle, R, T
|
14
6
|
|
15
7
|
|
16
8
|
def is_bound_method(func: Callable[..., Any], self_candidate: Any) -> bool:
|
@@ -19,13 +11,24 @@ def is_bound_method(func: Callable[..., Any], self_candidate: Any) -> bool:
|
|
19
11
|
)
|
20
12
|
|
21
13
|
|
14
|
+
@overload
|
22
15
|
def split_pos_args(
|
23
|
-
call:
|
24
|
-
RetrievalCallableSingle[QueryT, QueryP, QueryR]
|
25
|
-
| RetrievalCallableList[QueryT, QueryP, QueryR]
|
26
|
-
),
|
16
|
+
call: ProcessorCallableSingle[T, P, R],
|
27
17
|
args: Sequence[Any],
|
28
|
-
) -> tuple[Any | None,
|
18
|
+
) -> tuple[Any | None, T, Sequence[Any]]: ...
|
19
|
+
|
20
|
+
|
21
|
+
@overload
|
22
|
+
def split_pos_args(
|
23
|
+
call: ProcessorCallableList[T, P, R],
|
24
|
+
args: Sequence[Any],
|
25
|
+
) -> tuple[Any | None, list[T], Sequence[Any]]: ...
|
26
|
+
|
27
|
+
|
28
|
+
def split_pos_args(
|
29
|
+
call: (ProcessorCallableSingle[T, P, R] | ProcessorCallableList[T, P, R]),
|
30
|
+
args: Sequence[Any],
|
31
|
+
) -> tuple[Any | None, T | list[T], Sequence[Any]]:
|
29
32
|
if not args:
|
30
33
|
raise ValueError("No positional arguments passed.")
|
31
34
|
maybe_self = args[0]
|
@@ -45,13 +48,13 @@ def split_pos_args(
|
|
45
48
|
return None, args[0], args[1:]
|
46
49
|
|
47
50
|
|
48
|
-
def
|
49
|
-
call: Callable[..., Coroutine[Any, Any,
|
51
|
+
def partial_processor_callable(
|
52
|
+
call: Callable[..., Coroutine[Any, Any, R]],
|
50
53
|
self_obj: Any,
|
51
|
-
*args:
|
52
|
-
**kwargs:
|
53
|
-
) -> Callable[[
|
54
|
-
async def wrapper(inp:
|
54
|
+
*args: Any,
|
55
|
+
**kwargs: Any,
|
56
|
+
) -> Callable[[Any], Coroutine[Any, Any, R]]:
|
57
|
+
async def wrapper(inp: Any) -> R:
|
55
58
|
if self_obj is not None:
|
56
59
|
# `call` is a method
|
57
60
|
return await call(self_obj, inp, *args, **kwargs)
|
@@ -59,9 +62,3 @@ def partial_retrieval_callable(
|
|
59
62
|
return await call(inp, *args, **kwargs)
|
60
63
|
|
61
64
|
return wrapper
|
62
|
-
|
63
|
-
|
64
|
-
def expected_exec_time_from_max_concurrency_and_rpm(
|
65
|
-
rpm: float, max_concurrency: int
|
66
|
-
) -> float:
|
67
|
-
return 60.0 / (rpm / max_concurrency)
|
grasp_agents/run_context.py
CHANGED
@@ -8,7 +8,6 @@ from .printer import Printer
|
|
8
8
|
from .typing.content import ImageData
|
9
9
|
from .typing.io import (
|
10
10
|
AgentID,
|
11
|
-
AgentPayload,
|
12
11
|
AgentState,
|
13
12
|
InT,
|
14
13
|
LLMPrompt,
|
@@ -44,9 +43,7 @@ class InteractionRecord(BaseModel, Generic[InT, OutT, StateT]):
|
|
44
43
|
model_config = ConfigDict(extra="forbid", frozen=True)
|
45
44
|
|
46
45
|
|
47
|
-
InteractionHistory: TypeAlias = list[
|
48
|
-
InteractionRecord[AgentPayload, AgentPayload, AgentState]
|
49
|
-
]
|
46
|
+
InteractionHistory: TypeAlias = list[InteractionRecord[Any, Any, AgentState]]
|
50
47
|
|
51
48
|
|
52
49
|
CtxT = TypeVar("CtxT")
|
@@ -56,23 +53,13 @@ class RunContextWrapper(BaseModel, Generic[CtxT]):
|
|
56
53
|
context: CtxT | None = None
|
57
54
|
run_id: str = Field(default_factory=lambda: str(uuid4())[:8], frozen=True)
|
58
55
|
run_args: dict[AgentID, RunArgs] = Field(default_factory=dict)
|
59
|
-
interaction_history: InteractionHistory = Field(default_factory=list)
|
56
|
+
interaction_history: InteractionHistory = Field(default_factory=list) # type: ignore[valid-type]
|
60
57
|
|
61
58
|
print_messages: bool = False
|
62
59
|
|
63
60
|
_usage_tracker: UsageTracker = PrivateAttr()
|
64
61
|
_printer: Printer = PrivateAttr()
|
65
62
|
|
66
|
-
# usage_tracker: Optional[UsageTracker] = None
|
67
|
-
# printer: Optional[Printer] = None
|
68
|
-
|
69
|
-
# @model_validator(mode="after")
|
70
|
-
# def set_usage_tracker_and_printer(self) -> "RunContextWrapper":
|
71
|
-
# self.usage_tracker = UsageTracker(source_id=self.run_id)
|
72
|
-
# self.printer = Printer(source_id=self.run_id)
|
73
|
-
|
74
|
-
# return self
|
75
|
-
|
76
63
|
def model_post_init(self, context: Any) -> None: # noqa: ARG002
|
77
64
|
self._usage_tracker = UsageTracker(source_id=self.run_id)
|
78
65
|
self._printer = Printer(
|
@@ -16,7 +16,7 @@ from .typing.tool import BaseTool, ToolCall, ToolChoice
|
|
16
16
|
logger = getLogger(__name__)
|
17
17
|
|
18
18
|
|
19
|
-
class
|
19
|
+
class ExitToolCallLoopHandler(Protocol[CtxT]):
|
20
20
|
def __call__(
|
21
21
|
self,
|
22
22
|
conversation: Conversation,
|
@@ -26,6 +26,16 @@ class ToolCallLoopExitHandler(Protocol[CtxT]):
|
|
26
26
|
) -> bool: ...
|
27
27
|
|
28
28
|
|
29
|
+
class ManageAgentStateHandler(Protocol[CtxT]):
|
30
|
+
def __call__(
|
31
|
+
self,
|
32
|
+
state: LLMAgentState,
|
33
|
+
*,
|
34
|
+
ctx: RunContextWrapper[CtxT] | None,
|
35
|
+
**kwargs: Any,
|
36
|
+
) -> None: ...
|
37
|
+
|
38
|
+
|
29
39
|
class ToolOrchestrator(Generic[CtxT]):
|
30
40
|
def __init__(
|
31
41
|
self,
|
@@ -38,13 +48,13 @@ class ToolOrchestrator(Generic[CtxT]):
|
|
38
48
|
self._agent_id = agent_id
|
39
49
|
|
40
50
|
self._llm = llm
|
41
|
-
self.
|
42
|
-
self.llm.tools = tools
|
51
|
+
self._llm.tools = tools
|
43
52
|
|
44
53
|
self._max_turns = max_turns
|
45
54
|
self._react_mode = react_mode
|
46
55
|
|
47
|
-
self.
|
56
|
+
self.exit_tool_call_loop_impl: ExitToolCallLoopHandler[CtxT] | None = None
|
57
|
+
self.manage_agent_state_impl: ManageAgentStateHandler[CtxT] | None = None
|
48
58
|
|
49
59
|
@property
|
50
60
|
def agent_id(self) -> str:
|
@@ -62,15 +72,15 @@ class ToolOrchestrator(Generic[CtxT]):
|
|
62
72
|
def max_turns(self) -> int:
|
63
73
|
return self._max_turns
|
64
74
|
|
65
|
-
def
|
75
|
+
def _exit_tool_call_loop(
|
66
76
|
self,
|
67
77
|
conversation: Conversation,
|
68
78
|
*,
|
69
79
|
ctx: RunContextWrapper[CtxT] | None = None,
|
70
80
|
**kwargs: Any,
|
71
81
|
) -> bool:
|
72
|
-
if self.
|
73
|
-
return self.
|
82
|
+
if self.exit_tool_call_loop_impl:
|
83
|
+
return self.exit_tool_call_loop_impl(
|
74
84
|
conversation=conversation, ctx=ctx, **kwargs
|
75
85
|
)
|
76
86
|
|
@@ -81,6 +91,16 @@ class ToolOrchestrator(Generic[CtxT]):
|
|
81
91
|
|
82
92
|
return not bool(conversation[-1].tool_calls)
|
83
93
|
|
94
|
+
def _manage_agent_state(
|
95
|
+
self,
|
96
|
+
state: LLMAgentState,
|
97
|
+
*,
|
98
|
+
ctx: RunContextWrapper[CtxT] | None = None,
|
99
|
+
**kwargs: Any,
|
100
|
+
) -> None:
|
101
|
+
if self.manage_agent_state_impl:
|
102
|
+
self.manage_agent_state_impl(state=state, ctx=ctx, **kwargs)
|
103
|
+
|
84
104
|
async def generate_once(
|
85
105
|
self,
|
86
106
|
agent_state: LLMAgentState,
|
@@ -99,10 +119,10 @@ class ToolOrchestrator(Generic[CtxT]):
|
|
99
119
|
|
100
120
|
async def run_loop(
|
101
121
|
self,
|
102
|
-
|
122
|
+
state: LLMAgentState,
|
103
123
|
ctx: RunContextWrapper[CtxT] | None = None,
|
104
124
|
) -> None:
|
105
|
-
message_history =
|
125
|
+
message_history = state.message_history
|
106
126
|
assert message_history.batch_size == 1, (
|
107
127
|
"Batch size must be 1 for tool call loop"
|
108
128
|
)
|
@@ -111,13 +131,15 @@ class ToolOrchestrator(Generic[CtxT]):
|
|
111
131
|
|
112
132
|
tool_choice = "none" if self._react_mode else "auto"
|
113
133
|
gen_message_batch = await self.generate_once(
|
114
|
-
|
134
|
+
state, tool_choice=tool_choice, ctx=ctx
|
115
135
|
)
|
116
136
|
|
117
137
|
turns = 0
|
118
138
|
|
119
139
|
while True:
|
120
|
-
|
140
|
+
self._manage_agent_state(state=state, ctx=ctx, num_turns=turns)
|
141
|
+
|
142
|
+
if self._exit_tool_call_loop(
|
121
143
|
message_history.batched_conversations[0], ctx=ctx, num_turns=turns
|
122
144
|
):
|
123
145
|
return
|
@@ -134,7 +156,7 @@ class ToolOrchestrator(Generic[CtxT]):
|
|
134
156
|
|
135
157
|
tool_choice = "none" if (self._react_mode and msg.tool_calls) else "auto"
|
136
158
|
gen_message_batch = await self.generate_once(
|
137
|
-
|
159
|
+
state, tool_choice=tool_choice, ctx=ctx
|
138
160
|
)
|
139
161
|
|
140
162
|
turns += 1
|
grasp_agents/typing/content.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
import base64
|
2
2
|
import re
|
3
|
-
from collections.abc import Iterable
|
3
|
+
from collections.abc import Iterable, Mapping
|
4
4
|
from enum import StrEnum
|
5
5
|
from pathlib import Path
|
6
6
|
from typing import Annotated, Any, Literal, TypeAlias
|
@@ -66,7 +66,7 @@ class Content(BaseModel):
|
|
66
66
|
def from_formatted_prompt(
|
67
67
|
cls,
|
68
68
|
prompt_template: str,
|
69
|
-
prompt_args:
|
69
|
+
prompt_args: Mapping[str, str | int | bool | ImageData] | None = None,
|
70
70
|
) -> "Content":
|
71
71
|
prompt_args = prompt_args or {}
|
72
72
|
image_args = {
|
@@ -2,6 +2,8 @@ from abc import ABC, abstractmethod
|
|
2
2
|
from collections.abc import AsyncIterator
|
3
3
|
from typing import Any
|
4
4
|
|
5
|
+
from pydantic import BaseModel
|
6
|
+
|
5
7
|
from .completion import Completion, CompletionChunk
|
6
8
|
from .content import Content
|
7
9
|
from .message import AssistantMessage, Message, SystemMessage, ToolMessage, UserMessage
|
@@ -64,7 +66,7 @@ class Converters(ABC):
|
|
64
66
|
|
65
67
|
@staticmethod
|
66
68
|
@abstractmethod
|
67
|
-
def to_tool(tool: BaseTool[
|
69
|
+
def to_tool(tool: BaseTool[BaseModel, Any, Any], **kwargs: Any) -> Any:
|
68
70
|
pass
|
69
71
|
|
70
72
|
@staticmethod
|
grasp_agents/typing/io.py
CHANGED
@@ -1,3 +1,4 @@
|
|
1
|
+
from collections.abc import Mapping
|
1
2
|
from typing import TypeAlias, TypeVar
|
2
3
|
|
3
4
|
from pydantic import BaseModel
|
@@ -7,23 +8,18 @@ from .content import ImageData
|
|
7
8
|
AgentID: TypeAlias = str
|
8
9
|
|
9
10
|
|
10
|
-
class AgentPayload(BaseModel):
|
11
|
-
pass
|
12
|
-
|
13
|
-
|
14
11
|
class AgentState(BaseModel):
|
15
12
|
pass
|
16
13
|
|
17
14
|
|
18
|
-
InT = TypeVar("InT", bound=AgentPayload, contravariant=True) # noqa: PLC0105
|
19
|
-
OutT = TypeVar("OutT", bound=AgentPayload, covariant=True) # noqa: PLC0105
|
20
|
-
StateT = TypeVar("StateT", bound=AgentState, covariant=True) # noqa: PLC0105
|
21
|
-
|
22
|
-
|
23
15
|
class LLMPromptArgs(BaseModel):
|
24
16
|
pass
|
25
17
|
|
26
18
|
|
19
|
+
InT = TypeVar("InT", contravariant=True) # noqa: PLC0105
|
20
|
+
OutT = TypeVar("OutT", covariant=True) # noqa: PLC0105
|
21
|
+
StateT = TypeVar("StateT", bound=AgentState, covariant=True) # noqa: PLC0105
|
22
|
+
|
27
23
|
LLMPrompt: TypeAlias = str
|
28
|
-
LLMFormattedSystemArgs: TypeAlias =
|
29
|
-
LLMFormattedArgs: TypeAlias =
|
24
|
+
LLMFormattedSystemArgs: TypeAlias = Mapping[str, str | int | bool]
|
25
|
+
LLMFormattedArgs: TypeAlias = Mapping[str, str | int | bool | ImageData]
|
grasp_agents/typing/message.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1
1
|
import json
|
2
|
-
from collections.abc import Hashable, Sequence
|
2
|
+
from collections.abc import Hashable, Mapping, Sequence
|
3
3
|
from enum import StrEnum
|
4
4
|
from typing import Annotated, Any, Literal, TypeAlias
|
5
5
|
from uuid import uuid4
|
@@ -79,7 +79,7 @@ class UserMessage(MessageBase):
|
|
79
79
|
def from_formatted_prompt(
|
80
80
|
cls,
|
81
81
|
prompt_template: str,
|
82
|
-
prompt_args:
|
82
|
+
prompt_args: Mapping[str, str | int | bool | ImageData] | None = None,
|
83
83
|
model_id: str | None = None,
|
84
84
|
) -> "UserMessage":
|
85
85
|
content = Content.from_formatted_prompt(
|
grasp_agents/typing/tool.py
CHANGED
@@ -1,11 +1,11 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
import asyncio
|
4
3
|
from abc import ABC, abstractmethod
|
5
|
-
from
|
6
|
-
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, TypeVar
|
4
|
+
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, TypeAlias, TypeVar
|
7
5
|
|
8
|
-
from pydantic import BaseModel, TypeAdapter
|
6
|
+
from pydantic import BaseModel, PrivateAttr, TypeAdapter
|
7
|
+
|
8
|
+
from ..generics_utils import AutoInstanceAttributesMixin
|
9
9
|
|
10
10
|
if TYPE_CHECKING:
|
11
11
|
from ..run_context import CtxT, RunContextWrapper
|
@@ -26,32 +26,44 @@ class ToolCall(BaseModel):
|
|
26
26
|
tool_arguments: str
|
27
27
|
|
28
28
|
|
29
|
-
class BaseTool(
|
29
|
+
class BaseTool(
|
30
|
+
AutoInstanceAttributesMixin, BaseModel, ABC, Generic[_ToolInT, _ToolOutT, CtxT]
|
31
|
+
):
|
32
|
+
_generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
|
33
|
+
0: "_in_schema",
|
34
|
+
1: "_out_schema",
|
35
|
+
}
|
36
|
+
|
30
37
|
name: str
|
31
38
|
description: str
|
32
|
-
|
33
|
-
|
39
|
+
|
40
|
+
_in_schema: type[_ToolInT] = PrivateAttr()
|
41
|
+
_out_schema: type[_ToolOutT] = PrivateAttr()
|
34
42
|
|
35
43
|
# Supported by OpenAI API
|
36
44
|
strict: bool | None = None
|
37
45
|
|
46
|
+
@property
|
47
|
+
def in_schema(self) -> type[_ToolInT]: # type: ignore[reportInvalidTypeVarUse]
|
48
|
+
# Exposing the type of a contravariant variable only, should be type safe
|
49
|
+
return self._in_schema
|
50
|
+
|
51
|
+
@property
|
52
|
+
def out_schema(self) -> type[_ToolOutT]:
|
53
|
+
return self._out_schema
|
54
|
+
|
38
55
|
@abstractmethod
|
39
56
|
async def run(
|
40
57
|
self, inp: _ToolInT, ctx: RunContextWrapper[CtxT] | None = None
|
41
58
|
) -> _ToolOutT:
|
42
59
|
pass
|
43
60
|
|
44
|
-
async def run_batch(
|
45
|
-
self, inp_batch: Sequence[_ToolInT], ctx: RunContextWrapper[CtxT] | None = None
|
46
|
-
) -> Sequence[_ToolOutT]:
|
47
|
-
return await asyncio.gather(*[self.run(inp, ctx=ctx) for inp in inp_batch])
|
48
|
-
|
49
61
|
async def __call__(
|
50
62
|
self, ctx: RunContextWrapper[CtxT] | None = None, **kwargs: Any
|
51
63
|
) -> _ToolOutT:
|
52
|
-
result = await self.run(self.
|
64
|
+
result = await self.run(self._in_schema(**kwargs), ctx=ctx)
|
53
65
|
|
54
|
-
return TypeAdapter(self.
|
66
|
+
return TypeAdapter(self._out_schema).validate_python(result)
|
55
67
|
|
56
68
|
|
57
69
|
ToolChoice: TypeAlias = (
|