grasp_agents 0.1.15__py3-none-any.whl → 0.1.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
grasp_agents/llm_agent.py CHANGED
@@ -33,6 +33,8 @@ from .typing.io import (
33
33
  AgentPayload,
34
34
  AgentState,
35
35
  InT,
36
+ LLMFormattedArgs,
37
+ LLMFormattedSystemArgs,
36
38
  LLMPrompt,
37
39
  LLMPromptArgs,
38
40
  OutT,
@@ -67,7 +69,7 @@ class LLMAgent(
67
69
  # Output schema
68
70
  out_schema: type[OutT] = cast("type[OutT]", AgentPayload),
69
71
  # Tools
70
- tools: list[BaseTool[BaseModel, BaseModel, CtxT]] | None = None,
72
+ tools: list[BaseTool[Any, Any, CtxT]] | None = None,
71
73
  max_turns: int = 1000,
72
74
  react_mode: bool = False,
73
75
  # Agent state management
@@ -75,7 +77,6 @@ class LLMAgent(
75
77
  # Multi-agent routing
76
78
  message_pool: AgentMessagePool[CtxT] | None = None,
77
79
  recipient_ids: list[AgentID] | None = None,
78
- dynamic_routing: bool = False,
79
80
  ) -> None:
80
81
  super().__init__(
81
82
  agent_id=agent_id,
@@ -83,7 +84,6 @@ class LLMAgent(
83
84
  rcv_args_schema=rcv_args_schema,
84
85
  message_pool=message_pool,
85
86
  recipient_ids=recipient_ids,
86
- dynamic_routing=dynamic_routing,
87
87
  )
88
88
 
89
89
  # Agent state
@@ -114,12 +114,24 @@ class LLMAgent(
114
114
 
115
115
  self.no_tqdm = getattr(llm, "no_tqdm", False)
116
116
 
117
+ if type(self)._format_sys_args is not LLMAgent[Any, Any, Any]._format_sys_args: # noqa: SLF001
118
+ self._prompt_builder.format_sys_args_impl = self._format_sys_args
119
+
120
+ if type(self)._format_inp_args is not LLMAgent[Any, Any, Any]._format_inp_args: # noqa: SLF001
121
+ self._prompt_builder.format_inp_args_impl = self._format_inp_args
122
+
123
+ if (
124
+ type(self)._tool_call_loop_exit # noqa: SLF001
125
+ is not LLMAgent[Any, Any, Any]._tool_call_loop_exit # noqa: SLF001
126
+ ):
127
+ self._tool_orchestrator.tool_call_loop_exit_impl = self._tool_call_loop_exit
128
+
117
129
  @property
118
130
  def llm(self) -> LLM[LLMSettings, Converters]:
119
131
  return self._tool_orchestrator.llm
120
132
 
121
133
  @property
122
- def tools(self) -> dict[str, BaseTool[BaseModel, BaseModel, CtxT]]:
134
+ def tools(self) -> dict[str, BaseTool[BaseModel, Any, CtxT]]:
123
135
  return self._tool_orchestrator.tools
124
136
 
125
137
  @property
@@ -142,37 +154,10 @@ class LLMAgent(
142
154
  def inp_prompt(self) -> LLMPrompt | None:
143
155
  return self._prompt_builder.inp_prompt
144
156
 
145
- def format_sys_args_handler(
146
- self, func: FormatSystemArgsHandler[CtxT]
147
- ) -> FormatSystemArgsHandler[CtxT]:
148
- self._prompt_builder.format_sys_args_impl = func
149
-
150
- return func
151
-
152
- def format_inp_args_handler(
153
- self, func: FormatInputArgsHandler[InT, CtxT]
154
- ) -> FormatInputArgsHandler[InT, CtxT]:
155
- self._prompt_builder.format_inp_args_impl = func
156
-
157
- return func
158
-
159
- def make_custom_agent_state_handler(
160
- self, func: MakeCustomAgentState
161
- ) -> MakeCustomAgentState:
162
- self._make_custom_agent_state_impl = func
163
-
164
- return func
165
-
166
- def tool_call_loop_exit_handler(
167
- self, func: ToolCallLoopExitHandler[CtxT]
168
- ) -> ToolCallLoopExitHandler[CtxT]:
169
- self._tool_orchestrator.tool_call_loop_exit_impl = func
170
-
171
- return func
172
-
173
157
  def _parse_output(
174
158
  self,
175
159
  conversation: Conversation,
160
+ *,
176
161
  rcv_args: InT | None = None,
177
162
  ctx: RunContextWrapper[CtxT] | None = None,
178
163
  **kwargs: Any,
@@ -274,10 +259,7 @@ class LLMAgent(
274
259
 
275
260
  # 6. Write interaction history to context
276
261
 
277
- if self.dynamic_routing:
278
- recipient_ids = self._validate_dynamic_routing(val_output_batch)
279
- else:
280
- recipient_ids = self._validate_static_routing(val_output_batch)
262
+ recipient_ids = self._validate_routing(val_output_batch)
281
263
 
282
264
  if ctx:
283
265
  interaction_record = InteractionRecord(
@@ -332,30 +314,66 @@ class LLMAgent(
332
314
  ):
333
315
  self._print_msgs([state.message_history[0][0]], ctx=ctx)
334
316
 
335
- # def _format_sys_args(
336
- # self,
337
- # sys_args: LLMPromptArgs,
338
- # ctx: RunContextWrapper[CtxT] | None = None,
339
- # ) -> LLMFormattedSystemArgs:
340
- # return self._prompt_builder.format_sys_args(sys_args=sys_args, ctx=ctx)
341
-
342
- # def _format_inp_args(
343
- # self,
344
- # usr_args: LLMPromptArgs,
345
- # rcv_args: InT,
346
- # ctx: RunContextWrapper[CtxT] | None = None,
347
- # ) -> LLMFormattedArgs:
348
- # return self._prompt_builder.format_inp_args(
349
- # usr_args=usr_args, rcv_args=rcv_args, ctx=ctx
350
- # )
351
-
352
- # def _tool_call_loop_exit(
353
- # self,
354
- # conversation: Conversation,
355
- # *,
356
- # ctx: RunContextWrapper[CtxT] | None = None,
357
- # **kwargs: Any,
358
- # ) -> bool:
359
- # return self._tool_orchestrator.tool_call_loop_exit(
360
- # conversation=conversation, ctx=ctx, **kwargs
361
- # )
317
+ # -- Handlers for custom implementations --
318
+
319
+ def format_sys_args_handler(
320
+ self, func: FormatSystemArgsHandler[CtxT]
321
+ ) -> FormatSystemArgsHandler[CtxT]:
322
+ self._prompt_builder.format_sys_args_impl = func
323
+
324
+ return func
325
+
326
+ def format_inp_args_handler(
327
+ self, func: FormatInputArgsHandler[InT, CtxT]
328
+ ) -> FormatInputArgsHandler[InT, CtxT]:
329
+ self._prompt_builder.format_inp_args_impl = func
330
+
331
+ return func
332
+
333
+ def make_custom_agent_state_handler(
334
+ self, func: MakeCustomAgentState
335
+ ) -> MakeCustomAgentState:
336
+ self._make_custom_agent_state_impl = func
337
+
338
+ return func
339
+
340
+ def tool_call_loop_exit_handler(
341
+ self, func: ToolCallLoopExitHandler[CtxT]
342
+ ) -> ToolCallLoopExitHandler[CtxT]:
343
+ self._tool_orchestrator.tool_call_loop_exit_impl = func
344
+
345
+ return func
346
+
347
+ # -- Override these methods in subclasses if needed --
348
+
349
+ def _format_sys_args(
350
+ self,
351
+ sys_args: LLMPromptArgs,
352
+ *,
353
+ ctx: RunContextWrapper[CtxT] | None = None,
354
+ ) -> LLMFormattedSystemArgs:
355
+ raise NotImplementedError(
356
+ "LLMAgent._format_sys_args must be overridden by a subclass "
357
+ "if it's intended to be used as the system arguments formatter."
358
+ )
359
+
360
+ def _format_inp_args(
361
+ self,
362
+ usr_args: LLMPromptArgs,
363
+ rcv_args: InT,
364
+ ctx: RunContextWrapper[CtxT] | None = None,
365
+ ) -> LLMFormattedArgs:
366
+ raise NotImplementedError(
367
+ "LLMAgent._format_inp_args must be overridden by a subclass"
368
+ )
369
+
370
+ def _tool_call_loop_exit(
371
+ self,
372
+ conversation: Conversation,
373
+ *,
374
+ ctx: RunContextWrapper[CtxT] | None = None,
375
+ **kwargs: Any,
376
+ ) -> bool:
377
+ raise NotImplementedError(
378
+ "LLMAgent._tool_call_loop_exit must be overridden by a subclass"
379
+ )
grasp_agents/memory.py CHANGED
@@ -142,9 +142,3 @@ class MessageHistory:
142
142
 
143
143
  def erase(self) -> None:
144
144
  self._batched_conversations = [[]]
145
-
146
- # def get_batch(self, batch_id: int) -> list[Message]:
147
- # return self._batched_conversations[batch_id]
148
-
149
- # def iterate_conversations(self) -> Iterator[list[Message]]:
150
- # return iter(self._batched_conversations)
@@ -13,13 +13,14 @@ def from_api_completion(
13
13
  api_completion: ChatCompletion, model_id: str | None = None
14
14
  ) -> Completion:
15
15
  choices: list[CompletionChoice] = []
16
- # TODO: add custom error type
17
16
  if api_completion.choices is None: # type: ignore
18
- raise ValueError(
17
+ # Choices can sometimes be None for some providers using the OpenAI API
18
+ # TODO: add custom error types
19
+ raise RuntimeError(
19
20
  f"Completion API error: {getattr(api_completion, 'error', None)}"
20
21
  )
21
22
  for api_choice in api_completion.choices:
22
- # TODO: no way to assign individual message usages when len(choices) > 1
23
+ # TODO: currently no way to assign individual message usages when len(choices) > 1
23
24
  message = from_api_assistant_message(
24
25
  api_choice.message, api_completion.usage, model_id=model_id
25
26
  )
@@ -6,12 +6,7 @@ from pydantic import BaseModel
6
6
  from ..typing.completion import Completion, CompletionChunk
7
7
  from ..typing.content import Content
8
8
  from ..typing.converters import Converters
9
- from ..typing.message import (
10
- AssistantMessage,
11
- SystemMessage,
12
- ToolMessage,
13
- UserMessage,
14
- )
9
+ from ..typing.message import AssistantMessage, SystemMessage, ToolMessage, UserMessage
15
10
  from ..typing.tool import BaseTool, ToolChoice
16
11
  from . import (
17
12
  ChatCompletion,
@@ -19,7 +14,6 @@ from . import (
19
14
  ChatCompletionAsyncStream, # type: ignore[import]
20
15
  ChatCompletionChunk,
21
16
  ChatCompletionContentPartParam,
22
- # ChatCompletionDeveloperMessageParam,
23
17
  ChatCompletionMessage,
24
18
  ChatCompletionSystemMessageParam,
25
19
  ChatCompletionToolChoiceOptionParam,
@@ -110,7 +104,7 @@ class OpenAIConverters(Converters):
110
104
 
111
105
  @staticmethod
112
106
  def to_tool(
113
- tool: BaseTool[BaseModel, BaseModel, Any], **kwargs: Any
107
+ tool: BaseTool[BaseModel, Any, Any], **kwargs: Any
114
108
  ) -> ChatCompletionToolParam:
115
109
  return to_api_tool(tool, **kwargs)
116
110
 
@@ -132,9 +132,6 @@ def to_api_system_message(
132
132
  message: SystemMessage,
133
133
  ) -> ChatCompletionSystemMessageParam:
134
134
  return ChatCompletionSystemMessageParam(role="system", content=message.content)
135
- # return ChatCompletionSystemMessageParam(
136
- # role="system", content=message.content
137
- # )
138
135
 
139
136
 
140
137
  def from_api_tool_message(
@@ -149,7 +146,5 @@ def from_api_tool_message(
149
146
 
150
147
  def to_api_tool_message(message: ToolMessage) -> ChatCompletionToolMessageParam:
151
148
  return ChatCompletionToolMessageParam(
152
- role="tool",
153
- content=message.content,
154
- tool_call_id=message.tool_call_id,
149
+ role="tool", content=message.content, tool_call_id=message.tool_call_id
155
150
  )
@@ -3,14 +3,14 @@ from collections.abc import Iterable
3
3
  from copy import deepcopy
4
4
  from typing import Any, Literal
5
5
 
6
- from openai import AsyncOpenAI
7
- from openai._types import NOT_GIVEN # noqa: PLC2701 # type: ignore[import]
8
6
  from pydantic import BaseModel
9
7
 
10
- from ..data_retrieval.rate_limiter_chunked import RateLimiterC
8
+ from openai import AsyncOpenAI
9
+ from openai._types import NOT_GIVEN # type: ignore[import]
11
10
 
12
11
  from ..cloud_llm import APIProvider, CloudLLM, CloudLLMSettings
13
12
  from ..http_client import AsyncHTTPClientParams
13
+ from ..rate_limiting.rate_limiter_chunked import RateLimiterC
14
14
  from ..typing.message import AssistantMessage, Conversation
15
15
  from ..typing.tool import BaseTool
16
16
  from . import (
@@ -67,7 +67,7 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
67
67
  model_name: str,
68
68
  model_id: str | None = None,
69
69
  llm_settings: OpenAILLMSettings | None = None,
70
- tools: list[BaseTool[BaseModel, BaseModel, Any]] | None = None,
70
+ tools: list[BaseTool[BaseModel, Any, Any]] | None = None,
71
71
  response_format: type | None = None,
72
72
  # Connection settings
73
73
  api_provider: APIProvider = "openai",
@@ -113,8 +113,6 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
113
113
  base_url=self._base_url,
114
114
  api_key=self._api_key,
115
115
  **async_openai_client_params_,
116
- # timeout=10.0,
117
- # max_retries=3,
118
116
  )
119
117
 
120
118
  async def _get_completion(
@@ -13,7 +13,7 @@ from . import (
13
13
 
14
14
 
15
15
  def to_api_tool(
16
- tool: BaseTool[BaseModel, BaseModel, Any],
16
+ tool: BaseTool[BaseModel, Any, Any],
17
17
  ) -> ChatCompletionToolParam:
18
18
  return ChatCompletionToolParam(
19
19
  type="function",
@@ -3,16 +3,11 @@ import functools
3
3
  import logging
4
4
  from collections.abc import Callable, Coroutine
5
5
  from time import monotonic
6
- from typing import (
7
- Any,
8
- Generic,
9
- overload,
10
- )
6
+ from typing import Any, Generic, overload
11
7
 
12
8
  from tqdm.autonotebook import tqdm
13
9
 
14
10
  from ..utils import asyncio_gather_with_pbar
15
-
16
11
  from .types import (
17
12
  QueryP,
18
13
  QueryR,
@@ -54,9 +49,7 @@ class RateLimiterC(Generic[QueryT, QueryR]):
54
49
  if now < self._state.next_request_time:
55
50
  await asyncio.sleep(self._state.next_request_time - now)
56
51
  self._state.next_request_time = monotonic() + 1.01 * 60.0 / self._rpm
57
- result = await func_partial(inp)
58
-
59
- return result
52
+ return await func_partial(inp)
60
53
 
61
54
  async def process_inputs(
62
55
  self,
@@ -14,11 +14,16 @@ from .types import (
14
14
 
15
15
 
16
16
  def is_bound_method(func: Callable[..., Any], self_candidate: Any) -> bool:
17
- return (inspect.ismethod(func) and (func.__self__ is self_candidate)) or hasattr(self_candidate, func.__name__)
17
+ return (inspect.ismethod(func) and (func.__self__ is self_candidate)) or hasattr(
18
+ self_candidate, func.__name__
19
+ )
18
20
 
19
21
 
20
22
  def split_pos_args(
21
- call: (RetrievalCallableSingle[QueryT, QueryP, QueryR] | RetrievalCallableList[QueryT, QueryP, QueryR]),
23
+ call: (
24
+ RetrievalCallableSingle[QueryT, QueryP, QueryR]
25
+ | RetrievalCallableList[QueryT, QueryP, QueryR]
26
+ ),
22
27
  args: Sequence[Any],
23
28
  ) -> tuple[Any | None, QueryT | list[QueryT], Sequence[Any]]:
24
29
  if not args:
@@ -28,12 +33,15 @@ def split_pos_args(
28
33
  # Case: Bound instance method with signature (self, inp, *rest)
29
34
  if len(args) < 2:
30
35
  raise ValueError(
31
- "Must pass at least `self` and an input (or a list of inputs) " + "for a bound instance method."
36
+ "Must pass at least `self` and an input (or a list of inputs) "
37
+ "for a bound instance method."
32
38
  )
33
39
  return maybe_self, args[1], args[2:]
34
40
  # Case: Standalone function with signature (inp, *rest)
35
41
  if not args:
36
- raise ValueError("Must pass an input (or a list of inputs) " + "for a standalone function.")
42
+ raise ValueError(
43
+ "Must pass an input (or a list of inputs) " + "for a standalone function."
44
+ )
37
45
  return None, args[0], args[1:]
38
46
 
39
47
 
@@ -53,5 +61,7 @@ def partial_retrieval_callable(
53
61
  return wrapper
54
62
 
55
63
 
56
- def expected_exec_time_from_max_concurrency_and_rpm(rpm: float, max_concurrency: int) -> float:
64
+ def expected_exec_time_from_max_concurrency_and_rpm(
65
+ rpm: float, max_concurrency: int
66
+ ) -> float:
57
67
  return 60.0 / (rpm / max_concurrency)
@@ -31,7 +31,7 @@ class ToolOrchestrator(Generic[CtxT]):
31
31
  self,
32
32
  agent_id: str,
33
33
  llm: LLM[LLMSettings, Converters],
34
- tools: list[BaseTool[BaseModel, BaseModel, CtxT]] | None,
34
+ tools: list[BaseTool[BaseModel, Any, CtxT]] | None,
35
35
  max_turns: int,
36
36
  react_mode: bool = False,
37
37
  ) -> None:
@@ -55,7 +55,7 @@ class ToolOrchestrator(Generic[CtxT]):
55
55
  return self._llm
56
56
 
57
57
  @property
58
- def tools(self) -> dict[str, BaseTool[BaseModel, BaseModel, CtxT]]:
58
+ def tools(self) -> dict[str, BaseTool[BaseModel, Any, CtxT]]:
59
59
  return self._llm.tools or {}
60
60
 
61
61
  @property
@@ -2,17 +2,9 @@ from abc import ABC, abstractmethod
2
2
  from collections.abc import AsyncIterator
3
3
  from typing import Any
4
4
 
5
- from pydantic import BaseModel
6
-
7
5
  from .completion import Completion, CompletionChunk
8
6
  from .content import Content
9
- from .message import (
10
- AssistantMessage,
11
- Message,
12
- SystemMessage,
13
- ToolMessage,
14
- UserMessage,
15
- )
7
+ from .message import AssistantMessage, Message, SystemMessage, ToolMessage, UserMessage
16
8
  from .tool import BaseTool, ToolChoice
17
9
 
18
10
 
@@ -72,7 +64,7 @@ class Converters(ABC):
72
64
 
73
65
  @staticmethod
74
66
  @abstractmethod
75
- def to_tool(tool: BaseTool[BaseModel, BaseModel, Any], **kwargs: Any) -> Any:
67
+ def to_tool(tool: BaseTool[Any, Any, Any], **kwargs: Any) -> Any:
76
68
  pass
77
69
 
78
70
  @staticmethod
grasp_agents/typing/io.py CHANGED
@@ -1,8 +1,6 @@
1
- from collections.abc import Sequence
2
1
  from typing import TypeAlias, TypeVar
3
2
 
4
3
  from pydantic import BaseModel
5
- from pydantic.json_schema import SkipJsonSchema
6
4
 
7
5
  from .content import ImageData
8
6
 
@@ -10,8 +8,7 @@ AgentID: TypeAlias = str
10
8
 
11
9
 
12
10
  class AgentPayload(BaseModel):
13
- # TODO: do we need conversation?
14
- selected_recipient_ids: SkipJsonSchema[Sequence[AgentID] | None] = None
11
+ pass
15
12
 
16
13
 
17
14
  class AgentState(BaseModel):
@@ -1,9 +1,11 @@
1
+ import json
1
2
  from collections.abc import Hashable, Sequence
2
3
  from enum import StrEnum
3
- from typing import Annotated, Literal, TypeAlias
4
+ from typing import Annotated, Any, Literal, TypeAlias
4
5
  from uuid import uuid4
5
6
 
6
7
  from pydantic import BaseModel, Field, NonNegativeFloat, NonNegativeInt
8
+ from pydantic.json import pydantic_encoder
7
9
 
8
10
  from .content import Content, ImageData
9
11
  from .tool import ToolCall
@@ -110,13 +112,13 @@ class ToolMessage(MessageBase):
110
112
  @classmethod
111
113
  def from_tool_output(
112
114
  cls,
113
- tool_output: BaseModel,
115
+ tool_output: Any,
114
116
  tool_call: ToolCall,
115
117
  model_id: str | None = None,
116
118
  indent: int = 2,
117
119
  ) -> "ToolMessage":
118
120
  return cls(
119
- content=tool_output.model_dump_json(indent=indent),
121
+ content=json.dumps(tool_output, default=pydantic_encoder, indent=indent),
120
122
  tool_call_id=tool_call.id,
121
123
  model_id=model_id,
122
124
  )
@@ -1,9 +1,11 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import asyncio
3
4
  from abc import ABC, abstractmethod
5
+ from collections.abc import Sequence
4
6
  from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, TypeVar
5
7
 
6
- from pydantic import BaseModel
8
+ from pydantic import BaseModel, TypeAdapter
7
9
 
8
10
  if TYPE_CHECKING:
9
11
  from ..run_context import CtxT, RunContextWrapper
@@ -14,8 +16,8 @@ else:
14
16
  """Runtime placeholder so RunContextWrapper[CtxT] works"""
15
17
 
16
18
 
17
- ToolInT = TypeVar("ToolInT", bound=BaseModel, contravariant=True) # noqa: PLC0105
18
- ToolOutT = TypeVar("ToolOutT", bound=BaseModel, covariant=True) # noqa: PLC0105
19
+ _ToolInT = TypeVar("_ToolInT", bound=BaseModel, contravariant=True) # noqa: PLC0105
20
+ _ToolOutT = TypeVar("_ToolOutT", covariant=True) # noqa: PLC0105
19
21
 
20
22
 
21
23
  class ToolCall(BaseModel):
@@ -24,29 +26,34 @@ class ToolCall(BaseModel):
24
26
  tool_arguments: str
25
27
 
26
28
 
27
- class BaseTool(BaseModel, ABC, Generic[ToolInT, ToolOutT, CtxT]):
29
+ class BaseTool(BaseModel, ABC, Generic[_ToolInT, _ToolOutT, CtxT]):
28
30
  name: str
29
31
  description: str
30
- in_schema: type[ToolInT]
31
- out_schema: type[ToolOutT]
32
+ in_schema: type[_ToolInT]
33
+ out_schema: type[_ToolOutT]
32
34
 
33
35
  # Supported by OpenAI API
34
36
  strict: bool | None = None
35
37
 
36
38
  @abstractmethod
37
39
  async def run(
38
- self, inp: ToolInT, ctx: RunContextWrapper[CtxT] | None = None
39
- ) -> ToolOutT:
40
+ self, inp: _ToolInT, ctx: RunContextWrapper[CtxT] | None = None
41
+ ) -> _ToolOutT:
40
42
  pass
41
43
 
44
+ async def run_batch(
45
+ self, inp_batch: Sequence[_ToolInT], ctx: RunContextWrapper[CtxT] | None = None
46
+ ) -> Sequence[_ToolOutT]:
47
+ return await asyncio.gather(*[self.run(inp, ctx=ctx) for inp in inp_batch])
48
+
42
49
  async def __call__(
43
50
  self, ctx: RunContextWrapper[CtxT] | None = None, **kwargs: Any
44
- ) -> ToolOutT:
51
+ ) -> _ToolOutT:
45
52
  result = await self.run(self.in_schema(**kwargs), ctx=ctx)
46
53
 
47
- return self.out_schema.model_validate(result)
54
+ return TypeAdapter(self.out_schema).validate_python(result)
48
55
 
49
56
 
50
57
  ToolChoice: TypeAlias = (
51
- Literal["none", "auto", "required"] | BaseTool[BaseModel, BaseModel, Any]
58
+ Literal["none", "auto", "required"] | BaseTool[BaseModel, Any, Any]
52
59
  )