grasp_agents 0.1.18__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,6 +2,7 @@ import asyncio
2
2
  import functools
3
3
  import logging
4
4
  from collections.abc import Callable, Coroutine
5
+ from dataclasses import dataclass
5
6
  from time import monotonic
6
7
  from typing import Any, Generic, overload
7
8
 
@@ -9,21 +10,25 @@ from tqdm.autonotebook import tqdm
9
10
 
10
11
  from ..utils import asyncio_gather_with_pbar
11
12
  from .types import (
12
- QueryP,
13
- QueryR,
14
- QueryT,
15
- RateLimDecoratorWithArgsList,
16
- RateLimDecoratorWithArgsSingle,
17
- RateLimiterState,
18
- RetrievalCallableList,
19
- RetrievalCallableSingle,
13
+ P,
14
+ ProcessorCallableList,
15
+ ProcessorCallableSingle,
16
+ R,
17
+ RateLimWrapperWithArgsList,
18
+ RateLimWrapperWithArgsSingle,
19
+ T,
20
20
  )
21
- from .utils import partial_retrieval_callable, split_pos_args
21
+ from .utils import partial_processor_callable, split_pos_args
22
22
 
23
23
  logger = logging.getLogger(__name__)
24
24
 
25
25
 
26
- class RateLimiterC(Generic[QueryT, QueryR]):
26
+ @dataclass
27
+ class RateLimiterState:
28
+ next_request_time: float = 0.0
29
+
30
+
31
+ class RateLimiterC(Generic[T, R]):
27
32
  def __init__(
28
33
  self,
29
34
  rpm: float,
@@ -40,9 +45,9 @@ class RateLimiterC(Generic[QueryT, QueryR]):
40
45
 
41
46
  async def process_input(
42
47
  self,
43
- func_partial: Callable[[QueryT], Coroutine[Any, Any, QueryR]],
44
- inp: QueryT,
45
- ) -> QueryR:
48
+ func_partial: Callable[[T], Coroutine[Any, Any, R]],
49
+ inp: T,
50
+ ) -> R:
46
51
  async with self._semaphore:
47
52
  async with self._lock:
48
53
  now = monotonic()
@@ -53,11 +58,11 @@ class RateLimiterC(Generic[QueryT, QueryR]):
53
58
 
54
59
  async def process_inputs(
55
60
  self,
56
- func_partial: Callable[[QueryT], Coroutine[Any, Any, QueryR]],
57
- inputs: list[QueryT],
61
+ func_partial: Callable[[T], Coroutine[Any, Any, R]],
62
+ inputs: list[T],
58
63
  no_tqdm: bool = False,
59
- ) -> list[QueryR]:
60
- results: list[QueryR] = []
64
+ ) -> list[R]:
65
+ results: list[R] = []
61
66
  for i in tqdm(
62
67
  range(0, len(inputs), self._chunk_size),
63
68
  disable=no_tqdm,
@@ -95,33 +100,30 @@ class RateLimiterC(Generic[QueryT, QueryR]):
95
100
 
96
101
  @overload
97
102
  def limit_rate(
98
- call: RetrievalCallableSingle[QueryT, QueryP, QueryR],
99
- rate_limiter: RateLimiterC[QueryT, QueryR] | None = None,
100
- ) -> RetrievalCallableSingle[QueryT, QueryP, QueryR]: ...
103
+ call: ProcessorCallableSingle[T, P, R],
104
+ rate_limiter: RateLimiterC[T, R] | None = None,
105
+ ) -> ProcessorCallableSingle[T, P, R]: ...
101
106
 
102
107
 
103
108
  @overload
104
109
  def limit_rate(
105
110
  call: None = None,
106
- rate_limiter: RateLimiterC[QueryT, QueryR] | None = None,
107
- ) -> RateLimDecoratorWithArgsSingle[QueryT, QueryP, QueryR]: ...
111
+ rate_limiter: RateLimiterC[T, R] | None = None,
112
+ ) -> RateLimWrapperWithArgsSingle[T, P, R]: ...
108
113
 
109
114
 
