grasp_agents 0.1.18__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_agents/agent_message.py +2 -2
- grasp_agents/agent_message_pool.py +6 -8
- grasp_agents/base_agent.py +15 -36
- grasp_agents/cloud_llm.py +9 -6
- grasp_agents/comm_agent.py +39 -43
- grasp_agents/generics_utils.py +159 -0
- grasp_agents/llm.py +4 -0
- grasp_agents/llm_agent.py +68 -37
- grasp_agents/llm_agent_state.py +9 -5
- grasp_agents/prompt_builder.py +45 -25
- grasp_agents/rate_limiting/rate_limiter_chunked.py +49 -48
- grasp_agents/rate_limiting/types.py +19 -40
- grasp_agents/rate_limiting/utils.py +24 -27
- grasp_agents/run_context.py +2 -15
- grasp_agents/tool_orchestrator.py +30 -8
- grasp_agents/typing/converters.py +3 -1
- grasp_agents/typing/io.py +4 -9
- grasp_agents/typing/tool.py +26 -7
- grasp_agents/utils.py +26 -39
- grasp_agents/workflow/looped_agent.py +12 -9
- grasp_agents/workflow/sequential_agent.py +9 -6
- grasp_agents/workflow/workflow_agent.py +16 -11
- {grasp_agents-0.1.18.dist-info → grasp_agents-0.2.0.dist-info}/METADATA +37 -33
- grasp_agents-0.2.0.dist-info/RECORD +45 -0
- grasp_agents-0.1.18.dist-info/RECORD +0 -44
- {grasp_agents-0.1.18.dist-info → grasp_agents-0.2.0.dist-info}/WHEEL +0 -0
- {grasp_agents-0.1.18.dist-info → grasp_agents-0.2.0.dist-info}/licenses/LICENSE.md +0 -0
grasp_agents/llm_agent.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
from collections.abc import Sequence
|
2
2
|
from pathlib import Path
|
3
|
-
from typing import Any, Generic, cast, final
|
3
|
+
from typing import Any, ClassVar, Generic, Protocol, cast, final
|
4
4
|
|
5
5
|
from pydantic import BaseModel
|
6
6
|
|
@@ -25,12 +25,11 @@ from .run_context import (
|
|
25
25
|
SystemRunArgs,
|
26
26
|
UserRunArgs,
|
27
27
|
)
|
28
|
-
from .tool_orchestrator import
|
28
|
+
from .tool_orchestrator import ExitToolCallLoopHandler, ToolOrchestrator
|
29
29
|
from .typing.content import ImageData
|
30
30
|
from .typing.converters import Converters
|
31
31
|
from .typing.io import (
|
32
32
|
AgentID,
|
33
|
-
AgentPayload,
|
34
33
|
AgentState,
|
35
34
|
InT,
|
36
35
|
LLMFormattedArgs,
|
@@ -41,13 +40,24 @@ from .typing.io import (
|
|
41
40
|
)
|
42
41
|
from .typing.message import Conversation, Message, SystemMessage
|
43
42
|
from .typing.tool import BaseTool
|
44
|
-
from .utils import get_prompt
|
43
|
+
from .utils import get_prompt, validate_obj_from_json_or_py_string
|
44
|
+
|
45
|
+
|
46
|
+
class ParseOutputHandler(Protocol[OutT, CtxT]):
|
47
|
+
def __call__(
|
48
|
+
self, *args: Any, ctx: RunContextWrapper[CtxT] | None, **kwargs: Any
|
49
|
+
) -> OutT: ...
|
45
50
|
|
46
51
|
|
47
52
|
class LLMAgent(
|
48
53
|
CommunicatingAgent[InT, OutT, LLMAgentState, CtxT],
|
49
54
|
Generic[InT, OutT, CtxT],
|
50
55
|
):
|
56
|
+
_generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
|
57
|
+
0: "_in_type",
|
58
|
+
1: "_out_type",
|
59
|
+
}
|
60
|
+
|
51
61
|
def __init__(
|
52
62
|
self,
|
53
63
|
agent_id: AgentID,
|
@@ -64,10 +74,6 @@ class LLMAgent(
|
|
64
74
|
sys_args_schema: type[LLMPromptArgs] = LLMPromptArgs,
|
65
75
|
# User args (static args provided via RunContextWrapper)
|
66
76
|
usr_args_schema: type[LLMPromptArgs] = LLMPromptArgs,
|
67
|
-
# Received args (args from another agent)
|
68
|
-
rcv_args_schema: type[InT] = cast("type[InT]", AgentPayload),
|
69
|
-
# Output schema
|
70
|
-
out_schema: type[OutT] = cast("type[OutT]", AgentPayload),
|
71
77
|
# Tools
|
72
78
|
tools: list[BaseTool[Any, Any, CtxT]] | None = None,
|
73
79
|
max_turns: int = 1000,
|
@@ -79,11 +85,7 @@ class LLMAgent(
|
|
79
85
|
recipient_ids: list[AgentID] | None = None,
|
80
86
|
) -> None:
|
81
87
|
super().__init__(
|
82
|
-
agent_id=agent_id,
|
83
|
-
out_schema=out_schema,
|
84
|
-
rcv_args_schema=rcv_args_schema,
|
85
|
-
message_pool=message_pool,
|
86
|
-
recipient_ids=recipient_ids,
|
88
|
+
agent_id=agent_id, message_pool=message_pool, recipient_ids=recipient_ids
|
87
89
|
)
|
88
90
|
|
89
91
|
# Agent state
|
@@ -103,28 +105,19 @@ class LLMAgent(
|
|
103
105
|
# Prompt builder
|
104
106
|
sys_prompt = get_prompt(prompt_text=sys_prompt, prompt_path=sys_prompt_path)
|
105
107
|
inp_prompt = get_prompt(prompt_text=inp_prompt, prompt_path=inp_prompt_path)
|
106
|
-
self._prompt_builder: PromptBuilder[InT, CtxT] = PromptBuilder[
|
108
|
+
self._prompt_builder: PromptBuilder[InT, CtxT] = PromptBuilder[
|
109
|
+
self.in_type, CtxT
|
110
|
+
](
|
107
111
|
agent_id=self._agent_id,
|
108
112
|
sys_prompt=sys_prompt,
|
109
113
|
inp_prompt=inp_prompt,
|
110
114
|
sys_args_schema=sys_args_schema,
|
111
115
|
usr_args_schema=usr_args_schema,
|
112
|
-
rcv_args_schema=rcv_args_schema,
|
113
116
|
)
|
114
117
|
|
115
118
|
self.no_tqdm = getattr(llm, "no_tqdm", False)
|
116
119
|
|
117
|
-
|
118
|
-
self._prompt_builder.format_sys_args_impl = self._format_sys_args
|
119
|
-
|
120
|
-
if type(self)._format_inp_args is not LLMAgent[Any, Any, Any]._format_inp_args: # noqa: SLF001
|
121
|
-
self._prompt_builder.format_inp_args_impl = self._format_inp_args
|
122
|
-
|
123
|
-
if (
|
124
|
-
type(self)._tool_call_loop_exit # noqa: SLF001
|
125
|
-
is not LLMAgent[Any, Any, Any]._tool_call_loop_exit # noqa: SLF001
|
126
|
-
):
|
127
|
-
self._tool_orchestrator.tool_call_loop_exit_impl = self._tool_call_loop_exit
|
120
|
+
self._register_overridden_handlers()
|
128
121
|
|
129
122
|
@property
|
130
123
|
def llm(self) -> LLM[LLMSettings, Converters]:
|
@@ -166,10 +159,10 @@ class LLMAgent(
|
|
166
159
|
return self._parse_output_impl(
|
167
160
|
conversation=conversation, rcv_args=rcv_args, ctx=ctx, **kwargs
|
168
161
|
)
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
162
|
+
|
163
|
+
return validate_obj_from_json_or_py_string(
|
164
|
+
str(conversation[-1].content), self._out_type_adapter
|
165
|
+
)
|
173
166
|
|
174
167
|
@final
|
175
168
|
async def run(
|
@@ -177,7 +170,7 @@ class LLMAgent(
|
|
177
170
|
inp_items: LLMPrompt | list[str | ImageData] | None = None,
|
178
171
|
*,
|
179
172
|
ctx: RunContextWrapper[CtxT] | None = None,
|
180
|
-
rcv_message: AgentMessage[InT,
|
173
|
+
rcv_message: AgentMessage[InT, AgentState] | None = None,
|
181
174
|
entry_point: bool = False,
|
182
175
|
forbid_state_change: bool = False,
|
183
176
|
**gen_kwargs: Any, # noqa: ARG002
|
@@ -247,7 +240,7 @@ class LLMAgent(
|
|
247
240
|
batch_size = state.message_history.batch_size
|
248
241
|
rcv_args_batch = rcv_message.payloads if rcv_message else batch_size * [None]
|
249
242
|
val_output_batch = [
|
250
|
-
self.
|
243
|
+
self._out_type_adapter.validate_python(
|
251
244
|
self._parse_output(conversation=conv, rcv_args=rcv_args, ctx=ctx)
|
252
245
|
)
|
253
246
|
for conv, rcv_args in zip(
|
@@ -276,7 +269,7 @@ class LLMAgent(
|
|
276
269
|
)
|
277
270
|
ctx.interaction_history.append(
|
278
271
|
cast(
|
279
|
-
"InteractionRecord[
|
272
|
+
"InteractionRecord[Any, Any, AgentState]",
|
280
273
|
interaction_record,
|
281
274
|
)
|
282
275
|
)
|
@@ -330,6 +323,13 @@ class LLMAgent(
|
|
330
323
|
|
331
324
|
return func
|
332
325
|
|
326
|
+
def parse_output_handler(
|
327
|
+
self, func: ParseOutputHandler[OutT, CtxT]
|
328
|
+
) -> ParseOutputHandler[OutT, CtxT]:
|
329
|
+
self._parse_output_impl = func
|
330
|
+
|
331
|
+
return func
|
332
|
+
|
333
333
|
def make_custom_agent_state_handler(
|
334
334
|
self, func: MakeCustomAgentState
|
335
335
|
) -> MakeCustomAgentState:
|
@@ -337,15 +337,35 @@ class LLMAgent(
|
|
337
337
|
|
338
338
|
return func
|
339
339
|
|
340
|
-
def
|
341
|
-
self, func:
|
342
|
-
) ->
|
343
|
-
self._tool_orchestrator.
|
340
|
+
def exit_tool_call_loop_handler(
|
341
|
+
self, func: ExitToolCallLoopHandler[CtxT]
|
342
|
+
) -> ExitToolCallLoopHandler[CtxT]:
|
343
|
+
self._tool_orchestrator.exit_tool_call_loop_impl = func
|
344
344
|
|
345
345
|
return func
|
346
346
|
|
347
347
|
# -- Override these methods in subclasses if needed --
|
348
348
|
|
349
|
+
def _register_overridden_handlers(self) -> None:
|
350
|
+
cur_cls = type(self)
|
351
|
+
base_cls = LLMAgent[Any, Any, Any]
|
352
|
+
|
353
|
+
if cur_cls._format_sys_args is not base_cls._format_sys_args: # noqa: SLF001
|
354
|
+
self._prompt_builder.format_sys_args_impl = self._format_sys_args
|
355
|
+
|
356
|
+
if cur_cls._format_inp_args is not base_cls._format_inp_args: # noqa: SLF001
|
357
|
+
self._prompt_builder.format_inp_args_impl = self._format_inp_args
|
358
|
+
|
359
|
+
if cur_cls._make_custom_agent_state is not base_cls._make_custom_agent_state: # noqa: SLF001
|
360
|
+
self._make_custom_agent_state_impl = self._make_custom_agent_state
|
361
|
+
|
362
|
+
if (
|
363
|
+
cur_cls._tool_call_loop_exit is not base_cls._tool_call_loop_exit # noqa: SLF001
|
364
|
+
):
|
365
|
+
self._tool_orchestrator.exit_tool_call_loop_impl = self._tool_call_loop_exit
|
366
|
+
|
367
|
+
self._parse_output_impl: ParseOutputHandler[OutT, CtxT] | None = None
|
368
|
+
|
349
369
|
def _format_sys_args(
|
350
370
|
self,
|
351
371
|
sys_args: LLMPromptArgs,
|
@@ -367,6 +387,17 @@ class LLMAgent(
|
|
367
387
|
"LLMAgent._format_inp_args must be overridden by a subclass"
|
368
388
|
)
|
369
389
|
|
390
|
+
def _make_custom_agent_state(
|
391
|
+
self,
|
392
|
+
cur_state: LLMAgentState | None,
|
393
|
+
rcv_state: AgentState | None,
|
394
|
+
sys_prompt: LLMPrompt | None,
|
395
|
+
ctx: RunContextWrapper[Any] | None,
|
396
|
+
) -> LLMAgentState:
|
397
|
+
raise NotImplementedError(
|
398
|
+
"LLMAgent._make_custom_agent_state_handler must be overridden by a subclass"
|
399
|
+
)
|
400
|
+
|
370
401
|
def _tool_call_loop_exit(
|
371
402
|
self,
|
372
403
|
conversation: Conversation,
|
grasp_agents/llm_agent_state.py
CHANGED
@@ -14,7 +14,7 @@ class MakeCustomAgentState(Protocol):
|
|
14
14
|
def __call__(
|
15
15
|
self,
|
16
16
|
cur_state: Optional["LLMAgentState"],
|
17
|
-
|
17
|
+
rcv_state: AgentState | None,
|
18
18
|
sys_prompt: LLMPrompt | None,
|
19
19
|
ctx: RunContextWrapper[Any] | None,
|
20
20
|
) -> "LLMAgentState": ...
|
@@ -31,7 +31,7 @@ class LLMAgentState(AgentState):
|
|
31
31
|
def from_cur_and_rcv_states(
|
32
32
|
cls,
|
33
33
|
cur_state: Optional["LLMAgentState"] = None,
|
34
|
-
rcv_state: Optional["
|
34
|
+
rcv_state: Optional["AgentState"] = None,
|
35
35
|
sys_prompt: LLMPrompt | None = None,
|
36
36
|
strategy: SetAgentStateStrategy = "from_sender",
|
37
37
|
make_custom_state_impl: MakeCustomAgentState | None = None,
|
@@ -48,7 +48,11 @@ class LLMAgentState(AgentState):
|
|
48
48
|
upd_mh.reset(sys_prompt)
|
49
49
|
|
50
50
|
elif strategy == "from_sender":
|
51
|
-
rcv_mh =
|
51
|
+
rcv_mh = (
|
52
|
+
rcv_state.message_history
|
53
|
+
if rcv_state and isinstance(rcv_state, "LLMAgentState")
|
54
|
+
else None
|
55
|
+
)
|
52
56
|
if rcv_mh:
|
53
57
|
upd_mh = deepcopy(rcv_mh)
|
54
58
|
else:
|
@@ -60,12 +64,12 @@ class LLMAgentState(AgentState):
|
|
60
64
|
)
|
61
65
|
return make_custom_state_impl(
|
62
66
|
cur_state=cur_state,
|
63
|
-
|
67
|
+
rcv_state=rcv_state,
|
64
68
|
sys_prompt=sys_prompt,
|
65
69
|
ctx=ctx,
|
66
70
|
)
|
67
71
|
|
68
|
-
return cls(message_history=upd_mh)
|
72
|
+
return cls.model_construct(message_history=upd_mh)
|
69
73
|
|
70
74
|
def __repr__(self) -> str:
|
71
75
|
return f"Message History: {len(self.message_history)}"
|
grasp_agents/prompt_builder.py
CHANGED
@@ -1,11 +1,13 @@
|
|
1
1
|
from collections.abc import Sequence
|
2
2
|
from copy import deepcopy
|
3
|
-
from typing import Generic, Protocol
|
3
|
+
from typing import ClassVar, Generic, Protocol
|
4
4
|
|
5
|
+
from pydantic import BaseModel, TypeAdapter
|
6
|
+
|
7
|
+
from .generics_utils import AutoInstanceAttributesMixin
|
5
8
|
from .run_context import CtxT, RunContextWrapper, UserRunArgs
|
6
9
|
from .typing.content import ImageData
|
7
10
|
from .typing.io import (
|
8
|
-
AgentPayload,
|
9
11
|
InT,
|
10
12
|
LLMFormattedArgs,
|
11
13
|
LLMFormattedSystemArgs,
|
@@ -15,6 +17,10 @@ from .typing.io import (
|
|
15
17
|
from .typing.message import UserMessage
|
16
18
|
|
17
19
|
|
20
|
+
class DummySchema(BaseModel):
|
21
|
+
pass
|
22
|
+
|
23
|
+
|
18
24
|
class FormatSystemArgsHandler(Protocol[CtxT]):
|
19
25
|
def __call__(
|
20
26
|
self,
|
@@ -33,7 +39,9 @@ class FormatInputArgsHandler(Protocol[InT, CtxT]):
|
|
33
39
|
) -> LLMFormattedArgs: ...
|
34
40
|
|
35
41
|
|
36
|
-
class PromptBuilder(Generic[InT, CtxT]):
|
42
|
+
class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
|
43
|
+
_generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {0: "_in_type"}
|
44
|
+
|
37
45
|
def __init__(
|
38
46
|
self,
|
39
47
|
agent_id: str,
|
@@ -41,17 +49,20 @@ class PromptBuilder(Generic[InT, CtxT]):
|
|
41
49
|
inp_prompt: LLMPrompt | None,
|
42
50
|
sys_args_schema: type[LLMPromptArgs],
|
43
51
|
usr_args_schema: type[LLMPromptArgs],
|
44
|
-
rcv_args_schema: type[InT],
|
45
52
|
):
|
53
|
+
self._in_type: type[InT]
|
54
|
+
super().__init__()
|
55
|
+
|
46
56
|
self._agent_id = agent_id
|
47
57
|
self.sys_prompt = sys_prompt
|
48
58
|
self.inp_prompt = inp_prompt
|
49
59
|
self.sys_args_schema = sys_args_schema
|
50
60
|
self.usr_args_schema = usr_args_schema
|
51
|
-
self.rcv_args_schema = rcv_args_schema
|
52
61
|
self.format_sys_args_impl: FormatSystemArgsHandler[CtxT] | None = None
|
53
62
|
self.format_inp_args_impl: FormatInputArgsHandler[InT, CtxT] | None = None
|
54
63
|
|
64
|
+
self._rcv_args_type_adapter: TypeAdapter[InT] = TypeAdapter(self._in_type)
|
65
|
+
|
55
66
|
def _format_sys_args(
|
56
67
|
self,
|
57
68
|
sys_args: LLMPromptArgs,
|
@@ -73,9 +84,18 @@ class PromptBuilder(Generic[InT, CtxT]):
|
|
73
84
|
usr_args=usr_args, rcv_args=rcv_args, ctx=ctx
|
74
85
|
)
|
75
86
|
|
76
|
-
|
77
|
-
|
78
|
-
|
87
|
+
if not isinstance(rcv_args, BaseModel) and rcv_args is not None:
|
88
|
+
raise TypeError(
|
89
|
+
"Cannot apply default formatting to non-BaseModel received arguments."
|
90
|
+
)
|
91
|
+
|
92
|
+
usr_args_ = usr_args
|
93
|
+
rcv_args_ = DummySchema() if rcv_args is None else rcv_args
|
94
|
+
|
95
|
+
usr_args_dump = usr_args_.model_dump(exclude_unset=True)
|
96
|
+
rcv_args_dump = rcv_args_.model_dump(exclude={"selected_recipient_ids"})
|
97
|
+
|
98
|
+
return usr_args_dump | rcv_args_dump
|
79
99
|
|
80
100
|
def make_sys_prompt(
|
81
101
|
self,
|
@@ -100,20 +120,18 @@ class PromptBuilder(Generic[InT, CtxT]):
|
|
100
120
|
def _usr_messages_from_rcv_args(
|
101
121
|
self, rcv_args_batch: Sequence[InT]
|
102
122
|
) -> list[UserMessage]:
|
103
|
-
val_rcv_args_batch = [
|
104
|
-
self.rcv_args_schema.model_validate(rcv) for rcv in rcv_args_batch
|
105
|
-
]
|
106
|
-
|
107
123
|
return [
|
108
124
|
UserMessage.from_text(
|
109
|
-
|
125
|
+
self._rcv_args_type_adapter.dump_json(
|
126
|
+
rcv,
|
110
127
|
exclude_unset=True,
|
111
128
|
indent=2,
|
112
129
|
exclude={"selected_recipient_ids"},
|
113
|
-
|
130
|
+
warnings="error",
|
131
|
+
).decode("utf-8"),
|
114
132
|
model_id=self._agent_id,
|
115
133
|
)
|
116
|
-
for rcv in
|
134
|
+
for rcv in rcv_args_batch
|
117
135
|
]
|
118
136
|
|
119
137
|
def _usr_messages_from_prompt_template(
|
@@ -123,17 +141,19 @@ class PromptBuilder(Generic[InT, CtxT]):
|
|
123
141
|
rcv_args_batch: Sequence[InT] | None = None,
|
124
142
|
ctx: RunContextWrapper[CtxT] | None = None,
|
125
143
|
) -> Sequence[UserMessage]:
|
126
|
-
|
127
|
-
|
128
|
-
|
144
|
+
usr_args_batch_, rcv_args_batch_ = self._make_batched(usr_args, rcv_args_batch)
|
145
|
+
|
146
|
+
val_usr_args_batch_ = [
|
147
|
+
self.usr_args_schema.model_validate(u) for u in usr_args_batch_
|
129
148
|
]
|
130
|
-
|
131
|
-
self.
|
149
|
+
val_rcv_args_batch_ = [
|
150
|
+
self._rcv_args_type_adapter.validate_python(rcv) for rcv in rcv_args_batch_
|
132
151
|
]
|
152
|
+
|
133
153
|
formatted_inp_args_batch = [
|
134
154
|
self._format_inp_args(usr_args=val_usr_args, rcv_args=val_rcv_args, ctx=ctx)
|
135
155
|
for val_usr_args, val_rcv_args in zip(
|
136
|
-
|
156
|
+
val_usr_args_batch_, val_rcv_args_batch_, strict=False
|
137
157
|
)
|
138
158
|
]
|
139
159
|
|
@@ -187,11 +207,11 @@ class PromptBuilder(Generic[InT, CtxT]):
|
|
187
207
|
self,
|
188
208
|
usr_args: UserRunArgs | None = None,
|
189
209
|
rcv_args_batch: Sequence[InT] | None = None,
|
190
|
-
) -> tuple[Sequence[LLMPromptArgs], Sequence[InT]]:
|
210
|
+
) -> tuple[Sequence[LLMPromptArgs | DummySchema], Sequence[InT | DummySchema]]:
|
191
211
|
usr_args_batch_ = (
|
192
|
-
usr_args if isinstance(usr_args, list) else [usr_args or
|
212
|
+
usr_args if isinstance(usr_args, list) else [usr_args or DummySchema()]
|
193
213
|
)
|
194
|
-
rcv_args_batch_ = rcv_args_batch or [
|
214
|
+
rcv_args_batch_ = rcv_args_batch or [DummySchema()]
|
195
215
|
|
196
216
|
# Broadcast singleton → match lengths
|
197
217
|
if len(usr_args_batch_) == 1 and len(rcv_args_batch_) > 1:
|
@@ -201,4 +221,4 @@ class PromptBuilder(Generic[InT, CtxT]):
|
|
201
221
|
if len(usr_args_batch_) != len(rcv_args_batch_):
|
202
222
|
raise ValueError("User args and received args must have the same length")
|
203
223
|
|
204
|
-
return usr_args_batch_, rcv_args_batch_
|
224
|
+
return usr_args_batch_, rcv_args_batch_
|
@@ -2,6 +2,7 @@ import asyncio
|
|
2
2
|
import functools
|
3
3
|
import logging
|
4
4
|
from collections.abc import Callable, Coroutine
|
5
|
+
from dataclasses import dataclass
|
5
6
|
from time import monotonic
|
6
7
|
from typing import Any, Generic, overload
|
7
8
|
|
@@ -9,21 +10,25 @@ from tqdm.autonotebook import tqdm
|
|
9
10
|
|
10
11
|
from ..utils import asyncio_gather_with_pbar
|
11
12
|
from .types import (
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
RetrievalCallableSingle,
|
13
|
+
P,
|
14
|
+
ProcessorCallableList,
|
15
|
+
ProcessorCallableSingle,
|
16
|
+
R,
|
17
|
+
RateLimWrapperWithArgsList,
|
18
|
+
RateLimWrapperWithArgsSingle,
|
19
|
+
T,
|
20
20
|
)
|
21
|
-
from .utils import
|
21
|
+
from .utils import partial_processor_callable, split_pos_args
|
22
22
|
|
23
23
|
logger = logging.getLogger(__name__)
|
24
24
|
|
25
25
|
|
26
|
-
|
26
|
+
@dataclass
|
27
|
+
class RateLimiterState:
|
28
|
+
next_request_time: float = 0.0
|
29
|
+
|
30
|
+
|
31
|
+
class RateLimiterC(Generic[T, R]):
|
27
32
|
def __init__(
|
28
33
|
self,
|
29
34
|
rpm: float,
|
@@ -40,9 +45,9 @@ class RateLimiterC(Generic[QueryT, QueryR]):
|
|
40
45
|
|
41
46
|
async def process_input(
|
42
47
|
self,
|
43
|
-
func_partial: Callable[[
|
44
|
-
inp:
|
45
|
-
) ->
|
48
|
+
func_partial: Callable[[T], Coroutine[Any, Any, R]],
|
49
|
+
inp: T,
|
50
|
+
) -> R:
|
46
51
|
async with self._semaphore:
|
47
52
|
async with self._lock:
|
48
53
|
now = monotonic()
|
@@ -53,11 +58,11 @@ class RateLimiterC(Generic[QueryT, QueryR]):
|
|
53
58
|
|
54
59
|
async def process_inputs(
|
55
60
|
self,
|
56
|
-
func_partial: Callable[[
|
57
|
-
inputs: list[
|
61
|
+
func_partial: Callable[[T], Coroutine[Any, Any, R]],
|
62
|
+
inputs: list[T],
|
58
63
|
no_tqdm: bool = False,
|
59
|
-
) -> list[
|
60
|
-
results: list[
|
64
|
+
) -> list[R]:
|
65
|
+
results: list[R] = []
|
61
66
|
for i in tqdm(
|
62
67
|
range(0, len(inputs), self._chunk_size),
|
63
68
|
disable=no_tqdm,
|
@@ -95,33 +100,30 @@ class RateLimiterC(Generic[QueryT, QueryR]):
|
|
95
100
|
|
96
101
|
@overload
|
97
102
|
def limit_rate(
|
98
|
-
call:
|
99
|
-
rate_limiter: RateLimiterC[
|
100
|
-
) ->
|
103
|
+
call: ProcessorCallableSingle[T, P, R],
|
104
|
+
rate_limiter: RateLimiterC[T, R] | None = None,
|
105
|
+
) -> ProcessorCallableSingle[T, P, R]: ...
|
101
106
|
|
102
107
|
|
103
108
|
@overload
|
104
109
|
def limit_rate(
|
105
110
|
call: None = None,
|
106
|
-
rate_limiter: RateLimiterC[
|
107
|
-
) ->
|
111
|
+
rate_limiter: RateLimiterC[T, R] | None = None,
|
112
|
+
) -> RateLimWrapperWithArgsSingle[T, P, R]: ...
|
108
113
|
|
109
114
|
|
110
115
|
def limit_rate(
|
111
|
-
call:
|
112
|
-
rate_limiter: RateLimiterC[
|
113
|
-
) ->
|
114
|
-
RetrievalCallableSingle[QueryT, QueryP, QueryR]
|
115
|
-
| RateLimDecoratorWithArgsSingle[QueryT, QueryP, QueryR]
|
116
|
-
):
|
116
|
+
call: ProcessorCallableSingle[T, P, R] | None = None,
|
117
|
+
rate_limiter: RateLimiterC[T, R] | None = None,
|
118
|
+
) -> ProcessorCallableSingle[T, P, R] | RateLimWrapperWithArgsSingle[T, P, R]:
|
117
119
|
if call is None:
|
118
120
|
return functools.partial(limit_rate, rate_limiter=rate_limiter)
|
119
121
|
|
120
122
|
@functools.wraps(call) # type: ignore
|
121
|
-
async def wrapper(*args: Any, **kwargs: Any) ->
|
122
|
-
inp:
|
123
|
-
self_obj, inp, other_args = split_pos_args(call, args)
|
124
|
-
call_partial =
|
123
|
+
async def wrapper(*args: Any, **kwargs: Any) -> R:
|
124
|
+
inp: T
|
125
|
+
self_obj, inp, other_args = split_pos_args(call, args)
|
126
|
+
call_partial = partial_processor_callable(call, self_obj, *other_args, **kwargs)
|
125
127
|
|
126
128
|
_rate_limiter = rate_limiter
|
127
129
|
if _rate_limiter is None:
|
@@ -136,39 +138,36 @@ def limit_rate(
|
|
136
138
|
|
137
139
|
@overload
|
138
140
|
def limit_rate_chunked(
|
139
|
-
call:
|
140
|
-
rate_limiter: RateLimiterC[
|
141
|
+
call: ProcessorCallableList[T, P, R],
|
142
|
+
rate_limiter: RateLimiterC[T, R] | None = None,
|
141
143
|
no_tqdm: bool | None = None,
|
142
|
-
) ->
|
144
|
+
) -> ProcessorCallableList[T, P, R]: ...
|
143
145
|
|
144
146
|
|
145
147
|
@overload
|
146
148
|
def limit_rate_chunked(
|
147
149
|
call: None = None,
|
148
|
-
rate_limiter: RateLimiterC[
|
150
|
+
rate_limiter: RateLimiterC[T, R] | None = None,
|
149
151
|
no_tqdm: bool | None = None,
|
150
|
-
) ->
|
152
|
+
) -> RateLimWrapperWithArgsList[T, P, R]: ...
|
151
153
|
|
152
154
|
|
153
155
|
def limit_rate_chunked(
|
154
|
-
call:
|
155
|
-
rate_limiter: RateLimiterC[
|
156
|
+
call: ProcessorCallableList[T, P, R] | None = None,
|
157
|
+
rate_limiter: RateLimiterC[T, R] | None = None,
|
156
158
|
no_tqdm: bool | None = None,
|
157
|
-
) ->
|
158
|
-
RetrievalCallableList[QueryT, QueryP, QueryR]
|
159
|
-
| RateLimDecoratorWithArgsList[QueryT, QueryP, QueryR]
|
160
|
-
):
|
159
|
+
) -> ProcessorCallableList[T, P, R] | RateLimWrapperWithArgsList[T, P, R]:
|
161
160
|
if call is None:
|
162
161
|
return functools.partial(
|
163
162
|
limit_rate_chunked, rate_limiter=rate_limiter, no_tqdm=no_tqdm
|
164
|
-
)
|
163
|
+
) # type: ignore
|
165
164
|
|
166
165
|
@functools.wraps(call) # type: ignore
|
167
|
-
async def wrapper(*args: Any, **kwargs: Any) -> list[
|
166
|
+
async def wrapper(*args: Any, **kwargs: Any) -> list[R]:
|
168
167
|
assert call is not None
|
169
168
|
|
170
|
-
self_obj, inputs, other_args = split_pos_args(call, args)
|
171
|
-
call_partial =
|
169
|
+
self_obj, inputs, other_args = split_pos_args(call, args)
|
170
|
+
call_partial = partial_processor_callable(call, self_obj, *other_args, **kwargs)
|
172
171
|
|
173
172
|
_no_tqdm = no_tqdm
|
174
173
|
_rate_limiter = rate_limiter
|
@@ -182,7 +181,9 @@ def limit_rate_chunked(
|
|
182
181
|
*[call_partial(inp) for inp in inputs], no_tqdm=_no_tqdm
|
183
182
|
)
|
184
183
|
return await _rate_limiter.process_inputs(
|
185
|
-
func_partial=call_partial,
|
184
|
+
func_partial=call_partial, # type: ignore
|
185
|
+
inputs=inputs,
|
186
|
+
no_tqdm=_no_tqdm,
|
186
187
|
)
|
187
188
|
|
188
189
|
return wrapper
|
@@ -1,57 +1,36 @@
|
|
1
1
|
from collections.abc import Callable, Coroutine
|
2
|
-
from
|
3
|
-
from typing import (
|
4
|
-
Any,
|
5
|
-
Concatenate,
|
6
|
-
ParamSpec,
|
7
|
-
TypeAlias,
|
8
|
-
TypeVar,
|
9
|
-
)
|
10
|
-
|
11
|
-
MAX_RPM = 1e10
|
12
|
-
|
2
|
+
from typing import Any, Concatenate, ParamSpec, TypeAlias, TypeVar
|
13
3
|
|
14
|
-
|
15
|
-
|
16
|
-
|
4
|
+
T = TypeVar("T")
|
5
|
+
R = TypeVar("R")
|
6
|
+
P = ParamSpec("P")
|
17
7
|
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
QueryP = ParamSpec("QueryP")
|
22
|
-
|
23
|
-
RetrievalFuncSingle: TypeAlias = Callable[
|
24
|
-
Concatenate[QueryT, QueryP], Coroutine[Any, Any, QueryR]
|
25
|
-
]
|
26
|
-
RetrievalFuncList: TypeAlias = Callable[
|
27
|
-
Concatenate[list[QueryT], QueryP], Coroutine[Any, Any, list[QueryR]]
|
8
|
+
ProcessorFuncSingle: TypeAlias = Callable[Concatenate[T, P], Coroutine[Any, Any, R]]
|
9
|
+
ProcessorFuncList: TypeAlias = Callable[
|
10
|
+
Concatenate[list[T], P], Coroutine[Any, Any, list[R]]
|
28
11
|
]
|
29
12
|
|
30
|
-
|
31
|
-
Concatenate[Any,
|
13
|
+
ProcessorMethodSingle: TypeAlias = Callable[
|
14
|
+
Concatenate[Any, T, P], Coroutine[Any, Any, R]
|
32
15
|
]
|
33
|
-
|
34
|
-
Concatenate[Any, list[
|
16
|
+
ProcessorMethodList: TypeAlias = Callable[
|
17
|
+
Concatenate[Any, list[T], P], Coroutine[Any, Any, list[R]]
|
35
18
|
]
|
36
19
|
|
37
|
-
|
38
|
-
|
39
|
-
| RetrievalMethodSingle[QueryT, QueryP, QueryR]
|
20
|
+
ProcessorCallableSingle: TypeAlias = (
|
21
|
+
ProcessorFuncSingle[T, P, R] | ProcessorMethodSingle[T, P, R]
|
40
22
|
)
|
41
23
|
|
42
|
-
|
43
|
-
|
44
|
-
| RetrievalMethodList[QueryT, QueryP, QueryR]
|
24
|
+
ProcessorCallableList: TypeAlias = (
|
25
|
+
ProcessorFuncList[T, P, R] | ProcessorMethodList[T, P, R]
|
45
26
|
)
|
46
27
|
|
47
28
|
|
48
|
-
|
49
|
-
[
|
50
|
-
RetrievalCallableSingle[QueryT, QueryP, QueryR],
|
29
|
+
RateLimWrapperWithArgsSingle = Callable[
|
30
|
+
[ProcessorCallableSingle[T, P, R]], ProcessorCallableSingle[T, P, R]
|
51
31
|
]
|
52
32
|
|
53
33
|
|
54
|
-
|
55
|
-
[
|
56
|
-
RetrievalCallableList[QueryT, QueryP, QueryR],
|
34
|
+
RateLimWrapperWithArgsList = Callable[
|
35
|
+
[ProcessorCallableList[T, P, R]], ProcessorCallableList[T, P, R]
|
57
36
|
]
|