grasp_agents 0.1.17__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
grasp_agents/llm_agent.py CHANGED
@@ -1,6 +1,6 @@
1
1
  from collections.abc import Sequence
2
2
  from pathlib import Path
3
- from typing import Any, Generic, cast, final
3
+ from typing import Any, ClassVar, Generic, Protocol, cast, final
4
4
 
5
5
  from pydantic import BaseModel
6
6
 
@@ -25,12 +25,11 @@ from .run_context import (
25
25
  SystemRunArgs,
26
26
  UserRunArgs,
27
27
  )
28
- from .tool_orchestrator import ToolCallLoopExitHandler, ToolOrchestrator
28
+ from .tool_orchestrator import ExitToolCallLoopHandler, ToolOrchestrator
29
29
  from .typing.content import ImageData
30
30
  from .typing.converters import Converters
31
31
  from .typing.io import (
32
32
  AgentID,
33
- AgentPayload,
34
33
  AgentState,
35
34
  InT,
36
35
  LLMFormattedArgs,
@@ -41,13 +40,24 @@ from .typing.io import (
41
40
  )
42
41
  from .typing.message import Conversation, Message, SystemMessage
43
42
  from .typing.tool import BaseTool
44
- from .utils import get_prompt
43
+ from .utils import get_prompt, validate_obj_from_json_or_py_string
44
+
45
+
46
+ class ParseOutputHandler(Protocol[OutT, CtxT]):
47
+ def __call__(
48
+ self, *args: Any, ctx: RunContextWrapper[CtxT] | None, **kwargs: Any
49
+ ) -> OutT: ...
45
50
 
46
51
 
47
52
  class LLMAgent(
48
53
  CommunicatingAgent[InT, OutT, LLMAgentState, CtxT],
49
54
  Generic[InT, OutT, CtxT],
50
55
  ):
56
+ _generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
57
+ 0: "_in_type",
58
+ 1: "_out_type",
59
+ }
60
+
51
61
  def __init__(
52
62
  self,
53
63
  agent_id: AgentID,
@@ -64,10 +74,6 @@ class LLMAgent(
64
74
  sys_args_schema: type[LLMPromptArgs] = LLMPromptArgs,
65
75
  # User args (static args provided via RunContextWrapper)
66
76
  usr_args_schema: type[LLMPromptArgs] = LLMPromptArgs,
67
- # Received args (args from another agent)
68
- rcv_args_schema: type[InT] = cast("type[InT]", AgentPayload),
69
- # Output schema
70
- out_schema: type[OutT] = cast("type[OutT]", AgentPayload),
71
77
  # Tools
72
78
  tools: list[BaseTool[Any, Any, CtxT]] | None = None,
73
79
  max_turns: int = 1000,
@@ -79,11 +85,7 @@ class LLMAgent(
79
85
  recipient_ids: list[AgentID] | None = None,
80
86
  ) -> None:
81
87
  super().__init__(
82
- agent_id=agent_id,
83
- out_schema=out_schema,
84
- rcv_args_schema=rcv_args_schema,
85
- message_pool=message_pool,
86
- recipient_ids=recipient_ids,
88
+ agent_id=agent_id, message_pool=message_pool, recipient_ids=recipient_ids
87
89
  )
88
90
 
89
91
  # Agent state
@@ -103,28 +105,19 @@ class LLMAgent(
103
105
  # Prompt builder
104
106
  sys_prompt = get_prompt(prompt_text=sys_prompt, prompt_path=sys_prompt_path)
105
107
  inp_prompt = get_prompt(prompt_text=inp_prompt, prompt_path=inp_prompt_path)
106
- self._prompt_builder: PromptBuilder[InT, CtxT] = PromptBuilder[InT, CtxT](
108
+ self._prompt_builder: PromptBuilder[InT, CtxT] = PromptBuilder[
109
+ self.in_type, CtxT
110
+ ](
107
111
  agent_id=self._agent_id,
108
112
  sys_prompt=sys_prompt,
109
113
  inp_prompt=inp_prompt,
110
114
  sys_args_schema=sys_args_schema,
111
115
  usr_args_schema=usr_args_schema,
112
- rcv_args_schema=rcv_args_schema,
113
116
  )
114
117
 
115
118
  self.no_tqdm = getattr(llm, "no_tqdm", False)
116
119
 
117
- if type(self)._format_sys_args is not LLMAgent[Any, Any, Any]._format_sys_args: # noqa: SLF001
118
- self._prompt_builder.format_sys_args_impl = self._format_sys_args
119
-
120
- if type(self)._format_inp_args is not LLMAgent[Any, Any, Any]._format_inp_args: # noqa: SLF001
121
- self._prompt_builder.format_inp_args_impl = self._format_inp_args
122
-
123
- if (
124
- type(self)._tool_call_loop_exit # noqa: SLF001
125
- is not LLMAgent[Any, Any, Any]._tool_call_loop_exit # noqa: SLF001
126
- ):
127
- self._tool_orchestrator.tool_call_loop_exit_impl = self._tool_call_loop_exit
120
+ self._register_overridden_handlers()
128
121
 
129
122
  @property
130
123
  def llm(self) -> LLM[LLMSettings, Converters]:
@@ -166,10 +159,10 @@ class LLMAgent(
166
159
  return self._parse_output_impl(
167
160
  conversation=conversation, rcv_args=rcv_args, ctx=ctx, **kwargs
168
161
  )
169
- try:
170
- return self._out_schema.model_validate_json(str(conversation[-1].content))
171
- except Exception:
172
- return self._out_schema()
162
+
163
+ return validate_obj_from_json_or_py_string(
164
+ str(conversation[-1].content), self._out_type_adapter
165
+ )
173
166
 
174
167
  @final
175
168
  async def run(
@@ -177,7 +170,7 @@ class LLMAgent(
177
170
  inp_items: LLMPrompt | list[str | ImageData] | None = None,
178
171
  *,
179
172
  ctx: RunContextWrapper[CtxT] | None = None,
180
- rcv_message: AgentMessage[InT, LLMAgentState] | None = None,
173
+ rcv_message: AgentMessage[InT, AgentState] | None = None,
181
174
  entry_point: bool = False,
182
175
  forbid_state_change: bool = False,
183
176
  **gen_kwargs: Any, # noqa: ARG002
@@ -247,7 +240,7 @@ class LLMAgent(
247
240
  batch_size = state.message_history.batch_size
248
241
  rcv_args_batch = rcv_message.payloads if rcv_message else batch_size * [None]
249
242
  val_output_batch = [
250
- self.out_schema.model_validate(
243
+ self._out_type_adapter.validate_python(
251
244
  self._parse_output(conversation=conv, rcv_args=rcv_args, ctx=ctx)
252
245
  )
253
246
  for conv, rcv_args in zip(
@@ -276,7 +269,7 @@ class LLMAgent(
276
269
  )
277
270
  ctx.interaction_history.append(
278
271
  cast(
279
- "InteractionRecord[AgentPayload, AgentPayload, AgentState]",
272
+ "InteractionRecord[Any, Any, AgentState]",
280
273
  interaction_record,
281
274
  )
282
275
  )
@@ -330,6 +323,13 @@ class LLMAgent(
330
323
 
331
324
  return func
332
325
 
326
+ def parse_output_handler(
327
+ self, func: ParseOutputHandler[OutT, CtxT]
328
+ ) -> ParseOutputHandler[OutT, CtxT]:
329
+ self._parse_output_impl = func
330
+
331
+ return func
332
+
333
333
  def make_custom_agent_state_handler(
334
334
  self, func: MakeCustomAgentState
335
335
  ) -> MakeCustomAgentState:
@@ -337,15 +337,35 @@ class LLMAgent(
337
337
 
338
338
  return func
339
339
 
340
- def tool_call_loop_exit_handler(
341
- self, func: ToolCallLoopExitHandler[CtxT]
342
- ) -> ToolCallLoopExitHandler[CtxT]:
343
- self._tool_orchestrator.tool_call_loop_exit_impl = func
340
+ def exit_tool_call_loop_handler(
341
+ self, func: ExitToolCallLoopHandler[CtxT]
342
+ ) -> ExitToolCallLoopHandler[CtxT]:
343
+ self._tool_orchestrator.exit_tool_call_loop_impl = func
344
344
 
345
345
  return func
346
346
 
347
347
  # -- Override these methods in subclasses if needed --
348
348
 
349
+ def _register_overridden_handlers(self) -> None:
350
+ cur_cls = type(self)
351
+ base_cls = LLMAgent[Any, Any, Any]
352
+
353
+ if cur_cls._format_sys_args is not base_cls._format_sys_args: # noqa: SLF001
354
+ self._prompt_builder.format_sys_args_impl = self._format_sys_args
355
+
356
+ if cur_cls._format_inp_args is not base_cls._format_inp_args: # noqa: SLF001
357
+ self._prompt_builder.format_inp_args_impl = self._format_inp_args
358
+
359
+ if cur_cls._make_custom_agent_state is not base_cls._make_custom_agent_state: # noqa: SLF001
360
+ self._make_custom_agent_state_impl = self._make_custom_agent_state
361
+
362
+ if (
363
+ cur_cls._tool_call_loop_exit is not base_cls._tool_call_loop_exit # noqa: SLF001
364
+ ):
365
+ self._tool_orchestrator.exit_tool_call_loop_impl = self._tool_call_loop_exit
366
+
367
+ self._parse_output_impl: ParseOutputHandler[OutT, CtxT] | None = None
368
+
349
369
  def _format_sys_args(
350
370
  self,
351
371
  sys_args: LLMPromptArgs,
@@ -367,6 +387,17 @@ class LLMAgent(
367
387
  "LLMAgent._format_inp_args must be overridden by a subclass"
368
388
  )
369
389
 
390
+ def _make_custom_agent_state(
391
+ self,
392
+ cur_state: LLMAgentState | None,
393
+ rcv_state: AgentState | None,
394
+ sys_prompt: LLMPrompt | None,
395
+ ctx: RunContextWrapper[Any] | None,
396
+ ) -> LLMAgentState:
397
+ raise NotImplementedError(
398
+ "LLMAgent._make_custom_agent_state_handler must be overridden by a subclass"
399
+ )
400
+
370
401
  def _tool_call_loop_exit(
371
402
  self,
372
403
  conversation: Conversation,
@@ -14,7 +14,7 @@ class MakeCustomAgentState(Protocol):
14
14
  def __call__(
15
15
  self,
16
16
  cur_state: Optional["LLMAgentState"],
17
- rec_state: Optional["LLMAgentState"],
17
+ rcv_state: AgentState | None,
18
18
  sys_prompt: LLMPrompt | None,
19
19
  ctx: RunContextWrapper[Any] | None,
20
20
  ) -> "LLMAgentState": ...
@@ -31,7 +31,7 @@ class LLMAgentState(AgentState):
31
31
  def from_cur_and_rcv_states(
32
32
  cls,
33
33
  cur_state: Optional["LLMAgentState"] = None,
34
- rcv_state: Optional["LLMAgentState"] = None,
34
+ rcv_state: Optional["AgentState"] = None,
35
35
  sys_prompt: LLMPrompt | None = None,
36
36
  strategy: SetAgentStateStrategy = "from_sender",
37
37
  make_custom_state_impl: MakeCustomAgentState | None = None,
@@ -48,7 +48,11 @@ class LLMAgentState(AgentState):
48
48
  upd_mh.reset(sys_prompt)
49
49
 
50
50
  elif strategy == "from_sender":
51
- rcv_mh = rcv_state.message_history if rcv_state else None
51
+ rcv_mh = (
52
+ rcv_state.message_history
53
+ if rcv_state and isinstance(rcv_state, "LLMAgentState")
54
+ else None
55
+ )
52
56
  if rcv_mh:
53
57
  upd_mh = deepcopy(rcv_mh)
54
58
  else:
@@ -60,12 +64,12 @@ class LLMAgentState(AgentState):
60
64
  )
61
65
  return make_custom_state_impl(
62
66
  cur_state=cur_state,
63
- rec_state=rcv_state,
67
+ rcv_state=rcv_state,
64
68
  sys_prompt=sys_prompt,
65
69
  ctx=ctx,
66
70
  )
67
71
 
68
- return cls(message_history=upd_mh)
72
+ return cls.model_construct(message_history=upd_mh)
69
73
 
70
74
  def __repr__(self) -> str:
71
75
  return f"Message History: {len(self.message_history)}"
@@ -1,11 +1,13 @@
1
1
  from collections.abc import Sequence
2
2
  from copy import deepcopy
3
- from typing import Generic, Protocol
3
+ from typing import ClassVar, Generic, Protocol
4
4
 
5
+ from pydantic import BaseModel, TypeAdapter
6
+
7
+ from .generics_utils import AutoInstanceAttributesMixin
5
8
  from .run_context import CtxT, RunContextWrapper, UserRunArgs
6
9
  from .typing.content import ImageData
7
10
  from .typing.io import (
8
- AgentPayload,
9
11
  InT,
10
12
  LLMFormattedArgs,
11
13
  LLMFormattedSystemArgs,
@@ -15,6 +17,10 @@ from .typing.io import (
15
17
  from .typing.message import UserMessage
16
18
 
17
19
 
20
+ class DummySchema(BaseModel):
21
+ pass
22
+
23
+
18
24
  class FormatSystemArgsHandler(Protocol[CtxT]):
19
25
  def __call__(
20
26
  self,
@@ -33,7 +39,9 @@ class FormatInputArgsHandler(Protocol[InT, CtxT]):
33
39
  ) -> LLMFormattedArgs: ...
34
40
 
35
41
 
36
- class PromptBuilder(Generic[InT, CtxT]):
42
+ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
43
+ _generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {0: "_in_type"}
44
+
37
45
  def __init__(
38
46
  self,
39
47
  agent_id: str,
@@ -41,17 +49,20 @@ class PromptBuilder(Generic[InT, CtxT]):
41
49
  inp_prompt: LLMPrompt | None,
42
50
  sys_args_schema: type[LLMPromptArgs],
43
51
  usr_args_schema: type[LLMPromptArgs],
44
- rcv_args_schema: type[InT],
45
52
  ):
53
+ self._in_type: type[InT]
54
+ super().__init__()
55
+
46
56
  self._agent_id = agent_id
47
57
  self.sys_prompt = sys_prompt
48
58
  self.inp_prompt = inp_prompt
49
59
  self.sys_args_schema = sys_args_schema
50
60
  self.usr_args_schema = usr_args_schema
51
- self.rcv_args_schema = rcv_args_schema
52
61
  self.format_sys_args_impl: FormatSystemArgsHandler[CtxT] | None = None
53
62
  self.format_inp_args_impl: FormatInputArgsHandler[InT, CtxT] | None = None
54
63
 
64
+ self._rcv_args_type_adapter: TypeAdapter[InT] = TypeAdapter(self._in_type)
65
+
55
66
  def _format_sys_args(
56
67
  self,
57
68
  sys_args: LLMPromptArgs,
@@ -73,9 +84,18 @@ class PromptBuilder(Generic[InT, CtxT]):
73
84
  usr_args=usr_args, rcv_args=rcv_args, ctx=ctx
74
85
  )
75
86
 
76
- return usr_args.model_dump(exclude_unset=True) | rcv_args.model_dump(
77
- exclude_unset=True, exclude={"selected_recipient_ids"}
78
- )
87
+ if not isinstance(rcv_args, BaseModel) and rcv_args is not None:
88
+ raise TypeError(
89
+ "Cannot apply default formatting to non-BaseModel received arguments."
90
+ )
91
+
92
+ usr_args_ = usr_args
93
+ rcv_args_ = DummySchema() if rcv_args is None else rcv_args
94
+
95
+ usr_args_dump = usr_args_.model_dump(exclude_unset=True)
96
+ rcv_args_dump = rcv_args_.model_dump(exclude={"selected_recipient_ids"})
97
+
98
+ return usr_args_dump | rcv_args_dump
79
99
 
80
100
  def make_sys_prompt(
81
101
  self,
@@ -100,20 +120,18 @@ class PromptBuilder(Generic[InT, CtxT]):
100
120
  def _usr_messages_from_rcv_args(
101
121
  self, rcv_args_batch: Sequence[InT]
102
122
  ) -> list[UserMessage]:
103
- val_rcv_args_batch = [
104
- self.rcv_args_schema.model_validate(rcv) for rcv in rcv_args_batch
105
- ]
106
-
107
123
  return [
108
124
  UserMessage.from_text(
109
- rcv.model_dump_json(
125
+ self._rcv_args_type_adapter.dump_json(
126
+ rcv,
110
127
  exclude_unset=True,
111
128
  indent=2,
112
129
  exclude={"selected_recipient_ids"},
113
- ),
130
+ warnings="error",
131
+ ).decode("utf-8"),
114
132
  model_id=self._agent_id,
115
133
  )
116
- for rcv in val_rcv_args_batch
134
+ for rcv in rcv_args_batch
117
135
  ]
118
136
 
119
137
  def _usr_messages_from_prompt_template(
@@ -123,17 +141,19 @@ class PromptBuilder(Generic[InT, CtxT]):
123
141
  rcv_args_batch: Sequence[InT] | None = None,
124
142
  ctx: RunContextWrapper[CtxT] | None = None,
125
143
  ) -> Sequence[UserMessage]:
126
- usr_args_batch, rcv_args_batch = self._make_batched(usr_args, rcv_args_batch)
127
- val_usr_args_batch = [
128
- self.usr_args_schema.model_validate(u) for u in usr_args_batch
144
+ usr_args_batch_, rcv_args_batch_ = self._make_batched(usr_args, rcv_args_batch)
145
+
146
+ val_usr_args_batch_ = [
147
+ self.usr_args_schema.model_validate(u) for u in usr_args_batch_
129
148
  ]
130
- val_rcv_args_batch = [
131
- self.rcv_args_schema.model_validate(r) for r in rcv_args_batch
149
+ val_rcv_args_batch_ = [
150
+ self._rcv_args_type_adapter.validate_python(rcv) for rcv in rcv_args_batch_
132
151
  ]
152
+
133
153
  formatted_inp_args_batch = [
134
154
  self._format_inp_args(usr_args=val_usr_args, rcv_args=val_rcv_args, ctx=ctx)
135
155
  for val_usr_args, val_rcv_args in zip(
136
- val_usr_args_batch, val_rcv_args_batch, strict=False
156
+ val_usr_args_batch_, val_rcv_args_batch_, strict=False
137
157
  )
138
158
  ]
139
159
 
@@ -187,11 +207,11 @@ class PromptBuilder(Generic[InT, CtxT]):
187
207
  self,
188
208
  usr_args: UserRunArgs | None = None,
189
209
  rcv_args_batch: Sequence[InT] | None = None,
190
- ) -> tuple[Sequence[LLMPromptArgs], Sequence[InT]]:
210
+ ) -> tuple[Sequence[LLMPromptArgs | DummySchema], Sequence[InT | DummySchema]]:
191
211
  usr_args_batch_ = (
192
- usr_args if isinstance(usr_args, list) else [usr_args or LLMPromptArgs()]
212
+ usr_args if isinstance(usr_args, list) else [usr_args or DummySchema()]
193
213
  )
194
- rcv_args_batch_ = rcv_args_batch or [AgentPayload()]
214
+ rcv_args_batch_ = rcv_args_batch or [DummySchema()]
195
215
 
196
216
  # Broadcast singleton → match lengths
197
217
  if len(usr_args_batch_) == 1 and len(rcv_args_batch_) > 1:
@@ -201,4 +221,4 @@ class PromptBuilder(Generic[InT, CtxT]):
201
221
  if len(usr_args_batch_) != len(rcv_args_batch_):
202
222
  raise ValueError("User args and received args must have the same length")
203
223
 
204
- return usr_args_batch_, rcv_args_batch_ # type: ignore
224
+ return usr_args_batch_, rcv_args_batch_
@@ -2,6 +2,7 @@ import asyncio
2
2
  import functools
3
3
  import logging
4
4
  from collections.abc import Callable, Coroutine
5
+ from dataclasses import dataclass
5
6
  from time import monotonic
6
7
  from typing import Any, Generic, overload
7
8
 
@@ -9,21 +10,25 @@ from tqdm.autonotebook import tqdm
9
10
 
10
11
  from ..utils import asyncio_gather_with_pbar
11
12
  from .types import (
12
- QueryP,
13
- QueryR,
14
- QueryT,
15
- RateLimDecoratorWithArgsList,
16
- RateLimDecoratorWithArgsSingle,
17
- RateLimiterState,
18
- RetrievalCallableList,
19
- RetrievalCallableSingle,
13
+ P,
14
+ ProcessorCallableList,
15
+ ProcessorCallableSingle,
16
+ R,
17
+ RateLimWrapperWithArgsList,
18
+ RateLimWrapperWithArgsSingle,
19
+ T,
20
20
  )
21
- from .utils import partial_retrieval_callable, split_pos_args
21
+ from .utils import partial_processor_callable, split_pos_args
22
22
 
23
23
  logger = logging.getLogger(__name__)
24
24
 
25
25
 
26
- class RateLimiterC(Generic[QueryT, QueryR]):
26
+ @dataclass
27
+ class RateLimiterState:
28
+ next_request_time: float = 0.0
29
+
30
+
31
+ class RateLimiterC(Generic[T, R]):
27
32
  def __init__(
28
33
  self,
29
34
  rpm: float,
@@ -40,9 +45,9 @@ class RateLimiterC(Generic[QueryT, QueryR]):
40
45
 
41
46
  async def process_input(
42
47
  self,
43
- func_partial: Callable[[QueryT], Coroutine[Any, Any, QueryR]],
44
- inp: QueryT,
45
- ) -> QueryR:
48
+ func_partial: Callable[[T], Coroutine[Any, Any, R]],
49
+ inp: T,
50
+ ) -> R:
46
51
  async with self._semaphore:
47
52
  async with self._lock:
48
53
  now = monotonic()
@@ -53,11 +58,11 @@ class RateLimiterC(Generic[QueryT, QueryR]):
53
58
 
54
59
  async def process_inputs(
55
60
  self,
56
- func_partial: Callable[[QueryT], Coroutine[Any, Any, QueryR]],
57
- inputs: list[QueryT],
61
+ func_partial: Callable[[T], Coroutine[Any, Any, R]],
62
+ inputs: list[T],
58
63
  no_tqdm: bool = False,
59
- ) -> list[QueryR]:
60
- results: list[QueryR] = []
64
+ ) -> list[R]:
65
+ results: list[R] = []
61
66
  for i in tqdm(
62
67
  range(0, len(inputs), self._chunk_size),
63
68
  disable=no_tqdm,
@@ -95,33 +100,30 @@ class RateLimiterC(Generic[QueryT, QueryR]):
95
100
 
96
101
  @overload
97
102
  def limit_rate(
98
- call: RetrievalCallableSingle[QueryT, QueryP, QueryR],
99
- rate_limiter: RateLimiterC[QueryT, QueryR] | None = None,
100
- ) -> RetrievalCallableSingle[QueryT, QueryP, QueryR]: ...
103
+ call: ProcessorCallableSingle[T, P, R],
104
+ rate_limiter: RateLimiterC[T, R] | None = None,
105
+ ) -> ProcessorCallableSingle[T, P, R]: ...
101
106
 
102
107
 
103
108
  @overload
104
109
  def limit_rate(
105
110
  call: None = None,
106
- rate_limiter: RateLimiterC[QueryT, QueryR] | None = None,
107
- ) -> RateLimDecoratorWithArgsSingle[QueryT, QueryP, QueryR]: ...
111
+ rate_limiter: RateLimiterC[T, R] | None = None,
112
+ ) -> RateLimWrapperWithArgsSingle[T, P, R]: ...
108
113
 
109
114
 
110
115
  def limit_rate(
111
- call: RetrievalCallableSingle[QueryT, QueryP, QueryR] | None = None,
112
- rate_limiter: RateLimiterC[QueryT, QueryR] | None = None,
113
- ) -> (
114
- RetrievalCallableSingle[QueryT, QueryP, QueryR]
115
- | RateLimDecoratorWithArgsSingle[QueryT, QueryP, QueryR]
116
- ):
116
+ call: ProcessorCallableSingle[T, P, R] | None = None,
117
+ rate_limiter: RateLimiterC[T, R] | None = None,
118
+ ) -> ProcessorCallableSingle[T, P, R] | RateLimWrapperWithArgsSingle[T, P, R]:
117
119
  if call is None:
118
120
  return functools.partial(limit_rate, rate_limiter=rate_limiter)
119
121
 
120
122
  @functools.wraps(call) # type: ignore
121
- async def wrapper(*args: Any, **kwargs: Any) -> QueryR:
122
- inp: QueryT
123
- self_obj, inp, other_args = split_pos_args(call, args) # type: ignore
124
- call_partial = partial_retrieval_callable(call, self_obj, *other_args, **kwargs)
123
+ async def wrapper(*args: Any, **kwargs: Any) -> R:
124
+ inp: T
125
+ self_obj, inp, other_args = split_pos_args(call, args)
126
+ call_partial = partial_processor_callable(call, self_obj, *other_args, **kwargs)
125
127
 
126
128
  _rate_limiter = rate_limiter
127
129
  if _rate_limiter is None:
@@ -136,39 +138,36 @@ def limit_rate(
136
138
 
137
139
  @overload
138
140
  def limit_rate_chunked(
139
- call: RetrievalCallableSingle[QueryT, QueryP, QueryR],
140
- rate_limiter: RateLimiterC[QueryT, QueryR] | None = None,
141
+ call: ProcessorCallableList[T, P, R],
142
+ rate_limiter: RateLimiterC[T, R] | None = None,
141
143
  no_tqdm: bool | None = None,
142
- ) -> RetrievalCallableList[QueryT, QueryP, QueryR]: ...
144
+ ) -> ProcessorCallableList[T, P, R]: ...
143
145
 
144
146
 
145
147
  @overload
146
148
  def limit_rate_chunked(
147
149
  call: None = None,
148
- rate_limiter: RateLimiterC[QueryT, QueryR] | None = None,
150
+ rate_limiter: RateLimiterC[T, R] | None = None,
149
151
  no_tqdm: bool | None = None,
150
- ) -> RateLimDecoratorWithArgsList[QueryT, QueryP, QueryR]: ...
152
+ ) -> RateLimWrapperWithArgsList[T, P, R]: ...
151
153
 
152
154
 
153
155
  def limit_rate_chunked(
154
- call: RetrievalCallableSingle[QueryT, QueryP, QueryR] | None = None,
155
- rate_limiter: RateLimiterC[QueryT, QueryR] | None = None,
156
+ call: ProcessorCallableList[T, P, R] | None = None,
157
+ rate_limiter: RateLimiterC[T, R] | None = None,
156
158
  no_tqdm: bool | None = None,
157
- ) -> (
158
- RetrievalCallableList[QueryT, QueryP, QueryR]
159
- | RateLimDecoratorWithArgsList[QueryT, QueryP, QueryR]
160
- ):
159
+ ) -> ProcessorCallableList[T, P, R] | RateLimWrapperWithArgsList[T, P, R]:
161
160
  if call is None:
162
161
  return functools.partial(
163
162
  limit_rate_chunked, rate_limiter=rate_limiter, no_tqdm=no_tqdm
164
- )
163
+ ) # type: ignore
165
164
 
166
165
  @functools.wraps(call) # type: ignore
167
- async def wrapper(*args: Any, **kwargs: Any) -> list[QueryR]:
166
+ async def wrapper(*args: Any, **kwargs: Any) -> list[R]:
168
167
  assert call is not None
169
168
 
170
- self_obj, inputs, other_args = split_pos_args(call, args) # type: ignore
171
- call_partial = partial_retrieval_callable(call, self_obj, *other_args, **kwargs)
169
+ self_obj, inputs, other_args = split_pos_args(call, args)
170
+ call_partial = partial_processor_callable(call, self_obj, *other_args, **kwargs)
172
171
 
173
172
  _no_tqdm = no_tqdm
174
173
  _rate_limiter = rate_limiter
@@ -182,7 +181,9 @@ def limit_rate_chunked(
182
181
  *[call_partial(inp) for inp in inputs], no_tqdm=_no_tqdm
183
182
  )
184
183
  return await _rate_limiter.process_inputs(
185
- func_partial=call_partial, inputs=inputs, no_tqdm=_no_tqdm
184
+ func_partial=call_partial, # type: ignore
185
+ inputs=inputs,
186
+ no_tqdm=_no_tqdm,
186
187
  )
187
188
 
188
189
  return wrapper
@@ -1,57 +1,36 @@
1
1
  from collections.abc import Callable, Coroutine
2
- from dataclasses import dataclass
3
- from typing import (
4
- Any,
5
- Concatenate,
6
- ParamSpec,
7
- TypeAlias,
8
- TypeVar,
9
- )
10
-
11
- MAX_RPM = 1e10
12
-
2
+ from typing import Any, Concatenate, ParamSpec, TypeAlias, TypeVar
13
3
 
14
- @dataclass
15
- class RateLimiterState:
16
- next_request_time: float = 0.0
4
+ T = TypeVar("T")
5
+ R = TypeVar("R")
6
+ P = ParamSpec("P")
17
7
 
18
-
19
- QueryT = TypeVar("QueryT")
20
- QueryR = TypeVar("QueryR")
21
- QueryP = ParamSpec("QueryP")
22
-
23
- RetrievalFuncSingle: TypeAlias = Callable[
24
- Concatenate[QueryT, QueryP], Coroutine[Any, Any, QueryR]
25
- ]
26
- RetrievalFuncList: TypeAlias = Callable[
27
- Concatenate[list[QueryT], QueryP], Coroutine[Any, Any, list[QueryR]]
8
+ ProcessorFuncSingle: TypeAlias = Callable[Concatenate[T, P], Coroutine[Any, Any, R]]
9
+ ProcessorFuncList: TypeAlias = Callable[
10
+ Concatenate[list[T], P], Coroutine[Any, Any, list[R]]
28
11
  ]
29
12
 
30
- RetrievalMethodSingle: TypeAlias = Callable[
31
- Concatenate[Any, QueryT, QueryP], Coroutine[Any, Any, QueryR]
13
+ ProcessorMethodSingle: TypeAlias = Callable[
14
+ Concatenate[Any, T, P], Coroutine[Any, Any, R]
32
15
  ]
33
- RetrievalMethodList: TypeAlias = Callable[
34
- Concatenate[Any, list[QueryT], QueryP], Coroutine[Any, Any, list[QueryR]]
16
+ ProcessorMethodList: TypeAlias = Callable[
17
+ Concatenate[Any, list[T], P], Coroutine[Any, Any, list[R]]
35
18
  ]
36
19
 
37
- RetrievalCallableSingle: TypeAlias = (
38
- RetrievalFuncSingle[QueryT, QueryP, QueryR]
39
- | RetrievalMethodSingle[QueryT, QueryP, QueryR]
20
+ ProcessorCallableSingle: TypeAlias = (
21
+ ProcessorFuncSingle[T, P, R] | ProcessorMethodSingle[T, P, R]
40
22
  )
41
23
 
42
- RetrievalCallableList: TypeAlias = (
43
- RetrievalFuncList[QueryT, QueryP, QueryR]
44
- | RetrievalMethodList[QueryT, QueryP, QueryR]
24
+ ProcessorCallableList: TypeAlias = (
25
+ ProcessorFuncList[T, P, R] | ProcessorMethodList[T, P, R]
45
26
  )
46
27
 
47
28
 
48
- RateLimDecoratorWithArgsSingle = Callable[
49
- [RetrievalCallableSingle[QueryT, QueryP, QueryR]],
50
- RetrievalCallableSingle[QueryT, QueryP, QueryR],
29
+ RateLimWrapperWithArgsSingle = Callable[
30
+ [ProcessorCallableSingle[T, P, R]], ProcessorCallableSingle[T, P, R]
51
31
  ]
52
32
 
53
33
 
54
- RateLimDecoratorWithArgsList = Callable[
55
- [RetrievalCallableList[QueryT, QueryP, QueryR]],
56
- RetrievalCallableList[QueryT, QueryP, QueryR],
34
+ RateLimWrapperWithArgsList = Callable[
35
+ [ProcessorCallableList[T, P, R]], ProcessorCallableList[T, P, R]
57
36
  ]