110
115
  def limit_rate(
111
- call: RetrievalCallableSingle[QueryT, QueryP, QueryR] | None = None,
112
- rate_limiter: RateLimiterC[QueryT, QueryR] | None = None,
113
- ) -> (
114
- RetrievalCallableSingle[QueryT, QueryP, QueryR]
115
- | RateLimDecoratorWithArgsSingle[QueryT, QueryP, QueryR]
116
- ):
116
+ call: ProcessorCallableSingle[T, P, R] | None = None,
117
+ rate_limiter: RateLimiterC[T, R] | None = None,
118
+ ) -> ProcessorCallableSingle[T, P, R] | RateLimWrapperWithArgsSingle[T, P, R]:
117
119
  if call is None:
118
120
  return functools.partial(limit_rate, rate_limiter=rate_limiter)
119
121
 
120
122
  @functools.wraps(call) # type: ignore
121
- async def wrapper(*args: Any, **kwargs: Any) -> QueryR:
122
- inp: QueryT
123
- self_obj, inp, other_args = split_pos_args(call, args) # type: ignore
124
- call_partial = partial_retrieval_callable(call, self_obj, *other_args, **kwargs)
123
+ async def wrapper(*args: Any, **kwargs: Any) -> R:
124
+ inp: T
125
+ self_obj, inp, other_args = split_pos_args(call, args)
126
+ call_partial = partial_processor_callable(call, self_obj, *other_args, **kwargs)
125
127
 
126
128
  _rate_limiter = rate_limiter
127
129
  if _rate_limiter is None:
@@ -136,39 +138,36 @@ def limit_rate(
136
138
 
137
139
  @overload
138
140
  def limit_rate_chunked(
139
- call: RetrievalCallableSingle[QueryT, QueryP, QueryR],
140
- rate_limiter: RateLimiterC[QueryT, QueryR] | None = None,
141
+ call: ProcessorCallableList[T, P, R],
142
+ rate_limiter: RateLimiterC[T, R] | None = None,
141
143
  no_tqdm: bool | None = None,
142
- ) -> RetrievalCallableList[QueryT, QueryP, QueryR]: ...
144
+ ) -> ProcessorCallableList[T, P, R]: ...
143
145
 
144
146
 
145
147
  @overload
146
148
  def limit_rate_chunked(
147
149
  call: None = None,
148
- rate_limiter: RateLimiterC[QueryT, QueryR] | None = None,
150
+ rate_limiter: RateLimiterC[T, R] | None = None,
149
151
  no_tqdm: bool | None = None,
150
- ) -> RateLimDecoratorWithArgsList[QueryT, QueryP, QueryR]: ...
152
+ ) -> RateLimWrapperWithArgsList[T, P, R]: ...
151
153
 
152
154
 
153
155
  def limit_rate_chunked(
154
- call: RetrievalCallableSingle[QueryT, QueryP, QueryR] | None = None,
155
- rate_limiter: RateLimiterC[QueryT, QueryR] | None = None,
156
+ call: ProcessorCallableList[T, P, R] | None = None,
157
+ rate_limiter: RateLimiterC[T, R] | None = None,
156
158
  no_tqdm: bool | None = None,
157
- ) -> (
158
- RetrievalCallableList[QueryT, QueryP, QueryR]
159
- | RateLimDecoratorWithArgsList[QueryT, QueryP, QueryR]
160
- ):
159
+ ) -> ProcessorCallableList[T, P, R] | RateLimWrapperWithArgsList[T, P, R]:
161
160
  if call is None:
162
161
  return functools.partial(
163
162
  limit_rate_chunked, rate_limiter=rate_limiter, no_tqdm=no_tqdm
164
- )
163
+ ) # type: ignore
165
164
 
166
165
  @functools.wraps(call) # type: ignore
167
- async def wrapper(*args: Any, **kwargs: Any) -> list[QueryR]:
166
+ async def wrapper(*args: Any, **kwargs: Any) -> list[R]:
168
167
  assert call is not None
169
168
 
170
- self_obj, inputs, other_args = split_pos_args(call, args) # type: ignore
171
- call_partial = partial_retrieval_callable(call, self_obj, *other_args, **kwargs)
169
+ self_obj, inputs, other_args = split_pos_args(call, args)
170
+ call_partial = partial_processor_callable(call, self_obj, *other_args, **kwargs)
172
171
 
173
172
  _no_tqdm = no_tqdm
174
173
  _rate_limiter = rate_limiter
@@ -182,7 +181,9 @@ def limit_rate_chunked(
182
181
  *[call_partial(inp) for inp in inputs], no_tqdm=_no_tqdm
183
182
  )
