grasp_agents 0.1.15__py3-none-any.whl → 0.1.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_agents/agent_message.py +0 -1
- grasp_agents/base_agent.py +1 -1
- grasp_agents/cloud_llm.py +83 -40
- grasp_agents/comm_agent.py +40 -49
- grasp_agents/llm.py +6 -6
- grasp_agents/llm_agent.py +81 -63
- grasp_agents/memory.py +0 -6
- grasp_agents/openai/completion_converters.py +4 -3
- grasp_agents/openai/converters.py +2 -8
- grasp_agents/openai/message_converters.py +1 -6
- grasp_agents/openai/openai_llm.py +4 -6
- grasp_agents/openai/tool_converters.py +1 -1
- grasp_agents/{data_retrieval → rate_limiting}/rate_limiter_chunked.py +2 -9
- grasp_agents/{data_retrieval → rate_limiting}/utils.py +15 -5
- grasp_agents/tool_orchestrator.py +2 -2
- grasp_agents/typing/converters.py +2 -10
- grasp_agents/typing/io.py +1 -4
- grasp_agents/typing/message.py +5 -3
- grasp_agents/typing/tool.py +18 -11
- grasp_agents/utils.py +114 -65
- grasp_agents-0.1.17.dist-info/METADATA +212 -0
- grasp_agents-0.1.17.dist-info/RECORD +44 -0
- grasp_agents-0.1.15.dist-info/METADATA +0 -152
- grasp_agents-0.1.15.dist-info/RECORD +0 -44
- /grasp_agents/{data_retrieval → rate_limiting}/__init__.py +0 -0
- /grasp_agents/{data_retrieval → rate_limiting}/types.py +0 -0
- {grasp_agents-0.1.15.dist-info → grasp_agents-0.1.17.dist-info}/WHEEL +0 -0
- {grasp_agents-0.1.15.dist-info → grasp_agents-0.1.17.dist-info}/licenses/LICENSE.md +0 -0
grasp_agents/llm_agent.py
CHANGED
@@ -33,6 +33,8 @@ from .typing.io import (
|
|
33
33
|
AgentPayload,
|
34
34
|
AgentState,
|
35
35
|
InT,
|
36
|
+
LLMFormattedArgs,
|
37
|
+
LLMFormattedSystemArgs,
|
36
38
|
LLMPrompt,
|
37
39
|
LLMPromptArgs,
|
38
40
|
OutT,
|
@@ -67,7 +69,7 @@ class LLMAgent(
|
|
67
69
|
# Output schema
|
68
70
|
out_schema: type[OutT] = cast("type[OutT]", AgentPayload),
|
69
71
|
# Tools
|
70
|
-
tools: list[BaseTool[
|
72
|
+
tools: list[BaseTool[Any, Any, CtxT]] | None = None,
|
71
73
|
max_turns: int = 1000,
|
72
74
|
react_mode: bool = False,
|
73
75
|
# Agent state management
|
@@ -75,7 +77,6 @@ class LLMAgent(
|
|
75
77
|
# Multi-agent routing
|
76
78
|
message_pool: AgentMessagePool[CtxT] | None = None,
|
77
79
|
recipient_ids: list[AgentID] | None = None,
|
78
|
-
dynamic_routing: bool = False,
|
79
80
|
) -> None:
|
80
81
|
super().__init__(
|
81
82
|
agent_id=agent_id,
|
@@ -83,7 +84,6 @@ class LLMAgent(
|
|
83
84
|
rcv_args_schema=rcv_args_schema,
|
84
85
|
message_pool=message_pool,
|
85
86
|
recipient_ids=recipient_ids,
|
86
|
-
dynamic_routing=dynamic_routing,
|
87
87
|
)
|
88
88
|
|
89
89
|
# Agent state
|
@@ -114,12 +114,24 @@ class LLMAgent(
|
|
114
114
|
|
115
115
|
self.no_tqdm = getattr(llm, "no_tqdm", False)
|
116
116
|
|
117
|
+
if type(self)._format_sys_args is not LLMAgent[Any, Any, Any]._format_sys_args: # noqa: SLF001
|
118
|
+
self._prompt_builder.format_sys_args_impl = self._format_sys_args
|
119
|
+
|
120
|
+
if type(self)._format_inp_args is not LLMAgent[Any, Any, Any]._format_inp_args: # noqa: SLF001
|
121
|
+
self._prompt_builder.format_inp_args_impl = self._format_inp_args
|
122
|
+
|
123
|
+
if (
|
124
|
+
type(self)._tool_call_loop_exit # noqa: SLF001
|
125
|
+
is not LLMAgent[Any, Any, Any]._tool_call_loop_exit # noqa: SLF001
|
126
|
+
):
|
127
|
+
self._tool_orchestrator.tool_call_loop_exit_impl = self._tool_call_loop_exit
|
128
|
+
|
117
129
|
@property
|
118
130
|
def llm(self) -> LLM[LLMSettings, Converters]:
|
119
131
|
return self._tool_orchestrator.llm
|
120
132
|
|
121
133
|
@property
|
122
|
-
def tools(self) -> dict[str, BaseTool[BaseModel,
|
134
|
+
def tools(self) -> dict[str, BaseTool[BaseModel, Any, CtxT]]:
|
123
135
|
return self._tool_orchestrator.tools
|
124
136
|
|
125
137
|
@property
|
@@ -142,37 +154,10 @@ class LLMAgent(
|
|
142
154
|
def inp_prompt(self) -> LLMPrompt | None:
|
143
155
|
return self._prompt_builder.inp_prompt
|
144
156
|
|
145
|
-
def format_sys_args_handler(
|
146
|
-
self, func: FormatSystemArgsHandler[CtxT]
|
147
|
-
) -> FormatSystemArgsHandler[CtxT]:
|
148
|
-
self._prompt_builder.format_sys_args_impl = func
|
149
|
-
|
150
|
-
return func
|
151
|
-
|
152
|
-
def format_inp_args_handler(
|
153
|
-
self, func: FormatInputArgsHandler[InT, CtxT]
|
154
|
-
) -> FormatInputArgsHandler[InT, CtxT]:
|
155
|
-
self._prompt_builder.format_inp_args_impl = func
|
156
|
-
|
157
|
-
return func
|
158
|
-
|
159
|
-
def make_custom_agent_state_handler(
|
160
|
-
self, func: MakeCustomAgentState
|
161
|
-
) -> MakeCustomAgentState:
|
162
|
-
self._make_custom_agent_state_impl = func
|
163
|
-
|
164
|
-
return func
|
165
|
-
|
166
|
-
def tool_call_loop_exit_handler(
|
167
|
-
self, func: ToolCallLoopExitHandler[CtxT]
|
168
|
-
) -> ToolCallLoopExitHandler[CtxT]:
|
169
|
-
self._tool_orchestrator.tool_call_loop_exit_impl = func
|
170
|
-
|
171
|
-
return func
|
172
|
-
|
173
157
|
def _parse_output(
|
174
158
|
self,
|
175
159
|
conversation: Conversation,
|
160
|
+
*,
|
176
161
|
rcv_args: InT | None = None,
|
177
162
|
ctx: RunContextWrapper[CtxT] | None = None,
|
178
163
|
**kwargs: Any,
|
@@ -274,10 +259,7 @@ class LLMAgent(
|
|
274
259
|
|
275
260
|
# 6. Write interaction history to context
|
276
261
|
|
277
|
-
|
278
|
-
recipient_ids = self._validate_dynamic_routing(val_output_batch)
|
279
|
-
else:
|
280
|
-
recipient_ids = self._validate_static_routing(val_output_batch)
|
262
|
+
recipient_ids = self._validate_routing(val_output_batch)
|
281
263
|
|
282
264
|
if ctx:
|
283
265
|
interaction_record = InteractionRecord(
|
@@ -332,30 +314,66 @@ class LLMAgent(
|
|
332
314
|
):
|
333
315
|
self._print_msgs([state.message_history[0][0]], ctx=ctx)
|
334
316
|
|
335
|
-
#
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
317
|
+
# -- Handlers for custom implementations --
|
318
|
+
|
319
|
+
def format_sys_args_handler(
|
320
|
+
self, func: FormatSystemArgsHandler[CtxT]
|
321
|
+
) -> FormatSystemArgsHandler[CtxT]:
|
322
|
+
self._prompt_builder.format_sys_args_impl = func
|
323
|
+
|
324
|
+
return func
|
325
|
+
|
326
|
+
def format_inp_args_handler(
|
327
|
+
self, func: FormatInputArgsHandler[InT, CtxT]
|
328
|
+
) -> FormatInputArgsHandler[InT, CtxT]:
|
329
|
+
self._prompt_builder.format_inp_args_impl = func
|
330
|
+
|
331
|
+
return func
|
332
|
+
|
333
|
+
def make_custom_agent_state_handler(
|
334
|
+
self, func: MakeCustomAgentState
|
335
|
+
) -> MakeCustomAgentState:
|
336
|
+
self._make_custom_agent_state_impl = func
|
337
|
+
|
338
|
+
return func
|
339
|
+
|
340
|
+
def tool_call_loop_exit_handler(
|
341
|
+
self, func: ToolCallLoopExitHandler[CtxT]
|
342
|
+
) -> ToolCallLoopExitHandler[CtxT]:
|
343
|
+
self._tool_orchestrator.tool_call_loop_exit_impl = func
|
344
|
+
|
345
|
+
return func
|
346
|
+
|
347
|
+
# -- Override these methods in subclasses if needed --
|
348
|
+
|
349
|
+
def _format_sys_args(
|
350
|
+
self,
|
351
|
+
sys_args: LLMPromptArgs,
|
352
|
+
*,
|
353
|
+
ctx: RunContextWrapper[CtxT] | None = None,
|
354
|
+
) -> LLMFormattedSystemArgs:
|
355
|
+
raise NotImplementedError(
|
356
|
+
"LLMAgent._format_sys_args must be overridden by a subclass "
|
357
|
+
"if it's intended to be used as the system arguments formatter."
|
358
|
+
)
|
359
|
+
|
360
|
+
def _format_inp_args(
|
361
|
+
self,
|
362
|
+
usr_args: LLMPromptArgs,
|
363
|
+
rcv_args: InT,
|
364
|
+
ctx: RunContextWrapper[CtxT] | None = None,
|
365
|
+
) -> LLMFormattedArgs:
|
366
|
+
raise NotImplementedError(
|
367
|
+
"LLMAgent._format_inp_args must be overridden by a subclass"
|
368
|
+
)
|
369
|
+
|
370
|
+
def _tool_call_loop_exit(
|
371
|
+
self,
|
372
|
+
conversation: Conversation,
|
373
|
+
*,
|
374
|
+
ctx: RunContextWrapper[CtxT] | None = None,
|
375
|
+
**kwargs: Any,
|
376
|
+
) -> bool:
|
377
|
+
raise NotImplementedError(
|
378
|
+
"LLMAgent._tool_call_loop_exit must be overridden by a subclass"
|
379
|
+
)
|
grasp_agents/memory.py
CHANGED
@@ -142,9 +142,3 @@ class MessageHistory:
|
|
142
142
|
|
143
143
|
def erase(self) -> None:
|
144
144
|
self._batched_conversations = [[]]
|
145
|
-
|
146
|
-
# def get_batch(self, batch_id: int) -> list[Message]:
|
147
|
-
# return self._batched_conversations[batch_id]
|
148
|
-
|
149
|
-
# def iterate_conversations(self) -> Iterator[list[Message]]:
|
150
|
-
# return iter(self._batched_conversations)
|
@@ -13,13 +13,14 @@ def from_api_completion(
|
|
13
13
|
api_completion: ChatCompletion, model_id: str | None = None
|
14
14
|
) -> Completion:
|
15
15
|
choices: list[CompletionChoice] = []
|
16
|
-
# TODO: add custom error type
|
17
16
|
if api_completion.choices is None: # type: ignore
|
18
|
-
|
17
|
+
# Choices can sometimes be None for some providers using the OpenAI API
|
18
|
+
# TODO: add custom error types
|
19
|
+
raise RuntimeError(
|
19
20
|
f"Completion API error: {getattr(api_completion, 'error', None)}"
|
20
21
|
)
|
21
22
|
for api_choice in api_completion.choices:
|
22
|
-
# TODO: no way to assign individual message usages when len(choices) > 1
|
23
|
+
# TODO: currently no way to assign individual message usages when len(choices) > 1
|
23
24
|
message = from_api_assistant_message(
|
24
25
|
api_choice.message, api_completion.usage, model_id=model_id
|
25
26
|
)
|
@@ -6,12 +6,7 @@ from pydantic import BaseModel
|
|
6
6
|
from ..typing.completion import Completion, CompletionChunk
|
7
7
|
from ..typing.content import Content
|
8
8
|
from ..typing.converters import Converters
|
9
|
-
from ..typing.message import
|
10
|
-
AssistantMessage,
|
11
|
-
SystemMessage,
|
12
|
-
ToolMessage,
|
13
|
-
UserMessage,
|
14
|
-
)
|
9
|
+
from ..typing.message import AssistantMessage, SystemMessage, ToolMessage, UserMessage
|
15
10
|
from ..typing.tool import BaseTool, ToolChoice
|
16
11
|
from . import (
|
17
12
|
ChatCompletion,
|
@@ -19,7 +14,6 @@ from . import (
|
|
19
14
|
ChatCompletionAsyncStream, # type: ignore[import]
|
20
15
|
ChatCompletionChunk,
|
21
16
|
ChatCompletionContentPartParam,
|
22
|
-
# ChatCompletionDeveloperMessageParam,
|
23
17
|
ChatCompletionMessage,
|
24
18
|
ChatCompletionSystemMessageParam,
|
25
19
|
ChatCompletionToolChoiceOptionParam,
|
@@ -110,7 +104,7 @@ class OpenAIConverters(Converters):
|
|
110
104
|
|
111
105
|
@staticmethod
|
112
106
|
def to_tool(
|
113
|
-
tool: BaseTool[BaseModel,
|
107
|
+
tool: BaseTool[BaseModel, Any, Any], **kwargs: Any
|
114
108
|
) -> ChatCompletionToolParam:
|
115
109
|
return to_api_tool(tool, **kwargs)
|
116
110
|
|
@@ -132,9 +132,6 @@ def to_api_system_message(
|
|
132
132
|
message: SystemMessage,
|
133
133
|
) -> ChatCompletionSystemMessageParam:
|
134
134
|
return ChatCompletionSystemMessageParam(role="system", content=message.content)
|
135
|
-
# return ChatCompletionSystemMessageParam(
|
136
|
-
# role="system", content=message.content
|
137
|
-
# )
|
138
135
|
|
139
136
|
|
140
137
|
def from_api_tool_message(
|
@@ -149,7 +146,5 @@ def from_api_tool_message(
|
|
149
146
|
|
150
147
|
def to_api_tool_message(message: ToolMessage) -> ChatCompletionToolMessageParam:
|
151
148
|
return ChatCompletionToolMessageParam(
|
152
|
-
role="tool",
|
153
|
-
content=message.content,
|
154
|
-
tool_call_id=message.tool_call_id,
|
149
|
+
role="tool", content=message.content, tool_call_id=message.tool_call_id
|
155
150
|
)
|
@@ -3,14 +3,14 @@ from collections.abc import Iterable
|
|
3
3
|
from copy import deepcopy
|
4
4
|
from typing import Any, Literal
|
5
5
|
|
6
|
-
from openai import AsyncOpenAI
|
7
|
-
from openai._types import NOT_GIVEN # noqa: PLC2701 # type: ignore[import]
|
8
6
|
from pydantic import BaseModel
|
9
7
|
|
10
|
-
from
|
8
|
+
from openai import AsyncOpenAI
|
9
|
+
from openai._types import NOT_GIVEN # type: ignore[import]
|
11
10
|
|
12
11
|
from ..cloud_llm import APIProvider, CloudLLM, CloudLLMSettings
|
13
12
|
from ..http_client import AsyncHTTPClientParams
|
13
|
+
from ..rate_limiting.rate_limiter_chunked import RateLimiterC
|
14
14
|
from ..typing.message import AssistantMessage, Conversation
|
15
15
|
from ..typing.tool import BaseTool
|
16
16
|
from . import (
|
@@ -67,7 +67,7 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
|
|
67
67
|
model_name: str,
|
68
68
|
model_id: str | None = None,
|
69
69
|
llm_settings: OpenAILLMSettings | None = None,
|
70
|
-
tools: list[BaseTool[BaseModel,
|
70
|
+
tools: list[BaseTool[BaseModel, Any, Any]] | None = None,
|
71
71
|
response_format: type | None = None,
|
72
72
|
# Connection settings
|
73
73
|
api_provider: APIProvider = "openai",
|
@@ -113,8 +113,6 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
|
|
113
113
|
base_url=self._base_url,
|
114
114
|
api_key=self._api_key,
|
115
115
|
**async_openai_client_params_,
|
116
|
-
# timeout=10.0,
|
117
|
-
# max_retries=3,
|
118
116
|
)
|
119
117
|
|
120
118
|
async def _get_completion(
|
@@ -3,16 +3,11 @@ import functools
|
|
3
3
|
import logging
|
4
4
|
from collections.abc import Callable, Coroutine
|
5
5
|
from time import monotonic
|
6
|
-
from typing import
|
7
|
-
Any,
|
8
|
-
Generic,
|
9
|
-
overload,
|
10
|
-
)
|
6
|
+
from typing import Any, Generic, overload
|
11
7
|
|
12
8
|
from tqdm.autonotebook import tqdm
|
13
9
|
|
14
10
|
from ..utils import asyncio_gather_with_pbar
|
15
|
-
|
16
11
|
from .types import (
|
17
12
|
QueryP,
|
18
13
|
QueryR,
|
@@ -54,9 +49,7 @@ class RateLimiterC(Generic[QueryT, QueryR]):
|
|
54
49
|
if now < self._state.next_request_time:
|
55
50
|
await asyncio.sleep(self._state.next_request_time - now)
|
56
51
|
self._state.next_request_time = monotonic() + 1.01 * 60.0 / self._rpm
|
57
|
-
|
58
|
-
|
59
|
-
return result
|
52
|
+
return await func_partial(inp)
|
60
53
|
|
61
54
|
async def process_inputs(
|
62
55
|
self,
|
@@ -14,11 +14,16 @@ from .types import (
|
|
14
14
|
|
15
15
|
|
16
16
|
def is_bound_method(func: Callable[..., Any], self_candidate: Any) -> bool:
|
17
|
-
return (inspect.ismethod(func) and (func.__self__ is self_candidate)) or hasattr(
|
17
|
+
return (inspect.ismethod(func) and (func.__self__ is self_candidate)) or hasattr(
|
18
|
+
self_candidate, func.__name__
|
19
|
+
)
|
18
20
|
|
19
21
|
|
20
22
|
def split_pos_args(
|
21
|
-
call: (
|
23
|
+
call: (
|
24
|
+
RetrievalCallableSingle[QueryT, QueryP, QueryR]
|
25
|
+
| RetrievalCallableList[QueryT, QueryP, QueryR]
|
26
|
+
),
|
22
27
|
args: Sequence[Any],
|
23
28
|
) -> tuple[Any | None, QueryT | list[QueryT], Sequence[Any]]:
|
24
29
|
if not args:
|
@@ -28,12 +33,15 @@ def split_pos_args(
|
|
28
33
|
# Case: Bound instance method with signature (self, inp, *rest)
|
29
34
|
if len(args) < 2:
|
30
35
|
raise ValueError(
|
31
|
-
"Must pass at least `self` and an input (or a list of inputs) "
|
36
|
+
"Must pass at least `self` and an input (or a list of inputs) "
|
37
|
+
"for a bound instance method."
|
32
38
|
)
|
33
39
|
return maybe_self, args[1], args[2:]
|
34
40
|
# Case: Standalone function with signature (inp, *rest)
|
35
41
|
if not args:
|
36
|
-
raise ValueError(
|
42
|
+
raise ValueError(
|
43
|
+
"Must pass an input (or a list of inputs) " + "for a standalone function."
|
44
|
+
)
|
37
45
|
return None, args[0], args[1:]
|
38
46
|
|
39
47
|
|
@@ -53,5 +61,7 @@ def partial_retrieval_callable(
|
|
53
61
|
return wrapper
|
54
62
|
|
55
63
|
|
56
|
-
def expected_exec_time_from_max_concurrency_and_rpm(
|
64
|
+
def expected_exec_time_from_max_concurrency_and_rpm(
|
65
|
+
rpm: float, max_concurrency: int
|
66
|
+
) -> float:
|
57
67
|
return 60.0 / (rpm / max_concurrency)
|
@@ -31,7 +31,7 @@ class ToolOrchestrator(Generic[CtxT]):
|
|
31
31
|
self,
|
32
32
|
agent_id: str,
|
33
33
|
llm: LLM[LLMSettings, Converters],
|
34
|
-
tools: list[BaseTool[BaseModel,
|
34
|
+
tools: list[BaseTool[BaseModel, Any, CtxT]] | None,
|
35
35
|
max_turns: int,
|
36
36
|
react_mode: bool = False,
|
37
37
|
) -> None:
|
@@ -55,7 +55,7 @@ class ToolOrchestrator(Generic[CtxT]):
|
|
55
55
|
return self._llm
|
56
56
|
|
57
57
|
@property
|
58
|
-
def tools(self) -> dict[str, BaseTool[BaseModel,
|
58
|
+
def tools(self) -> dict[str, BaseTool[BaseModel, Any, CtxT]]:
|
59
59
|
return self._llm.tools or {}
|
60
60
|
|
61
61
|
@property
|
@@ -2,17 +2,9 @@ from abc import ABC, abstractmethod
|
|
2
2
|
from collections.abc import AsyncIterator
|
3
3
|
from typing import Any
|
4
4
|
|
5
|
-
from pydantic import BaseModel
|
6
|
-
|
7
5
|
from .completion import Completion, CompletionChunk
|
8
6
|
from .content import Content
|
9
|
-
from .message import
|
10
|
-
AssistantMessage,
|
11
|
-
Message,
|
12
|
-
SystemMessage,
|
13
|
-
ToolMessage,
|
14
|
-
UserMessage,
|
15
|
-
)
|
7
|
+
from .message import AssistantMessage, Message, SystemMessage, ToolMessage, UserMessage
|
16
8
|
from .tool import BaseTool, ToolChoice
|
17
9
|
|
18
10
|
|
@@ -72,7 +64,7 @@ class Converters(ABC):
|
|
72
64
|
|
73
65
|
@staticmethod
|
74
66
|
@abstractmethod
|
75
|
-
def to_tool(tool: BaseTool[
|
67
|
+
def to_tool(tool: BaseTool[Any, Any, Any], **kwargs: Any) -> Any:
|
76
68
|
pass
|
77
69
|
|
78
70
|
@staticmethod
|
grasp_agents/typing/io.py
CHANGED
@@ -1,8 +1,6 @@
|
|
1
|
-
from collections.abc import Sequence
|
2
1
|
from typing import TypeAlias, TypeVar
|
3
2
|
|
4
3
|
from pydantic import BaseModel
|
5
|
-
from pydantic.json_schema import SkipJsonSchema
|
6
4
|
|
7
5
|
from .content import ImageData
|
8
6
|
|
@@ -10,8 +8,7 @@ AgentID: TypeAlias = str
|
|
10
8
|
|
11
9
|
|
12
10
|
class AgentPayload(BaseModel):
|
13
|
-
|
14
|
-
selected_recipient_ids: SkipJsonSchema[Sequence[AgentID] | None] = None
|
11
|
+
pass
|
15
12
|
|
16
13
|
|
17
14
|
class AgentState(BaseModel):
|
grasp_agents/typing/message.py
CHANGED
@@ -1,9 +1,11 @@
|
|
1
|
+
import json
|
1
2
|
from collections.abc import Hashable, Sequence
|
2
3
|
from enum import StrEnum
|
3
|
-
from typing import Annotated, Literal, TypeAlias
|
4
|
+
from typing import Annotated, Any, Literal, TypeAlias
|
4
5
|
from uuid import uuid4
|
5
6
|
|
6
7
|
from pydantic import BaseModel, Field, NonNegativeFloat, NonNegativeInt
|
8
|
+
from pydantic.json import pydantic_encoder
|
7
9
|
|
8
10
|
from .content import Content, ImageData
|
9
11
|
from .tool import ToolCall
|
@@ -110,13 +112,13 @@ class ToolMessage(MessageBase):
|
|
110
112
|
@classmethod
|
111
113
|
def from_tool_output(
|
112
114
|
cls,
|
113
|
-
tool_output:
|
115
|
+
tool_output: Any,
|
114
116
|
tool_call: ToolCall,
|
115
117
|
model_id: str | None = None,
|
116
118
|
indent: int = 2,
|
117
119
|
) -> "ToolMessage":
|
118
120
|
return cls(
|
119
|
-
content=
|
121
|
+
content=json.dumps(tool_output, default=pydantic_encoder, indent=indent),
|
120
122
|
tool_call_id=tool_call.id,
|
121
123
|
model_id=model_id,
|
122
124
|
)
|
grasp_agents/typing/tool.py
CHANGED
@@ -1,9 +1,11 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import asyncio
|
3
4
|
from abc import ABC, abstractmethod
|
5
|
+
from collections.abc import Sequence
|
4
6
|
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, TypeVar
|
5
7
|
|
6
|
-
from pydantic import BaseModel
|
8
|
+
from pydantic import BaseModel, TypeAdapter
|
7
9
|
|
8
10
|
if TYPE_CHECKING:
|
9
11
|
from ..run_context import CtxT, RunContextWrapper
|
@@ -14,8 +16,8 @@ else:
|
|
14
16
|
"""Runtime placeholder so RunContextWrapper[CtxT] works"""
|
15
17
|
|
16
18
|
|
17
|
-
|
18
|
-
|
19
|
+
_ToolInT = TypeVar("_ToolInT", bound=BaseModel, contravariant=True) # noqa: PLC0105
|
20
|
+
_ToolOutT = TypeVar("_ToolOutT", covariant=True) # noqa: PLC0105
|
19
21
|
|
20
22
|
|
21
23
|
class ToolCall(BaseModel):
|
@@ -24,29 +26,34 @@ class ToolCall(BaseModel):
|
|
24
26
|
tool_arguments: str
|
25
27
|
|
26
28
|
|
27
|
-
class BaseTool(BaseModel, ABC, Generic[
|
29
|
+
class BaseTool(BaseModel, ABC, Generic[_ToolInT, _ToolOutT, CtxT]):
|
28
30
|
name: str
|
29
31
|
description: str
|
30
|
-
in_schema: type[
|
31
|
-
out_schema: type[
|
32
|
+
in_schema: type[_ToolInT]
|
33
|
+
out_schema: type[_ToolOutT]
|
32
34
|
|
33
35
|
# Supported by OpenAI API
|
34
36
|
strict: bool | None = None
|
35
37
|
|
36
38
|
@abstractmethod
|
37
39
|
async def run(
|
38
|
-
self, inp:
|
39
|
-
) ->
|
40
|
+
self, inp: _ToolInT, ctx: RunContextWrapper[CtxT] | None = None
|
41
|
+
) -> _ToolOutT:
|
40
42
|
pass
|
41
43
|
|
44
|
+
async def run_batch(
|
45
|
+
self, inp_batch: Sequence[_ToolInT], ctx: RunContextWrapper[CtxT] | None = None
|
46
|
+
) -> Sequence[_ToolOutT]:
|
47
|
+
return await asyncio.gather(*[self.run(inp, ctx=ctx) for inp in inp_batch])
|
48
|
+
|
42
49
|
async def __call__(
|
43
50
|
self, ctx: RunContextWrapper[CtxT] | None = None, **kwargs: Any
|
44
|
-
) ->
|
51
|
+
) -> _ToolOutT:
|
45
52
|
result = await self.run(self.in_schema(**kwargs), ctx=ctx)
|
46
53
|
|
47
|
-
return self.out_schema.
|
54
|
+
return TypeAdapter(self.out_schema).validate_python(result)
|
48
55
|
|
49
56
|
|
50
57
|
ToolChoice: TypeAlias = (
|
51
|
-
Literal["none", "auto", "required"] | BaseTool[BaseModel,
|
58
|
+
Literal["none", "auto", "required"] | BaseTool[BaseModel, Any, Any]
|
52
59
|
)
|