grasp_agents 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_agents/agent_message.py +28 -0
- grasp_agents/agent_message_pool.py +94 -0
- grasp_agents/base_agent.py +72 -0
- grasp_agents/cloud_llm.py +353 -0
- grasp_agents/comm_agent.py +230 -0
- grasp_agents/costs_dict.yaml +122 -0
- grasp_agents/data_retrieval/__init__.py +7 -0
- grasp_agents/data_retrieval/rate_limiter_chunked.py +195 -0
- grasp_agents/data_retrieval/types.py +57 -0
- grasp_agents/data_retrieval/utils.py +57 -0
- grasp_agents/grasp_logging.py +36 -0
- grasp_agents/http_client.py +24 -0
- grasp_agents/llm.py +106 -0
- grasp_agents/llm_agent.py +361 -0
- grasp_agents/llm_agent_state.py +73 -0
- grasp_agents/memory.py +150 -0
- grasp_agents/openai/__init__.py +83 -0
- grasp_agents/openai/completion_converters.py +49 -0
- grasp_agents/openai/content_converters.py +80 -0
- grasp_agents/openai/converters.py +170 -0
- grasp_agents/openai/message_converters.py +155 -0
- grasp_agents/openai/openai_llm.py +179 -0
- grasp_agents/openai/tool_converters.py +37 -0
- grasp_agents/printer.py +156 -0
- grasp_agents/prompt_builder.py +204 -0
- grasp_agents/run_context.py +90 -0
- grasp_agents/tool_orchestrator.py +181 -0
- grasp_agents/typing/__init__.py +0 -0
- grasp_agents/typing/completion.py +30 -0
- grasp_agents/typing/content.py +116 -0
- grasp_agents/typing/converters.py +118 -0
- grasp_agents/typing/io.py +32 -0
- grasp_agents/typing/message.py +130 -0
- grasp_agents/typing/tool.py +52 -0
- grasp_agents/usage_tracker.py +99 -0
- grasp_agents/utils.py +151 -0
- grasp_agents/workflow/__init__.py +0 -0
- grasp_agents/workflow/looped_agent.py +113 -0
- grasp_agents/workflow/sequential_agent.py +57 -0
- grasp_agents/workflow/workflow_agent.py +69 -0
- grasp_agents-0.1.5.dist-info/METADATA +14 -0
- grasp_agents-0.1.5.dist-info/RECORD +44 -0
- grasp_agents-0.1.5.dist-info/WHEEL +4 -0
- grasp_agents-0.1.5.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,230 @@
|
|
1
|
+
import logging
|
2
|
+
from abc import abstractmethod
|
3
|
+
from collections.abc import Sequence
|
4
|
+
from typing import Any, Generic, Protocol, TypeVar, cast, final
|
5
|
+
|
6
|
+
from pydantic import BaseModel
|
7
|
+
|
8
|
+
from .agent_message import AgentMessage
|
9
|
+
from .agent_message_pool import AgentMessagePool
|
10
|
+
from .base_agent import BaseAgent
|
11
|
+
from .run_context import CtxT, RunContextWrapper
|
12
|
+
from .typing.io import AgentID, AgentPayload, AgentState, InT, OutT, StateT
|
13
|
+
from .typing.tool import BaseTool
|
14
|
+
|
15
|
+
logger = logging.getLogger(__name__)
|
16
|
+
|
17
|
+
_EH_OutT = TypeVar("_EH_OutT", bound=AgentPayload, contravariant=True) # noqa: PLC0105
|
18
|
+
_EH_StateT = TypeVar("_EH_StateT", bound=AgentState, contravariant=True) # noqa: PLC0105
|
19
|
+
|
20
|
+
|
21
|
+
class ExitHandler(Protocol[_EH_OutT, _EH_StateT, CtxT]):
|
22
|
+
def __call__(
|
23
|
+
self,
|
24
|
+
output_message: AgentMessage[_EH_OutT, _EH_StateT],
|
25
|
+
agent_state: _EH_StateT,
|
26
|
+
ctx: RunContextWrapper[CtxT] | None,
|
27
|
+
) -> bool: ...
|
28
|
+
|
29
|
+
|
30
|
+
class CommunicatingAgent(
|
31
|
+
BaseAgent[OutT, StateT, CtxT], Generic[InT, OutT, StateT, CtxT]
|
32
|
+
):
|
33
|
+
def __init__(
|
34
|
+
self,
|
35
|
+
agent_id: AgentID,
|
36
|
+
*,
|
37
|
+
out_schema: type[OutT] = AgentPayload,
|
38
|
+
rcv_args_schema: type[InT] = AgentPayload,
|
39
|
+
recipient_ids: Sequence[AgentID] | None = None,
|
40
|
+
message_pool: AgentMessagePool[CtxT] | None = None,
|
41
|
+
dynamic_routing: bool = False,
|
42
|
+
**kwargs: Any,
|
43
|
+
) -> None:
|
44
|
+
super().__init__(agent_id=agent_id, out_schema=out_schema, **kwargs)
|
45
|
+
self._message_pool = message_pool or AgentMessagePool()
|
46
|
+
|
47
|
+
self._dynamic_routing = dynamic_routing
|
48
|
+
|
49
|
+
self._is_listening = False
|
50
|
+
self._exit_impl: ExitHandler[OutT, StateT, CtxT] | None = None
|
51
|
+
|
52
|
+
self._rcv_args_schema = rcv_args_schema
|
53
|
+
self.recipient_ids = recipient_ids or []
|
54
|
+
|
55
|
+
@property
|
56
|
+
def rcv_args_schema(self) -> type[InT]: # type: ignore[reportInvalidTypeVarUse]
|
57
|
+
return self._rcv_args_schema
|
58
|
+
|
59
|
+
@property
|
60
|
+
def dynamic_routing(self) -> bool:
|
61
|
+
return self._dynamic_routing
|
62
|
+
|
63
|
+
def _parse_output(
|
64
|
+
self,
|
65
|
+
*args: Any,
|
66
|
+
rcv_args: InT | None = None,
|
67
|
+
ctx: RunContextWrapper[CtxT] | None = None,
|
68
|
+
**kwargs: Any,
|
69
|
+
) -> OutT:
|
70
|
+
if self._parse_output_impl:
|
71
|
+
return self._parse_output_impl(*args, rcv_args=rcv_args, ctx=ctx, **kwargs)
|
72
|
+
|
73
|
+
return self._out_schema()
|
74
|
+
|
75
|
+
def _validate_dynamic_routing(self, payloads: Sequence[OutT]) -> Sequence[AgentID]:
|
76
|
+
assert all((p.selected_recipient_ids is not None) for p in payloads), (
|
77
|
+
"Dynamic routing is enabled, but some payloads have no recipient IDs"
|
78
|
+
)
|
79
|
+
|
80
|
+
selected_recipient_ids_per_payload = [
|
81
|
+
set(p.selected_recipient_ids or []) for p in payloads
|
82
|
+
]
|
83
|
+
assert all(
|
84
|
+
x == selected_recipient_ids_per_payload[0]
|
85
|
+
for x in selected_recipient_ids_per_payload
|
86
|
+
), "All payloads must have the same recipient IDs for dynamic routing"
|
87
|
+
|
88
|
+
assert payloads[0].selected_recipient_ids is not None
|
89
|
+
selected_recipient_ids = payloads[0].selected_recipient_ids
|
90
|
+
|
91
|
+
assert all(rid in self.recipient_ids for rid in selected_recipient_ids), (
|
92
|
+
"Dynamic routing is enabled, but recipient IDs are not in "
|
93
|
+
"the allowed agent's recipient IDs"
|
94
|
+
)
|
95
|
+
|
96
|
+
return selected_recipient_ids
|
97
|
+
|
98
|
+
def _validate_static_routing(self, payloads: Sequence[OutT]) -> Sequence[AgentID]:
|
99
|
+
assert all((p.selected_recipient_ids is None) for p in payloads), (
|
100
|
+
"Dynamic routing is not enabled, but some payloads have recipient IDs"
|
101
|
+
)
|
102
|
+
|
103
|
+
return self.recipient_ids
|
104
|
+
|
105
|
+
async def post_message(self, message: AgentMessage[OutT, StateT]) -> None:
|
106
|
+
if self._dynamic_routing:
|
107
|
+
self._validate_dynamic_routing(message.payloads)
|
108
|
+
else:
|
109
|
+
self._validate_static_routing(message.payloads)
|
110
|
+
|
111
|
+
await self._message_pool.post(message)
|
112
|
+
|
113
|
+
@abstractmethod
|
114
|
+
async def run(
|
115
|
+
self,
|
116
|
+
inp_items: Any | None = None,
|
117
|
+
*,
|
118
|
+
ctx: RunContextWrapper[CtxT] | None = None,
|
119
|
+
rcv_message: AgentMessage[InT, StateT] | None = None,
|
120
|
+
entry_point: bool = False,
|
121
|
+
forbid_state_change: bool = False,
|
122
|
+
**kwargs: Any,
|
123
|
+
) -> AgentMessage[OutT, StateT]:
|
124
|
+
pass
|
125
|
+
|
126
|
+
async def run_and_post(
|
127
|
+
self, ctx: RunContextWrapper[CtxT] | None = None, **run_kwargs: Any
|
128
|
+
) -> None:
|
129
|
+
output_message = await self.run(
|
130
|
+
ctx=ctx, rcv_message=None, entry_point=True, **run_kwargs
|
131
|
+
)
|
132
|
+
await self.post_message(output_message)
|
133
|
+
|
134
|
+
def exit_handler(
|
135
|
+
self, func: ExitHandler[OutT, StateT, CtxT]
|
136
|
+
) -> ExitHandler[OutT, StateT, CtxT]:
|
137
|
+
self._exit_impl = func
|
138
|
+
|
139
|
+
return func
|
140
|
+
|
141
|
+
def _exit_condition(
|
142
|
+
self,
|
143
|
+
output_message: AgentMessage[OutT, StateT],
|
144
|
+
ctx: RunContextWrapper[CtxT] | None,
|
145
|
+
) -> bool:
|
146
|
+
if self._exit_impl:
|
147
|
+
return self._exit_impl(
|
148
|
+
output_message=output_message, agent_state=self.state, ctx=ctx
|
149
|
+
)
|
150
|
+
|
151
|
+
return False
|
152
|
+
|
153
|
+
async def _message_handler(
|
154
|
+
self,
|
155
|
+
message: AgentMessage[AgentPayload, AgentState],
|
156
|
+
ctx: RunContextWrapper[CtxT] | None = None,
|
157
|
+
**run_kwargs: Any,
|
158
|
+
) -> None:
|
159
|
+
rcv_message = cast("AgentMessage[InT, StateT]", message)
|
160
|
+
out_message = await self.run(ctx=ctx, rcv_message=rcv_message, **run_kwargs)
|
161
|
+
|
162
|
+
if self._exit_condition(output_message=out_message, ctx=ctx):
|
163
|
+
await self._message_pool.stop_all()
|
164
|
+
return
|
165
|
+
|
166
|
+
if self.recipient_ids:
|
167
|
+
await self.post_message(out_message)
|
168
|
+
|
169
|
+
@property
|
170
|
+
def is_listening(self) -> bool:
|
171
|
+
return self._is_listening
|
172
|
+
|
173
|
+
async def start_listening(
|
174
|
+
self, ctx: RunContextWrapper[CtxT] | None = None, **run_kwargs: Any
|
175
|
+
) -> None:
|
176
|
+
if self._is_listening:
|
177
|
+
return
|
178
|
+
|
179
|
+
self._is_listening = True
|
180
|
+
self._message_pool.register_message_handler(
|
181
|
+
agent_id=self.agent_id,
|
182
|
+
handler=self._message_handler,
|
183
|
+
ctx=ctx,
|
184
|
+
**run_kwargs,
|
185
|
+
)
|
186
|
+
|
187
|
+
async def stop_listening(self) -> None:
|
188
|
+
self._is_listening = False
|
189
|
+
await self._message_pool.unregister_message_handler(self.agent_id)
|
190
|
+
|
191
|
+
@final
|
192
|
+
def as_tool(
|
193
|
+
self, tool_name: str, tool_description: str, tool_strict: bool = True
|
194
|
+
) -> BaseTool[BaseModel, BaseModel, CtxT]:
|
195
|
+
# assert self.state.batch_size == 1, (
|
196
|
+
# "Using agents as tools is only supported for batch size 1"
|
197
|
+
# )
|
198
|
+
|
199
|
+
agent_instance = self
|
200
|
+
|
201
|
+
class AgentTool(BaseTool[BaseModel, BaseModel, Any]):
|
202
|
+
name: str = tool_name
|
203
|
+
description: str = tool_description
|
204
|
+
in_schema: type[BaseModel] = agent_instance.rcv_args_schema
|
205
|
+
out_schema: type[BaseModel] = agent_instance.out_schema
|
206
|
+
|
207
|
+
strict: bool | None = tool_strict
|
208
|
+
|
209
|
+
async def run(
|
210
|
+
self,
|
211
|
+
inp: BaseModel,
|
212
|
+
ctx: RunContextWrapper[CtxT] | None = None,
|
213
|
+
) -> OutT:
|
214
|
+
rcv_args = agent_instance.rcv_args_schema.model_validate(inp)
|
215
|
+
rcv_message = AgentMessage( # type: ignore[arg-type]
|
216
|
+
payloads=[rcv_args],
|
217
|
+
sender_id="<tool_user>",
|
218
|
+
recipient_ids=[agent_instance.agent_id],
|
219
|
+
)
|
220
|
+
|
221
|
+
agent_result = await agent_instance.run(
|
222
|
+
rcv_message=rcv_message, # type: ignore[arg-type]
|
223
|
+
entry_point=False,
|
224
|
+
forbid_state_change=True,
|
225
|
+
ctx=ctx,
|
226
|
+
)
|
227
|
+
|
228
|
+
return agent_result.payloads[0]
|
229
|
+
|
230
|
+
return AgentTool()
|
@@ -0,0 +1,122 @@
|
|
1
|
+
costs:
|
2
|
+
google/gemini-2.5-flash-preview:
|
3
|
+
input: 0.15
|
4
|
+
output: 0.60
|
5
|
+
cached_discount: 0.5
|
6
|
+
gpt-4.1:
|
7
|
+
input: 2.0
|
8
|
+
output: 8.0
|
9
|
+
cached_discount: 0.5
|
10
|
+
gpt-4.1-nano-2025-04-14:
|
11
|
+
input: 0.1
|
12
|
+
output: 0.4
|
13
|
+
cached_discount: 0.5
|
14
|
+
gpt-4.1-mini-2025-04-14:
|
15
|
+
input: 0.4
|
16
|
+
output: 1.6
|
17
|
+
cached_discount: 0.5
|
18
|
+
gpt-4.1-2025-04-14:
|
19
|
+
input: 2.0
|
20
|
+
output: 8.0
|
21
|
+
cached_discount: 0.5
|
22
|
+
google/gemini-2.5-pro-preview-03-25:
|
23
|
+
input: 1.25
|
24
|
+
output: 10
|
25
|
+
cached_discount: 0.5
|
26
|
+
gemini-2.5-pro-preview-03-25:
|
27
|
+
input: 1.25
|
28
|
+
output: 10
|
29
|
+
cached_discount: 0.5
|
30
|
+
llama-3.3-70b-versatile:
|
31
|
+
input: 0.59
|
32
|
+
output: 0.79
|
33
|
+
cached_discount: 0.5
|
34
|
+
gemini-1.5-pro:
|
35
|
+
input: 1.25
|
36
|
+
output: 5
|
37
|
+
cached_discount: 0.5
|
38
|
+
gemini-2.0-flash:
|
39
|
+
input: 0.1
|
40
|
+
output: 0.4
|
41
|
+
cached_discount: 0.5
|
42
|
+
anthropic/claude-3.5-sonnet:
|
43
|
+
input: 3.0
|
44
|
+
output: 15
|
45
|
+
cached_discount: 0.5
|
46
|
+
anthropic/claude-3.7-sonnet:
|
47
|
+
input: 3.0
|
48
|
+
output: 15
|
49
|
+
cached_discount: 0.5
|
50
|
+
google/gemini-2.0-flash-001:
|
51
|
+
input: 0.1
|
52
|
+
output: 0.4
|
53
|
+
cached_discount: 0.5
|
54
|
+
deepseek/deepseek-chat:
|
55
|
+
input: 0.4
|
56
|
+
output: 0.89
|
57
|
+
cached_discount: 0.5
|
58
|
+
o3-mini:
|
59
|
+
input: 1.1
|
60
|
+
output: 4.4
|
61
|
+
cached_discount: 0.5
|
62
|
+
o1:
|
63
|
+
input: 15
|
64
|
+
output: 60
|
65
|
+
cached_discount: 0.5
|
66
|
+
o1-2024-12-17:
|
67
|
+
input: 15
|
68
|
+
output: 60
|
69
|
+
cached_discount: 0.5
|
70
|
+
o1-preview:
|
71
|
+
input: 15
|
72
|
+
output: 60
|
73
|
+
cached_discount: 0.5
|
74
|
+
o1-preview-2024-09-12:
|
75
|
+
input: 15
|
76
|
+
output: 60
|
77
|
+
cached_discount: 0.5
|
78
|
+
o1-mini:
|
79
|
+
input: 1.1
|
80
|
+
output: 4.4
|
81
|
+
cached_discount: 0.5
|
82
|
+
o1-mini-2024-09-12:
|
83
|
+
input: 3
|
84
|
+
output: 12
|
85
|
+
cached_discount: 0.5
|
86
|
+
gpt-4o:
|
87
|
+
input: 2.5
|
88
|
+
output: 10
|
89
|
+
cached_discount: 0.5
|
90
|
+
gpt-4o-2024-05-13:
|
91
|
+
input: 5
|
92
|
+
output: 15
|
93
|
+
cached_discount: 0.5
|
94
|
+
gpt-4o-2024-08-06:
|
95
|
+
input: 2.5
|
96
|
+
output: 10
|
97
|
+
cached_discount: 0.5
|
98
|
+
gpt-4o-mini:
|
99
|
+
input: 0.15
|
100
|
+
output: 0.6
|
101
|
+
cached_discount: 0.5
|
102
|
+
gpt-4:
|
103
|
+
input: 30
|
104
|
+
output: 60
|
105
|
+
cached_discount: 0.5
|
106
|
+
gpt-4-turbo:
|
107
|
+
input: 10
|
108
|
+
output: 30
|
109
|
+
cached_discount: 0.5
|
110
|
+
gpt-4-turbo-preview:
|
111
|
+
input: 10
|
112
|
+
output: 30
|
113
|
+
cached_discount: 0.5
|
114
|
+
gpt-3.5-turbo:
|
115
|
+
input: 0.5
|
116
|
+
output: 1.5
|
117
|
+
cached_discount: 0.5
|
118
|
+
gpt-3.5-turbo-0125:
|
119
|
+
input: 0.5
|
120
|
+
output: 1.5
|
121
|
+
cached_discount: 0.5
|
122
|
+
date_of_last_update: 17-04-25
|
@@ -0,0 +1,195 @@
|
|
1
|
+
import asyncio
|
2
|
+
import functools
|
3
|
+
import logging
|
4
|
+
from collections.abc import Callable, Coroutine
|
5
|
+
from time import monotonic
|
6
|
+
from typing import (
|
7
|
+
Any,
|
8
|
+
Generic,
|
9
|
+
overload,
|
10
|
+
)
|
11
|
+
|
12
|
+
from tqdm.autonotebook import tqdm
|
13
|
+
|
14
|
+
from ..utils import asyncio_gather_with_pbar
|
15
|
+
|
16
|
+
from .types import (
|
17
|
+
QueryP,
|
18
|
+
QueryR,
|
19
|
+
QueryT,
|
20
|
+
RateLimDecoratorWithArgsList,
|
21
|
+
RateLimDecoratorWithArgsSingle,
|
22
|
+
RateLimiterState,
|
23
|
+
RetrievalCallableList,
|
24
|
+
RetrievalCallableSingle,
|
25
|
+
)
|
26
|
+
from .utils import partial_retrieval_callable, split_pos_args
|
27
|
+
|
28
|
+
logger = logging.getLogger(__name__)
|
29
|
+
|
30
|
+
|
31
|
+
class RateLimiterC(Generic[QueryT, QueryR]):
|
32
|
+
def __init__(
|
33
|
+
self,
|
34
|
+
rpm: float,
|
35
|
+
chunk_size: int = 1000,
|
36
|
+
max_concurrency: int = 200,
|
37
|
+
):
|
38
|
+
self._rpm = rpm
|
39
|
+
self._max_concurrency = max_concurrency
|
40
|
+
self._chunk_size = chunk_size
|
41
|
+
|
42
|
+
self._lock = asyncio.Lock()
|
43
|
+
self._state = RateLimiterState(next_request_time=0.0)
|
44
|
+
self._semaphore = asyncio.Semaphore(self._max_concurrency)
|
45
|
+
|
46
|
+
async def process_input(
|
47
|
+
self,
|
48
|
+
func_partial: Callable[[QueryT], Coroutine[Any, Any, QueryR]],
|
49
|
+
inp: QueryT,
|
50
|
+
) -> QueryR:
|
51
|
+
async with self._semaphore:
|
52
|
+
async with self._lock:
|
53
|
+
now = monotonic()
|
54
|
+
if now < self._state.next_request_time:
|
55
|
+
await asyncio.sleep(self._state.next_request_time - now)
|
56
|
+
self._state.next_request_time = monotonic() + 1.01 * 60.0 / self._rpm
|
57
|
+
result = await func_partial(inp)
|
58
|
+
|
59
|
+
return result
|
60
|
+
|
61
|
+
async def process_inputs(
|
62
|
+
self,
|
63
|
+
func_partial: Callable[[QueryT], Coroutine[Any, Any, QueryR]],
|
64
|
+
inputs: list[QueryT],
|
65
|
+
no_tqdm: bool = False,
|
66
|
+
) -> list[QueryR]:
|
67
|
+
results: list[QueryR] = []
|
68
|
+
for i in tqdm(
|
69
|
+
range(0, len(inputs), self._chunk_size),
|
70
|
+
disable=no_tqdm,
|
71
|
+
desc="Processing chunks",
|
72
|
+
):
|
73
|
+
chunk = inputs[i : i + self._chunk_size]
|
74
|
+
corouts = [
|
75
|
+
self.process_input(func_partial=func_partial, inp=inp) for inp in chunk
|
76
|
+
]
|
77
|
+
chunk_results = await asyncio.gather(*corouts)
|
78
|
+
results.extend(chunk_results)
|
79
|
+
|
80
|
+
return results
|
81
|
+
|
82
|
+
@property
|
83
|
+
def rpm(self) -> float:
|
84
|
+
return self._rpm
|
85
|
+
|
86
|
+
@rpm.setter
|
87
|
+
def rpm(self, value: float) -> None:
|
88
|
+
self._rpm = value
|
89
|
+
|
90
|
+
@property
|
91
|
+
def max_concurrency(self) -> int:
|
92
|
+
return self._max_concurrency
|
93
|
+
|
94
|
+
@property
|
95
|
+
def chunk_size(self) -> int:
|
96
|
+
return self._chunk_size
|
97
|
+
|
98
|
+
@property
|
99
|
+
def state(self) -> RateLimiterState:
|
100
|
+
return self._state
|
101
|
+
|
102
|
+
|
103
|
+
@overload
|
104
|
+
def limit_rate(
|
105
|
+
call: RetrievalCallableSingle[QueryT, QueryP, QueryR],
|
106
|
+
rate_limiter: RateLimiterC[QueryT, QueryR] | None = None,
|
107
|
+
) -> RetrievalCallableSingle[QueryT, QueryP, QueryR]: ...
|
108
|
+
|
109
|
+
|
110
|
+
@overload
|
111
|
+
def limit_rate(
|
112
|
+
call: None = None,
|
113
|
+
rate_limiter: RateLimiterC[QueryT, QueryR] | None = None,
|
114
|
+
) -> RateLimDecoratorWithArgsSingle[QueryT, QueryP, QueryR]: ...
|
115
|
+
|
116
|
+
|
117
|
+
def limit_rate(
|
118
|
+
call: RetrievalCallableSingle[QueryT, QueryP, QueryR] | None = None,
|
119
|
+
rate_limiter: RateLimiterC[QueryT, QueryR] | None = None,
|
120
|
+
) -> (
|
121
|
+
RetrievalCallableSingle[QueryT, QueryP, QueryR]
|
122
|
+
| RateLimDecoratorWithArgsSingle[QueryT, QueryP, QueryR]
|
123
|
+
):
|
124
|
+
if call is None:
|
125
|
+
return functools.partial(limit_rate, rate_limiter=rate_limiter)
|
126
|
+
|
127
|
+
@functools.wraps(call) # type: ignore
|
128
|
+
async def wrapper(*args: Any, **kwargs: Any) -> QueryR:
|
129
|
+
inp: QueryT
|
130
|
+
self_obj, inp, other_args = split_pos_args(call, args) # type: ignore
|
131
|
+
call_partial = partial_retrieval_callable(call, self_obj, *other_args, **kwargs)
|
132
|
+
|
133
|
+
_rate_limiter = rate_limiter
|
134
|
+
if _rate_limiter is None:
|
135
|
+
_rate_limiter = getattr(self_obj, "rate_limiter", None)
|
136
|
+
|
137
|
+
if _rate_limiter is None:
|
138
|
+
return await call_partial(inp)
|
139
|
+
return await _rate_limiter.process_input(func_partial=call_partial, inp=inp)
|
140
|
+
|
141
|
+
return wrapper
|
142
|
+
|
143
|
+
|
144
|
+
@overload
|
145
|
+
def limit_rate_chunked(
|
146
|
+
call: RetrievalCallableSingle[QueryT, QueryP, QueryR],
|
147
|
+
rate_limiter: RateLimiterC[QueryT, QueryR] | None = None,
|
148
|
+
no_tqdm: bool | None = None,
|
149
|
+
) -> RetrievalCallableList[QueryT, QueryP, QueryR]: ...
|
150
|
+
|
151
|
+
|
152
|
+
@overload
|
153
|
+
def limit_rate_chunked(
|
154
|
+
call: None = None,
|
155
|
+
rate_limiter: RateLimiterC[QueryT, QueryR] | None = None,
|
156
|
+
no_tqdm: bool | None = None,
|
157
|
+
) -> RateLimDecoratorWithArgsList[QueryT, QueryP, QueryR]: ...
|
158
|
+
|
159
|
+
|
160
|
+
def limit_rate_chunked(
|
161
|
+
call: RetrievalCallableSingle[QueryT, QueryP, QueryR] | None = None,
|
162
|
+
rate_limiter: RateLimiterC[QueryT, QueryR] | None = None,
|
163
|
+
no_tqdm: bool | None = None,
|
164
|
+
) -> (
|
165
|
+
RetrievalCallableList[QueryT, QueryP, QueryR]
|
166
|
+
| RateLimDecoratorWithArgsList[QueryT, QueryP, QueryR]
|
167
|
+
):
|
168
|
+
if call is None:
|
169
|
+
return functools.partial(
|
170
|
+
limit_rate_chunked, rate_limiter=rate_limiter, no_tqdm=no_tqdm
|
171
|
+
)
|
172
|
+
|
173
|
+
@functools.wraps(call) # type: ignore
|
174
|
+
async def wrapper(*args: Any, **kwargs: Any) -> list[QueryR]:
|
175
|
+
assert call is not None
|
176
|
+
|
177
|
+
self_obj, inputs, other_args = split_pos_args(call, args) # type: ignore
|
178
|
+
call_partial = partial_retrieval_callable(call, self_obj, *other_args, **kwargs)
|
179
|
+
|
180
|
+
_no_tqdm = no_tqdm
|
181
|
+
_rate_limiter = rate_limiter
|
182
|
+
if _no_tqdm is None:
|
183
|
+
_no_tqdm = getattr(self_obj, "no_tqdm", False)
|
184
|
+
if _rate_limiter is None:
|
185
|
+
_rate_limiter = getattr(self_obj, "rate_limiter", None)
|
186
|
+
|
187
|
+
if _rate_limiter is None:
|
188
|
+
return await asyncio_gather_with_pbar(
|
189
|
+
*[call_partial(inp) for inp in inputs], no_tqdm=_no_tqdm
|
190
|
+
)
|
191
|
+
return await _rate_limiter.process_inputs(
|
192
|
+
func_partial=call_partial, inputs=inputs, no_tqdm=_no_tqdm
|
193
|
+
)
|
194
|
+
|
195
|
+
return wrapper
|
@@ -0,0 +1,57 @@
|
|
1
|
+
from collections.abc import Callable, Coroutine
|
2
|
+
from dataclasses import dataclass
|
3
|
+
from typing import (
|
4
|
+
Any,
|
5
|
+
Concatenate,
|
6
|
+
ParamSpec,
|
7
|
+
TypeAlias,
|
8
|
+
TypeVar,
|
9
|
+
)
|
10
|
+
|
11
|
+
MAX_RPM = 1e10
|
12
|
+
|
13
|
+
|
14
|
+
@dataclass
|
15
|
+
class RateLimiterState:
|
16
|
+
next_request_time: float = 0.0
|
17
|
+
|
18
|
+
|
19
|
+
QueryT = TypeVar("QueryT")
|
20
|
+
QueryR = TypeVar("QueryR")
|
21
|
+
QueryP = ParamSpec("QueryP")
|
22
|
+
|
23
|
+
RetrievalFuncSingle: TypeAlias = Callable[
|
24
|
+
Concatenate[QueryT, QueryP], Coroutine[Any, Any, QueryR]
|
25
|
+
]
|
26
|
+
RetrievalFuncList: TypeAlias = Callable[
|
27
|
+
Concatenate[list[QueryT], QueryP], Coroutine[Any, Any, list[QueryR]]
|
28
|
+
]
|
29
|
+
|
30
|
+
RetrievalMethodSingle: TypeAlias = Callable[
|
31
|
+
Concatenate[Any, QueryT, QueryP], Coroutine[Any, Any, QueryR]
|
32
|
+
]
|
33
|
+
RetrievalMethodList: TypeAlias = Callable[
|
34
|
+
Concatenate[Any, list[QueryT], QueryP], Coroutine[Any, Any, list[QueryR]]
|
35
|
+
]
|
36
|
+
|
37
|
+
RetrievalCallableSingle: TypeAlias = (
|
38
|
+
RetrievalFuncSingle[QueryT, QueryP, QueryR]
|
39
|
+
| RetrievalMethodSingle[QueryT, QueryP, QueryR]
|
40
|
+
)
|
41
|
+
|
42
|
+
RetrievalCallableList: TypeAlias = (
|
43
|
+
RetrievalFuncList[QueryT, QueryP, QueryR]
|
44
|
+
| RetrievalMethodList[QueryT, QueryP, QueryR]
|
45
|
+
)
|
46
|
+
|
47
|
+
|
48
|
+
RateLimDecoratorWithArgsSingle = Callable[
|
49
|
+
[RetrievalCallableSingle[QueryT, QueryP, QueryR]],
|
50
|
+
RetrievalCallableSingle[QueryT, QueryP, QueryR],
|
51
|
+
]
|
52
|
+
|
53
|
+
|
54
|
+
RateLimDecoratorWithArgsList = Callable[
|
55
|
+
[RetrievalCallableList[QueryT, QueryP, QueryR]],
|
56
|
+
RetrievalCallableList[QueryT, QueryP, QueryR],
|
57
|
+
]
|
@@ -0,0 +1,57 @@
|
|
1
|
+
import inspect
|
2
|
+
from collections.abc import Callable, Coroutine, Sequence
|
3
|
+
from typing import (
|
4
|
+
Any,
|
5
|
+
)
|
6
|
+
|
7
|
+
from .types import (
|
8
|
+
QueryP,
|
9
|
+
QueryR,
|
10
|
+
QueryT,
|
11
|
+
RetrievalCallableList,
|
12
|
+
RetrievalCallableSingle,
|
13
|
+
)
|
14
|
+
|
15
|
+
|
16
|
+
def is_bound_method(func: Callable[..., Any], self_candidate: Any) -> bool:
|
17
|
+
return (inspect.ismethod(func) and (func.__self__ is self_candidate)) or hasattr(self_candidate, func.__name__)
|
18
|
+
|
19
|
+
|
20
|
+
def split_pos_args(
|
21
|
+
call: (RetrievalCallableSingle[QueryT, QueryP, QueryR] | RetrievalCallableList[QueryT, QueryP, QueryR]),
|
22
|
+
args: Sequence[Any],
|
23
|
+
) -> tuple[Any | None, QueryT | list[QueryT], Sequence[Any]]:
|
24
|
+
if not args:
|
25
|
+
raise ValueError("No positional arguments passed.")
|
26
|
+
maybe_self = args[0]
|
27
|
+
if is_bound_method(call, maybe_self):
|
28
|
+
# Case: Bound instance method with signature (self, inp, *rest)
|
29
|
+
if len(args) < 2:
|
30
|
+
raise ValueError(
|
31
|
+
"Must pass at least `self` and an input (or a list of inputs) " + "for a bound instance method."
|
32
|
+
)
|
33
|
+
return maybe_self, args[1], args[2:]
|
34
|
+
# Case: Standalone function with signature (inp, *rest)
|
35
|
+
if not args:
|
36
|
+
raise ValueError("Must pass an input (or a list of inputs) " + "for a standalone function.")
|
37
|
+
return None, args[0], args[1:]
|
38
|
+
|
39
|
+
|
40
|
+
def partial_retrieval_callable(
|
41
|
+
call: Callable[..., Coroutine[Any, Any, QueryR]],
|
42
|
+
self_obj: Any,
|
43
|
+
*args: QueryP.args,
|
44
|
+
**kwargs: QueryP.kwargs,
|
45
|
+
) -> Callable[[QueryT], Coroutine[Any, Any, QueryR]]:
|
46
|
+
async def wrapper(inp: QueryT) -> QueryR:
|
47
|
+
if self_obj is not None:
|
48
|
+
# `call` is a method
|
49
|
+
return await call(self_obj, inp, *args, **kwargs)
|
50
|
+
# `call` is a function
|
51
|
+
return await call(inp, *args, **kwargs)
|
52
|
+
|
53
|
+
return wrapper
|
54
|
+
|
55
|
+
|
56
|
+
def expected_exec_time_from_max_concurrency_and_rpm(rpm: float, max_concurrency: int) -> float:
|
57
|
+
return 60.0 / (rpm / max_concurrency)
|
@@ -0,0 +1,36 @@
|
|
1
|
+
import logging
|
2
|
+
from logging import Formatter, LogRecord
|
3
|
+
from pathlib import Path
|
4
|
+
|
5
|
+
import yaml
|
6
|
+
from termcolor import colored
|
7
|
+
from termcolor._types import Color # type: ignore[import]
|
8
|
+
|
9
|
+
|
10
|
+
class ColorFormatter(Formatter):
|
11
|
+
def format(self, record: LogRecord) -> str:
|
12
|
+
message = super().format(record)
|
13
|
+
color: Color | None = getattr(record, "color", None)
|
14
|
+
if color:
|
15
|
+
return colored(message, color)
|
16
|
+
return message
|
17
|
+
|
18
|
+
|
19
|
+
def setup_logging(logs_file_path: str | Path, logs_config_path: str | Path) -> None:
|
20
|
+
logs_file_path = Path(logs_file_path)
|
21
|
+
logs_file_path.parent.mkdir(exist_ok=True, parents=True)
|
22
|
+
with Path(logs_config_path).open() as f:
|
23
|
+
config = yaml.safe_load(f)
|
24
|
+
|
25
|
+
config["handlers"]["fileHandler"]["filename"] = logs_file_path
|
26
|
+
|
27
|
+
logging.config.dictConfig(config) # type: ignore
|
28
|
+
|
29
|
+
root = logging.getLogger()
|
30
|
+
for handler in root.handlers:
|
31
|
+
if handler.formatter is not None:
|
32
|
+
fmt_str = handler.formatter._fmt # noqa: SLF001
|
33
|
+
handler.setFormatter(ColorFormatter(fmt_str))
|
34
|
+
|
35
|
+
|
36
|
+
logger = logging.getLogger(__name__)
|