184
183
  return await _rate_limiter.process_inputs(
185
- func_partial=call_partial, inputs=inputs, no_tqdm=_no_tqdm
184
+ func_partial=call_partial, # type: ignore
185
+ inputs=inputs,
186
+ no_tqdm=_no_tqdm,
186
187
  )
187
188
 
188
189
  return wrapper
@@ -1,57 +1,36 @@
1
1
  from collections.abc import Callable, Coroutine
2
- from dataclasses import dataclass
3
- from typing import (
4
- Any,
5
- Concatenate,
6
- ParamSpec,
7
- TypeAlias,
8
- TypeVar,
9
- )
10
-
11
- MAX_RPM = 1e10
12
-
2
+ from typing import Any, Concatenate, ParamSpec, TypeAlias, TypeVar
13
3
 
14
- @dataclass
15
- class RateLimiterState:
16
- next_request_time: float = 0.0
4
+ T = TypeVar("T")
5
+ R = TypeVar("R")
6
+ P = ParamSpec("P")
17
7
 
18
-
19
- QueryT = TypeVar("QueryT")
20
- QueryR = TypeVar("QueryR")
21
- QueryP = ParamSpec("QueryP")
22
-
23
- RetrievalFuncSingle: TypeAlias = Callable[
24
- Concatenate[QueryT, QueryP], Coroutine[Any, Any, QueryR]
25
- ]
26
- RetrievalFuncList: TypeAlias = Callable[
27
- Concatenate[list[QueryT], QueryP], Coroutine[Any, Any, list[QueryR]]
8
+ ProcessorFuncSingle: TypeAlias = Callable[Concatenate[T, P], Coroutine[Any, Any, R]]
9
+ ProcessorFuncList: TypeAlias = Callable[
10
+ Concatenate[list[T], P], Coroutine[Any, Any, list[R]]
28
11
  ]
29
12
 
30
- RetrievalMethodSingle: TypeAlias = Callable[
31
- Concatenate[Any, QueryT, QueryP], Coroutine[Any, Any, QueryR]
13
+ ProcessorMethodSingle: TypeAlias = Callable[
14
+ Concatenate[Any, T, P], Coroutine[Any, Any, R]
32
15
  ]
33
- RetrievalMethodList: TypeAlias = Callable[
34
- Concatenate[Any, list[QueryT], QueryP], Coroutine[Any, Any, list[QueryR]]
16
+ ProcessorMethodList: TypeAlias = Callable[
17
+ Concatenate[Any, list[T], P], Coroutine[Any, Any, list[R]]
35
18
  ]
36
19
 
37
- RetrievalCallableSingle: TypeAlias = (
38
- RetrievalFuncSingle[QueryT, QueryP, QueryR]
39
- | RetrievalMethodSingle[QueryT, QueryP, QueryR]
20
+ ProcessorCallableSingle: TypeAlias = (
21
+ ProcessorFuncSingle[T, P, R] | ProcessorMethodSingle[T, P, R]
40
22
  )
41
23
 
42
- RetrievalCallableList: TypeAlias = (
43
- RetrievalFuncList[QueryT, QueryP, QueryR]
44
- | RetrievalMethodList[QueryT, QueryP, QueryR]
24
+ ProcessorCallableList: TypeAlias = (
25
+ ProcessorFuncList[T, P, R] | ProcessorMethodList[T, P, R]
45
26
  )
46
27
 
47
28
 
48
- RateLimDecoratorWithArgsSingle = Callable[
49
- [RetrievalCallableSingle[QueryT, QueryP, QueryR]],
50
- RetrievalCallableSingle[QueryT, QueryP, QueryR],
29
+ RateLimWrapperWithArgsSingle = Callable[
30
+ [ProcessorCallableSingle[T, P, R]], ProcessorCallableSingle[T, P, R]
51
31
  ]
52
32
 
53
33
 
54
- RateLimDecoratorWithArgsList = Callable[
55
- [RetrievalCallableList[QueryT, QueryP, QueryR]],
56
- RetrievalCallableList[QueryT, QueryP, QueryR],
34
+ RateLimWrapperWithArgsList = Callable[
35
+ [ProcessorCallableList[T, P, R]], ProcessorCallableList[T, P, R]
57
36
  ]
@@ -1,16 +1,8 @@
1
1
  import inspect
2
2
  from collections.abc import Callable, Coroutine, Sequence
3
- from typing import (
4
- Any,
5
- )
3
+ from typing import Any, overload
6
4
 
7
- from .types import (
8
- QueryP,
9
- QueryR,
10
- QueryT,
11
- RetrievalCallableList,
12
- RetrievalCallableSingle,
13
- )
5
+ from .types import P, ProcessorCallableList, ProcessorCallableSingle, R, T
14
6
 
15
7
 
16
8
  def is_bound_method(func: Callable[..., Any], self_candidate: Any) -> bool:
@@ -19,13 +11,24 @@ def is_bound_method(func: Callable[..., Any], self_candidate: Any) -> bool:
19
11
  )
20
12
 
21
13
 
14
+ @overload
22
15
  def split_pos_args(
23
- call: (
24
- RetrievalCallableSingle[QueryT, QueryP, QueryR]
25
- | RetrievalCallableList[QueryT, QueryP, QueryR]
26
- ),
16
+ call: ProcessorCallableSingle[T, P, R],
27
17
  args: Sequence[Any],
28
- ) -> tuple[Any | None, QueryT | list[QueryT], Sequence[Any]]:
18
+ ) -> tuple[Any | None, T, Sequence[Any]]: ...
19
+
20
+
21
+ @overload
22
+ def split_pos_args(
23
+ call: ProcessorCallableList[T, P, R],
24
+ args: Sequence[Any],
25
+ ) -> tuple[Any | None, list[T], Sequence[Any]]: ...
26
+
27
+
28
+ def split_pos_args(
29
+ call: (ProcessorCallableSingle[T, P, R] | ProcessorCallableList[T, P, R]),
30
+ args: Sequence[Any],
31
+ ) -> tuple[Any | None, T | list[T], Sequence[Any]]:
29
32
  if not args:
30
33
  raise ValueError("No positional arguments passed.")
31
34
  maybe_self = args[0]
@@ -45,13 +48,13 @@ def split_pos_args(
45
48
  return None, args[0], args[1:]
46
49
 
47
50
 
48
- def partial_retrieval_callable(
49
- call: Callable[..., Coroutine[Any, Any, QueryR]],
51
+ def partial_processor_callable(
52
+ call: Callable[..., Coroutine[Any, Any, R]],
50
53
  self_obj: Any,
51
- *args: QueryP.args,
52
- **kwargs: QueryP.kwargs,
53
- ) -> Callable[[QueryT], Coroutine[Any, Any, QueryR]]:
54
- async def wrapper(inp: QueryT) -> QueryR:
54
+ *args: Any,
55
+ **kwargs: Any,
56
+ ) -> Callable[[Any], Coroutine[Any, Any, R]]:
57
+ async def wrapper(inp: Any) -> R:
55
58
  if self_obj is not None:
56
59
  # `call` is a method
57
60
  return await call(self_obj, inp, *args, **kwargs)
@@ -59,9 +62,3 @@ def partial_retrieval_callable(
59
62
  return await call(inp, *args, **kwargs)
60
63
 
61
64
  return wrapper
62
-
63
-
64
- def expected_exec_time_from_max_concurrency_and_rpm(
65
- rpm: float, max_concurrency: int
66
- ) -> float:
67
- return 60.0 / (rpm / max_concurrency)
@@ -8,7 +8,6 @@ from .printer import Printer
8
8
  from .typing.content import ImageData
9
9
  from .typing.io import (
10
10
  AgentID,
11
- AgentPayload,
12
11
  AgentState,
13
12
  InT,
14
13
  LLMPrompt,
@@ -44,9 +43,7 @@ class InteractionRecord(BaseModel, Generic[InT, OutT, StateT]):
44
43
  model_config = ConfigDict(extra="forbid", frozen=True)
45
44
 
46
45
 
47
- InteractionHistory: TypeAlias = list[
48
- InteractionRecord[AgentPayload, AgentPayload, AgentState]
49
- ]
46
+ InteractionHistory: TypeAlias = list[InteractionRecord[Any, Any, AgentState]]
50
47
 
51
48
 
52
49
  CtxT = TypeVar("CtxT")
@@ -56,23 +53,13 @@ class RunContextWrapper(BaseModel, Generic[CtxT]):
56
53
  context: CtxT | None = None
57
54
  run_id: str = Field(default_factory=lambda: str(uuid4())[:8], frozen=True)
58
55
  run_args: dict[AgentID, RunArgs] = Field(default_factory=dict)
59
- interaction_history: InteractionHistory = Field(default_factory=list)
56
+ interaction_history: InteractionHistory = Field(default_factory=list) # type: ignore[valid-type]
60
57
 
61
58
  print_messages: bool = False
62
59
 
63
60
  _usage_tracker: UsageTracker = PrivateAttr()
64
61
  _printer: Printer = PrivateAttr()
65
62
 
66
- # usage_tracker: Optional[UsageTracker] = None
67
- # printer: Optional[Printer] = None
68
-
69
- # @model_validator(mode="after")
70
- # def set_usage_tracker_and_printer(self) -> "RunContextWrapper":
71
- # self.usage_tracker = UsageTracker(source_id=self.run_id)
72
- # self.printer = Printer(source_id=self.run_id)
73
-
74
- # return self
75
-
76
63
  def model_post_init(self, context: Any) -> None: # noqa: ARG002
77
64
  self._usage_tracker = UsageTracker(source_id=self.run_id)
78
65
  self._printer = Printer(
@@ -16,7 +16,7 @@ from .typing.tool import BaseTool, ToolCall, ToolChoice
16
16
  logger = getLogger(__name__)
17
17
 
18
18
 
19
- class ToolCallLoopExitHandler(Protocol[CtxT]):
19
+ class ExitToolCallLoopHandler(Protocol[CtxT]):
20
20
  def __call__(
21
21
  self,
22
22
  conversation: Conversation,
@@ -26,6 +26,16 @@ class ToolCallLoopExitHandler(Protocol[CtxT]):
26
26
  ) -> bool: ...
27
27
 
28
28
 
29
+ class ManageAgentStateHandler(Protocol[CtxT]):
30
+ def __call__(
31
+ self,
32
+ state: LLMAgentState,
33
+ *,
34
+ ctx: RunContextWrapper[CtxT] | None,
35
+ **kwargs: Any,
36
+ ) -> None: ...
37
+
38
+
29
39
  class ToolOrchestrator(Generic[CtxT]):
30
40
  def __init__(
31
41
  self,
@@ -38,13 +48,13 @@ class ToolOrchestrator(Generic[CtxT]):
38
48
  self._agent_id = agent_id
39
49
 
40
50
  self._llm = llm
41
- self._tools = tools
42
- self.llm.tools = tools
51
+ self._llm.tools = tools
43
52
 
44
53
  self._max_turns = max_turns
45
54
  self._react_mode = react_mode
46
55
 
47
- self.tool_call_loop_exit_impl: ToolCallLoopExitHandler[CtxT] | None = None
56
+ self.exit_tool_call_loop_impl: ExitToolCallLoopHandler[CtxT] | None = None
57
+ self.manage_agent_state_impl: ManageAgentStateHandler[CtxT] | None = None
48
58
 
49
59
  @property
50
60
  def agent_id(self) -> str:
@@ -62,15 +72,15 @@ class ToolOrchestrator(Generic[CtxT]):
62
72
  def max_turns(self) -> int:
63
73
  return self._max_turns
64
74
 
65
- def _tool_call_loop_exit(
75
+ def _exit_tool_call_loop(
66
76
  self,
67
77
  conversation: Conversation,
68
78
  *,
69
79
  ctx: RunContextWrapper[CtxT] | None = None,
70
80
  **kwargs: Any,
71
81
  ) -> bool:
72
- if self.tool_call_loop_exit_impl:
73
- return self.tool_call_loop_exit_impl(
82
+ if self.exit_tool_call_loop_impl:
83
+ return self.exit_tool_call_loop_impl(
74
84
  conversation=conversation, ctx=ctx, **kwargs
75
85
  )
76
86
 
@@ -81,6 +91,16 @@ class ToolOrchestrator(Generic[CtxT]):
81
91
 
82
92
  return not bool(conversation[-1].tool_calls)
83
93
 
94
+ def _manage_agent_state(
95
+ self,
96
+ state: LLMAgentState,
97
+ *,
98
+ ctx: RunContextWrapper[CtxT] | None = None,
99
+ **kwargs: Any,
100
+ ) -> None:
101
+ if self.manage_agent_state_impl:
102
+ self.manage_agent_state_impl(state=state, ctx=ctx, **kwargs)
103
+
84
104
  async def generate_once(
85
105
  self,
86
106
  agent_state: LLMAgentState,
@@ -99,10 +119,10 @@ class ToolOrchestrator(Generic[CtxT]):
99
119
 
100
120
  async def run_loop(
101
121
  self,
102
- agent_state: LLMAgentState,
122
+ state: LLMAgentState,
103
123
  ctx: RunContextWrapper[CtxT] | None = None,
104
124
  ) -> None:
105
- message_history = agent_state.message_history
125
+ message_history = state.message_history
106
126
  assert message_history.batch_size == 1, (
107
127
  "Batch size must be 1 for tool call loop"
108
128
  )
@@ -111,13 +131,15 @@ class ToolOrchestrator(Generic[CtxT]):
111
131
 
112
132
  tool_choice = "none" if self._react_mode else "auto"
113
133
  gen_message_batch = await self.generate_once(
114
- agent_state, tool_choice=tool_choice, ctx=ctx
134
+ state, tool_choice=tool_choice, ctx=ctx
115
135
  )
116
136
 
117
137
  turns = 0
118
138
 
119
139
  while True:
120
- if self._tool_call_loop_exit(
140
+ self._manage_agent_state(state=state, ctx=ctx, num_turns=turns)
141
+
142
+ if self._exit_tool_call_loop(
121
143
  message_history.batched_conversations[0], ctx=ctx, num_turns=turns
122
144
  ):
123
145
  return
@@ -134,7 +156,7 @@ class ToolOrchestrator(Generic[CtxT]):
134
156
 
135
157
  tool_choice = "none" if (self._react_mode and msg.tool_calls) else "auto"
136
158
  gen_message_batch = await self.generate_once(
137
- agent_state, tool_choice=tool_choice, ctx=ctx
159
+ state, tool_choice=tool_choice, ctx=ctx
138
160
  )
139
161
 
140
162
  turns += 1
@@ -1,6 +1,6 @@
1
1
  import base64
2
2
  import re
3
- from collections.abc import Iterable
3
+ from collections.abc import Iterable, Mapping
4
4
  from enum import StrEnum
5
5
  from pathlib import Path
6
6
  from typing import Annotated, Any, Literal, TypeAlias
@@ -66,7 +66,7 @@ class Content(BaseModel):
66
66
  def from_formatted_prompt(
67
67
  cls,
68
68
  prompt_template: str,
69
- prompt_args: dict[str, str | ImageData] | None = None,
69
+ prompt_args: Mapping[str, str | int | bool | ImageData] | None = None,
70
70
  ) -> "Content":
71
71
  prompt_args = prompt_args or {}
72
72
  image_args = {
@@ -2,6 +2,8 @@ from abc import ABC, abstractmethod
2
2
  from collections.abc import AsyncIterator
3
3
  from typing import Any
4
4
 
5
+ from pydantic import BaseModel
6
+
5
7
  from .completion import Completion, CompletionChunk
6
8
  from .content import Content
7
9
  from .message import AssistantMessage, Message, SystemMessage, ToolMessage, UserMessage
@@ -64,7 +66,7 @@ class Converters(ABC):
64
66
 
65
67
  @staticmethod
66
68
  @abstractmethod
67
- def to_tool(tool: BaseTool[Any, Any, Any], **kwargs: Any) -> Any:
69
+ def to_tool(tool: BaseTool[BaseModel, Any, Any], **kwargs: Any) -> Any:
68
70
  pass
69
71
 
70
72
  @staticmethod
grasp_agents/typing/io.py CHANGED
@@ -1,3 +1,4 @@
1
+ from collections.abc import Mapping
1
2
  from typing import TypeAlias, TypeVar
2
3
 
3
4
  from pydantic import BaseModel
@@ -7,23 +8,18 @@ from .content import ImageData
7
8
  AgentID: TypeAlias = str
8
9
 
9
10
 
10
- class AgentPayload(BaseModel):
11
- pass
12
-
13
-
14
11
  class AgentState(BaseModel):
15
12
  pass
16
13
 
17
14
 
18
- InT = TypeVar("InT", bound=AgentPayload, contravariant=True) # noqa: PLC0105
19
- OutT = TypeVar("OutT", bound=AgentPayload, covariant=True) # noqa: PLC0105
20
- StateT = TypeVar("StateT", bound=AgentState, covariant=True) # noqa: PLC0105
21
-
22
-
23
15
  class LLMPromptArgs(BaseModel):
24
16
  pass
25
17
 
26
18
 
19
+ InT = TypeVar("InT", contravariant=True) # noqa: PLC0105
20
+ OutT = TypeVar("OutT", covariant=True) # noqa: PLC0105
21
+ StateT = TypeVar("StateT", bound=AgentState, covariant=True) # noqa: PLC0105
22
+
27
23
  LLMPrompt: TypeAlias = str
28
- LLMFormattedSystemArgs: TypeAlias = dict[str, str]
29
- LLMFormattedArgs: TypeAlias = dict[str, str | ImageData]
24
+ LLMFormattedSystemArgs: TypeAlias = Mapping[str, str | int | bool]
25
+ LLMFormattedArgs: TypeAlias = Mapping[str, str | int | bool | ImageData]
@@ -1,5 +1,5 @@
1
1
  import json
2
- from collections.abc import Hashable, Sequence
2
+ from collections.abc import Hashable, Mapping, Sequence
3
3
  from enum import StrEnum
4
4
  from typing import Annotated, Any, Literal, TypeAlias
5
5
  from uuid import uuid4
@@ -79,7 +79,7 @@ class UserMessage(MessageBase):
79
79
  def from_formatted_prompt(
80
80
  cls,
81
81
  prompt_template: str,
82
- prompt_args: dict[str, str | ImageData] | None = None,
82
+ prompt_args: Mapping[str, str | int | bool | ImageData] | None = None,
83
83
  model_id: str | None = None,
84
84
  ) -> "UserMessage":
85
85
  content = Content.from_formatted_prompt(
@@ -1,11 +1,11 @@
1
1
  from __future__ import annotations
2
2
 
3
- import asyncio
4
3
  from abc import ABC, abstractmethod
5
- from collections.abc import Sequence
6
- from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, TypeVar
4
+ from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, TypeAlias, TypeVar
7
5
 
8
- from pydantic import BaseModel, TypeAdapter
6
+ from pydantic import BaseModel, PrivateAttr, TypeAdapter
7
+
8
+ from ..generics_utils import AutoInstanceAttributesMixin
9
9
 
10
10
  if TYPE_CHECKING:
11
11
  from ..run_context import CtxT, RunContextWrapper
@@ -26,32 +26,44 @@ class ToolCall(BaseModel):
26
26
  tool_arguments: str
27
27
 
28
28
 
29
- class BaseTool(BaseModel, ABC, Generic[_ToolInT, _ToolOutT, CtxT]):
29
+ class BaseTool(
30
+ AutoInstanceAttributesMixin, BaseModel, ABC, Generic[_ToolInT, _ToolOutT, CtxT]
31
+ ):
32
+ _generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
33
+ 0: "_in_schema",
34
+ 1: "_out_schema",
35
+ }
36
+
30
37
  name: str
31
38
  description: str
32
- in_schema: type[_ToolInT]
33
- out_schema: type[_ToolOutT]
39
+
40
+ _in_schema: type[_ToolInT] = PrivateAttr()
41
+ _out_schema: type[_ToolOutT] = PrivateAttr()
34
42
 
35
43
  # Supported by OpenAI API
36
44
  strict: bool | None = None
37
45
 
46
+ @property
47
+ def in_schema(self) -> type[_ToolInT]: # type: ignore[reportInvalidTypeVarUse]
48
+ # Exposing the type of a contravariant variable only, should be type safe
49
+ return self._in_schema
50
+
51
+ @property
52
+ def out_schema(self) -> type[_ToolOutT]:
53
+ return self._out_schema
54
+
38
55
  @abstractmethod
39
56
  async def run(
40
57
  self, inp: _ToolInT, ctx: RunContextWrapper[CtxT] | None = None
41
58
  ) -> _ToolOutT:
42
59
  pass
43
60
 
44
- async def run_batch(
45
- self, inp_batch: Sequence[_ToolInT], ctx: RunContextWrapper[CtxT] | None = None
46
- ) -> Sequence[_ToolOutT]:
47
- return await asyncio.gather(*[self.run(inp, ctx=ctx) for inp in inp_batch])
48
-
49
61
  async def __call__(
50
62
  self, ctx: RunContextWrapper[CtxT] | None = None, **kwargs: Any
51
63
  ) -> _ToolOutT:
52
- result = await self.run(self.in_schema(**kwargs), ctx=ctx)
64
+ result = await self.run(self._in_schema(**kwargs), ctx=ctx)
53
65
 
54
- return TypeAdapter(self.out_schema).validate_python(result)
66
+ return TypeAdapter(self._out_schema).validate_python(result)
55
67
 
56
68
 
57
69
  ToolChoice: TypeAlias = (