grasp_agents 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. grasp_agents/agent_message.py +28 -0
  2. grasp_agents/agent_message_pool.py +94 -0
  3. grasp_agents/base_agent.py +72 -0
  4. grasp_agents/cloud_llm.py +353 -0
  5. grasp_agents/comm_agent.py +230 -0
  6. grasp_agents/costs_dict.yaml +122 -0
  7. grasp_agents/data_retrieval/__init__.py +7 -0
  8. grasp_agents/data_retrieval/rate_limiter_chunked.py +195 -0
  9. grasp_agents/data_retrieval/types.py +57 -0
  10. grasp_agents/data_retrieval/utils.py +57 -0
  11. grasp_agents/grasp_logging.py +36 -0
  12. grasp_agents/http_client.py +24 -0
  13. grasp_agents/llm.py +106 -0
  14. grasp_agents/llm_agent.py +361 -0
  15. grasp_agents/llm_agent_state.py +73 -0
  16. grasp_agents/memory.py +150 -0
  17. grasp_agents/openai/__init__.py +83 -0
  18. grasp_agents/openai/completion_converters.py +49 -0
  19. grasp_agents/openai/content_converters.py +80 -0
  20. grasp_agents/openai/converters.py +170 -0
  21. grasp_agents/openai/message_converters.py +155 -0
  22. grasp_agents/openai/openai_llm.py +179 -0
  23. grasp_agents/openai/tool_converters.py +37 -0
  24. grasp_agents/printer.py +156 -0
  25. grasp_agents/prompt_builder.py +204 -0
  26. grasp_agents/run_context.py +90 -0
  27. grasp_agents/tool_orchestrator.py +181 -0
  28. grasp_agents/typing/__init__.py +0 -0
  29. grasp_agents/typing/completion.py +30 -0
  30. grasp_agents/typing/content.py +116 -0
  31. grasp_agents/typing/converters.py +118 -0
  32. grasp_agents/typing/io.py +32 -0
  33. grasp_agents/typing/message.py +130 -0
  34. grasp_agents/typing/tool.py +52 -0
  35. grasp_agents/usage_tracker.py +99 -0
  36. grasp_agents/utils.py +151 -0
  37. grasp_agents/workflow/__init__.py +0 -0
  38. grasp_agents/workflow/looped_agent.py +113 -0
  39. grasp_agents/workflow/sequential_agent.py +57 -0
  40. grasp_agents/workflow/workflow_agent.py +69 -0
  41. grasp_agents-0.1.5.dist-info/METADATA +14 -0
  42. grasp_agents-0.1.5.dist-info/RECORD +44 -0
  43. grasp_agents-0.1.5.dist-info/WHEEL +4 -0
  44. grasp_agents-0.1.5.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,230 @@
1
+ import logging
2
+ from abc import abstractmethod
3
+ from collections.abc import Sequence
4
+ from typing import Any, Generic, Protocol, TypeVar, cast, final
5
+
6
+ from pydantic import BaseModel
7
+
8
+ from .agent_message import AgentMessage
9
+ from .agent_message_pool import AgentMessagePool
10
+ from .base_agent import BaseAgent
11
+ from .run_context import CtxT, RunContextWrapper
12
+ from .typing.io import AgentID, AgentPayload, AgentState, InT, OutT, StateT
13
+ from .typing.tool import BaseTool
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ _EH_OutT = TypeVar("_EH_OutT", bound=AgentPayload, contravariant=True) # noqa: PLC0105
18
+ _EH_StateT = TypeVar("_EH_StateT", bound=AgentState, contravariant=True) # noqa: PLC0105
19
+
20
+
21
+ class ExitHandler(Protocol[_EH_OutT, _EH_StateT, CtxT]):
22
+ def __call__(
23
+ self,
24
+ output_message: AgentMessage[_EH_OutT, _EH_StateT],
25
+ agent_state: _EH_StateT,
26
+ ctx: RunContextWrapper[CtxT] | None,
27
+ ) -> bool: ...
28
+
29
+
30
+ class CommunicatingAgent(
31
+ BaseAgent[OutT, StateT, CtxT], Generic[InT, OutT, StateT, CtxT]
32
+ ):
33
+ def __init__(
34
+ self,
35
+ agent_id: AgentID,
36
+ *,
37
+ out_schema: type[OutT] = AgentPayload,
38
+ rcv_args_schema: type[InT] = AgentPayload,
39
+ recipient_ids: Sequence[AgentID] | None = None,
40
+ message_pool: AgentMessagePool[CtxT] | None = None,
41
+ dynamic_routing: bool = False,
42
+ **kwargs: Any,
43
+ ) -> None:
44
+ super().__init__(agent_id=agent_id, out_schema=out_schema, **kwargs)
45
+ self._message_pool = message_pool or AgentMessagePool()
46
+
47
+ self._dynamic_routing = dynamic_routing
48
+
49
+ self._is_listening = False
50
+ self._exit_impl: ExitHandler[OutT, StateT, CtxT] | None = None
51
+
52
+ self._rcv_args_schema = rcv_args_schema
53
+ self.recipient_ids = recipient_ids or []
54
+
55
+ @property
56
+ def rcv_args_schema(self) -> type[InT]: # type: ignore[reportInvalidTypeVarUse]
57
+ return self._rcv_args_schema
58
+
59
+ @property
60
+ def dynamic_routing(self) -> bool:
61
+ return self._dynamic_routing
62
+
63
+ def _parse_output(
64
+ self,
65
+ *args: Any,
66
+ rcv_args: InT | None = None,
67
+ ctx: RunContextWrapper[CtxT] | None = None,
68
+ **kwargs: Any,
69
+ ) -> OutT:
70
+ if self._parse_output_impl:
71
+ return self._parse_output_impl(*args, rcv_args=rcv_args, ctx=ctx, **kwargs)
72
+
73
+ return self._out_schema()
74
+
75
+ def _validate_dynamic_routing(self, payloads: Sequence[OutT]) -> Sequence[AgentID]:
76
+ assert all((p.selected_recipient_ids is not None) for p in payloads), (
77
+ "Dynamic routing is enabled, but some payloads have no recipient IDs"
78
+ )
79
+
80
+ selected_recipient_ids_per_payload = [
81
+ set(p.selected_recipient_ids or []) for p in payloads
82
+ ]
83
+ assert all(
84
+ x == selected_recipient_ids_per_payload[0]
85
+ for x in selected_recipient_ids_per_payload
86
+ ), "All payloads must have the same recipient IDs for dynamic routing"
87
+
88
+ assert payloads[0].selected_recipient_ids is not None
89
+ selected_recipient_ids = payloads[0].selected_recipient_ids
90
+
91
+ assert all(rid in self.recipient_ids for rid in selected_recipient_ids), (
92
+ "Dynamic routing is enabled, but recipient IDs are not in "
93
+ "the allowed agent's recipient IDs"
94
+ )
95
+
96
+ return selected_recipient_ids
97
+
98
+ def _validate_static_routing(self, payloads: Sequence[OutT]) -> Sequence[AgentID]:
99
+ assert all((p.selected_recipient_ids is None) for p in payloads), (
100
+ "Dynamic routing is not enabled, but some payloads have recipient IDs"
101
+ )
102
+
103
+ return self.recipient_ids
104
+
105
+ async def post_message(self, message: AgentMessage[OutT, StateT]) -> None:
106
+ if self._dynamic_routing:
107
+ self._validate_dynamic_routing(message.payloads)
108
+ else:
109
+ self._validate_static_routing(message.payloads)
110
+
111
+ await self._message_pool.post(message)
112
+
113
+ @abstractmethod
114
+ async def run(
115
+ self,
116
+ inp_items: Any | None = None,
117
+ *,
118
+ ctx: RunContextWrapper[CtxT] | None = None,
119
+ rcv_message: AgentMessage[InT, StateT] | None = None,
120
+ entry_point: bool = False,
121
+ forbid_state_change: bool = False,
122
+ **kwargs: Any,
123
+ ) -> AgentMessage[OutT, StateT]:
124
+ pass
125
+
126
+ async def run_and_post(
127
+ self, ctx: RunContextWrapper[CtxT] | None = None, **run_kwargs: Any
128
+ ) -> None:
129
+ output_message = await self.run(
130
+ ctx=ctx, rcv_message=None, entry_point=True, **run_kwargs
131
+ )
132
+ await self.post_message(output_message)
133
+
134
+ def exit_handler(
135
+ self, func: ExitHandler[OutT, StateT, CtxT]
136
+ ) -> ExitHandler[OutT, StateT, CtxT]:
137
+ self._exit_impl = func
138
+
139
+ return func
140
+
141
+ def _exit_condition(
142
+ self,
143
+ output_message: AgentMessage[OutT, StateT],
144
+ ctx: RunContextWrapper[CtxT] | None,
145
+ ) -> bool:
146
+ if self._exit_impl:
147
+ return self._exit_impl(
148
+ output_message=output_message, agent_state=self.state, ctx=ctx
149
+ )
150
+
151
+ return False
152
+
153
+ async def _message_handler(
154
+ self,
155
+ message: AgentMessage[AgentPayload, AgentState],
156
+ ctx: RunContextWrapper[CtxT] | None = None,
157
+ **run_kwargs: Any,
158
+ ) -> None:
159
+ rcv_message = cast("AgentMessage[InT, StateT]", message)
160
+ out_message = await self.run(ctx=ctx, rcv_message=rcv_message, **run_kwargs)
161
+
162
+ if self._exit_condition(output_message=out_message, ctx=ctx):
163
+ await self._message_pool.stop_all()
164
+ return
165
+
166
+ if self.recipient_ids:
167
+ await self.post_message(out_message)
168
+
169
+ @property
170
+ def is_listening(self) -> bool:
171
+ return self._is_listening
172
+
173
+ async def start_listening(
174
+ self, ctx: RunContextWrapper[CtxT] | None = None, **run_kwargs: Any
175
+ ) -> None:
176
+ if self._is_listening:
177
+ return
178
+
179
+ self._is_listening = True
180
+ self._message_pool.register_message_handler(
181
+ agent_id=self.agent_id,
182
+ handler=self._message_handler,
183
+ ctx=ctx,
184
+ **run_kwargs,
185
+ )
186
+
187
+ async def stop_listening(self) -> None:
188
+ self._is_listening = False
189
+ await self._message_pool.unregister_message_handler(self.agent_id)
190
+
191
+ @final
192
+ def as_tool(
193
+ self, tool_name: str, tool_description: str, tool_strict: bool = True
194
+ ) -> BaseTool[BaseModel, BaseModel, CtxT]:
195
+ # assert self.state.batch_size == 1, (
196
+ # "Using agents as tools is only supported for batch size 1"
197
+ # )
198
+
199
+ agent_instance = self
200
+
201
+ class AgentTool(BaseTool[BaseModel, BaseModel, Any]):
202
+ name: str = tool_name
203
+ description: str = tool_description
204
+ in_schema: type[BaseModel] = agent_instance.rcv_args_schema
205
+ out_schema: type[BaseModel] = agent_instance.out_schema
206
+
207
+ strict: bool | None = tool_strict
208
+
209
+ async def run(
210
+ self,
211
+ inp: BaseModel,
212
+ ctx: RunContextWrapper[CtxT] | None = None,
213
+ ) -> OutT:
214
+ rcv_args = agent_instance.rcv_args_schema.model_validate(inp)
215
+ rcv_message = AgentMessage( # type: ignore[arg-type]
216
+ payloads=[rcv_args],
217
+ sender_id="<tool_user>",
218
+ recipient_ids=[agent_instance.agent_id],
219
+ )
220
+
221
+ agent_result = await agent_instance.run(
222
+ rcv_message=rcv_message, # type: ignore[arg-type]
223
+ entry_point=False,
224
+ forbid_state_change=True,
225
+ ctx=ctx,
226
+ )
227
+
228
+ return agent_result.payloads[0]
229
+
230
+ return AgentTool()
@@ -0,0 +1,122 @@
1
+ costs:
2
+ google/gemini-2.5-flash-preview:
3
+ input: 0.15
4
+ output: 0.60
5
+ cached_discount: 0.5
6
+ gpt-4.1:
7
+ input: 2.0
8
+ output: 8.0
9
+ cached_discount: 0.5
10
+ gpt-4.1-nano-2025-04-14:
11
+ input: 0.1
12
+ output: 0.4
13
+ cached_discount: 0.5
14
+ gpt-4.1-mini-2025-04-14:
15
+ input: 0.4
16
+ output: 1.6
17
+ cached_discount: 0.5
18
+ gpt-4.1-2025-04-14:
19
+ input: 2.0
20
+ output: 8.0
21
+ cached_discount: 0.5
22
+ google/gemini-2.5-pro-preview-03-25:
23
+ input: 1.25
24
+ output: 10
25
+ cached_discount: 0.5
26
+ gemini-2.5-pro-preview-03-25:
27
+ input: 1.25
28
+ output: 10
29
+ cached_discount: 0.5
30
+ llama-3.3-70b-versatile:
31
+ input: 0.59
32
+ output: 0.79
33
+ cached_discount: 0.5
34
+ gemini-1.5-pro:
35
+ input: 1.25
36
+ output: 5
37
+ cached_discount: 0.5
38
+ gemini-2.0-flash:
39
+ input: 0.1
40
+ output: 0.4
41
+ cached_discount: 0.5
42
+ anthropic/claude-3.5-sonnet:
43
+ input: 3.0
44
+ output: 15
45
+ cached_discount: 0.5
46
+ anthropic/claude-3.7-sonnet:
47
+ input: 3.0
48
+ output: 15
49
+ cached_discount: 0.5
50
+ google/gemini-2.0-flash-001:
51
+ input: 0.1
52
+ output: 0.4
53
+ cached_discount: 0.5
54
+ deepseek/deepseek-chat:
55
+ input: 0.4
56
+ output: 0.89
57
+ cached_discount: 0.5
58
+ o3-mini:
59
+ input: 1.1
60
+ output: 4.4
61
+ cached_discount: 0.5
62
+ o1:
63
+ input: 15
64
+ output: 60
65
+ cached_discount: 0.5
66
+ o1-2024-12-17:
67
+ input: 15
68
+ output: 60
69
+ cached_discount: 0.5
70
+ o1-preview:
71
+ input: 15
72
+ output: 60
73
+ cached_discount: 0.5
74
+ o1-preview-2024-09-12:
75
+ input: 15
76
+ output: 60
77
+ cached_discount: 0.5
78
+ o1-mini:
79
+ input: 1.1
80
+ output: 4.4
81
+ cached_discount: 0.5
82
+ o1-mini-2024-09-12:
83
+ input: 3
84
+ output: 12
85
+ cached_discount: 0.5
86
+ gpt-4o:
87
+ input: 2.5
88
+ output: 10
89
+ cached_discount: 0.5
90
+ gpt-4o-2024-05-13:
91
+ input: 5
92
+ output: 15
93
+ cached_discount: 0.5
94
+ gpt-4o-2024-08-06:
95
+ input: 2.5
96
+ output: 10
97
+ cached_discount: 0.5
98
+ gpt-4o-mini:
99
+ input: 0.15
100
+ output: 0.6
101
+ cached_discount: 0.5
102
+ gpt-4:
103
+ input: 30
104
+ output: 60
105
+ cached_discount: 0.5
106
+ gpt-4-turbo:
107
+ input: 10
108
+ output: 30
109
+ cached_discount: 0.5
110
+ gpt-4-turbo-preview:
111
+ input: 10
112
+ output: 30
113
+ cached_discount: 0.5
114
+ gpt-3.5-turbo:
115
+ input: 0.5
116
+ output: 1.5
117
+ cached_discount: 0.5
118
+ gpt-3.5-turbo-0125:
119
+ input: 0.5
120
+ output: 1.5
121
+ cached_discount: 0.5
122
+ date_of_last_update: 17-04-25
@@ -0,0 +1,7 @@
1
+ from .rate_limiter_chunked import RateLimiterC, limit_rate, limit_rate_chunked
2
+
3
+ __all__ = [
4
+ "RateLimiterC",
5
+ "limit_rate",
6
+ "limit_rate_chunked",
7
+ ]
@@ -0,0 +1,195 @@
1
+ import asyncio
2
+ import functools
3
+ import logging
4
+ from collections.abc import Callable, Coroutine
5
+ from time import monotonic
6
+ from typing import (
7
+ Any,
8
+ Generic,
9
+ overload,
10
+ )
11
+
12
+ from tqdm.autonotebook import tqdm
13
+
14
+ from ..utils import asyncio_gather_with_pbar
15
+
16
+ from .types import (
17
+ QueryP,
18
+ QueryR,
19
+ QueryT,
20
+ RateLimDecoratorWithArgsList,
21
+ RateLimDecoratorWithArgsSingle,
22
+ RateLimiterState,
23
+ RetrievalCallableList,
24
+ RetrievalCallableSingle,
25
+ )
26
+ from .utils import partial_retrieval_callable, split_pos_args
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ class RateLimiterC(Generic[QueryT, QueryR]):
32
+ def __init__(
33
+ self,
34
+ rpm: float,
35
+ chunk_size: int = 1000,
36
+ max_concurrency: int = 200,
37
+ ):
38
+ self._rpm = rpm
39
+ self._max_concurrency = max_concurrency
40
+ self._chunk_size = chunk_size
41
+
42
+ self._lock = asyncio.Lock()
43
+ self._state = RateLimiterState(next_request_time=0.0)
44
+ self._semaphore = asyncio.Semaphore(self._max_concurrency)
45
+
46
+ async def process_input(
47
+ self,
48
+ func_partial: Callable[[QueryT], Coroutine[Any, Any, QueryR]],
49
+ inp: QueryT,
50
+ ) -> QueryR:
51
+ async with self._semaphore:
52
+ async with self._lock:
53
+ now = monotonic()
54
+ if now < self._state.next_request_time:
55
+ await asyncio.sleep(self._state.next_request_time - now)
56
+ self._state.next_request_time = monotonic() + 1.01 * 60.0 / self._rpm
57
+ result = await func_partial(inp)
58
+
59
+ return result
60
+
61
+ async def process_inputs(
62
+ self,
63
+ func_partial: Callable[[QueryT], Coroutine[Any, Any, QueryR]],
64
+ inputs: list[QueryT],
65
+ no_tqdm: bool = False,
66
+ ) -> list[QueryR]:
67
+ results: list[QueryR] = []
68
+ for i in tqdm(
69
+ range(0, len(inputs), self._chunk_size),
70
+ disable=no_tqdm,
71
+ desc="Processing chunks",
72
+ ):
73
+ chunk = inputs[i : i + self._chunk_size]
74
+ corouts = [
75
+ self.process_input(func_partial=func_partial, inp=inp) for inp in chunk
76
+ ]
77
+ chunk_results = await asyncio.gather(*corouts)
78
+ results.extend(chunk_results)
79
+
80
+ return results
81
+
82
+ @property
83
+ def rpm(self) -> float:
84
+ return self._rpm
85
+
86
+ @rpm.setter
87
+ def rpm(self, value: float) -> None:
88
+ self._rpm = value
89
+
90
+ @property
91
+ def max_concurrency(self) -> int:
92
+ return self._max_concurrency
93
+
94
+ @property
95
+ def chunk_size(self) -> int:
96
+ return self._chunk_size
97
+
98
+ @property
99
+ def state(self) -> RateLimiterState:
100
+ return self._state
101
+
102
+
103
+ @overload
104
+ def limit_rate(
105
+ call: RetrievalCallableSingle[QueryT, QueryP, QueryR],
106
+ rate_limiter: RateLimiterC[QueryT, QueryR] | None = None,
107
+ ) -> RetrievalCallableSingle[QueryT, QueryP, QueryR]: ...
108
+
109
+
110
+ @overload
111
+ def limit_rate(
112
+ call: None = None,
113
+ rate_limiter: RateLimiterC[QueryT, QueryR] | None = None,
114
+ ) -> RateLimDecoratorWithArgsSingle[QueryT, QueryP, QueryR]: ...
115
+
116
+
117
+ def limit_rate(
118
+ call: RetrievalCallableSingle[QueryT, QueryP, QueryR] | None = None,
119
+ rate_limiter: RateLimiterC[QueryT, QueryR] | None = None,
120
+ ) -> (
121
+ RetrievalCallableSingle[QueryT, QueryP, QueryR]
122
+ | RateLimDecoratorWithArgsSingle[QueryT, QueryP, QueryR]
123
+ ):
124
+ if call is None:
125
+ return functools.partial(limit_rate, rate_limiter=rate_limiter)
126
+
127
+ @functools.wraps(call) # type: ignore
128
+ async def wrapper(*args: Any, **kwargs: Any) -> QueryR:
129
+ inp: QueryT
130
+ self_obj, inp, other_args = split_pos_args(call, args) # type: ignore
131
+ call_partial = partial_retrieval_callable(call, self_obj, *other_args, **kwargs)
132
+
133
+ _rate_limiter = rate_limiter
134
+ if _rate_limiter is None:
135
+ _rate_limiter = getattr(self_obj, "rate_limiter", None)
136
+
137
+ if _rate_limiter is None:
138
+ return await call_partial(inp)
139
+ return await _rate_limiter.process_input(func_partial=call_partial, inp=inp)
140
+
141
+ return wrapper
142
+
143
+
144
+ @overload
145
+ def limit_rate_chunked(
146
+ call: RetrievalCallableSingle[QueryT, QueryP, QueryR],
147
+ rate_limiter: RateLimiterC[QueryT, QueryR] | None = None,
148
+ no_tqdm: bool | None = None,
149
+ ) -> RetrievalCallableList[QueryT, QueryP, QueryR]: ...
150
+
151
+
152
+ @overload
153
+ def limit_rate_chunked(
154
+ call: None = None,
155
+ rate_limiter: RateLimiterC[QueryT, QueryR] | None = None,
156
+ no_tqdm: bool | None = None,
157
+ ) -> RateLimDecoratorWithArgsList[QueryT, QueryP, QueryR]: ...
158
+
159
+
160
+ def limit_rate_chunked(
161
+ call: RetrievalCallableSingle[QueryT, QueryP, QueryR] | None = None,
162
+ rate_limiter: RateLimiterC[QueryT, QueryR] | None = None,
163
+ no_tqdm: bool | None = None,
164
+ ) -> (
165
+ RetrievalCallableList[QueryT, QueryP, QueryR]
166
+ | RateLimDecoratorWithArgsList[QueryT, QueryP, QueryR]
167
+ ):
168
+ if call is None:
169
+ return functools.partial(
170
+ limit_rate_chunked, rate_limiter=rate_limiter, no_tqdm=no_tqdm
171
+ )
172
+
173
+ @functools.wraps(call) # type: ignore
174
+ async def wrapper(*args: Any, **kwargs: Any) -> list[QueryR]:
175
+ assert call is not None
176
+
177
+ self_obj, inputs, other_args = split_pos_args(call, args) # type: ignore
178
+ call_partial = partial_retrieval_callable(call, self_obj, *other_args, **kwargs)
179
+
180
+ _no_tqdm = no_tqdm
181
+ _rate_limiter = rate_limiter
182
+ if _no_tqdm is None:
183
+ _no_tqdm = getattr(self_obj, "no_tqdm", False)
184
+ if _rate_limiter is None:
185
+ _rate_limiter = getattr(self_obj, "rate_limiter", None)
186
+
187
+ if _rate_limiter is None:
188
+ return await asyncio_gather_with_pbar(
189
+ *[call_partial(inp) for inp in inputs], no_tqdm=_no_tqdm
190
+ )
191
+ return await _rate_limiter.process_inputs(
192
+ func_partial=call_partial, inputs=inputs, no_tqdm=_no_tqdm
193
+ )
194
+
195
+ return wrapper
@@ -0,0 +1,57 @@
1
+ from collections.abc import Callable, Coroutine
2
+ from dataclasses import dataclass
3
+ from typing import (
4
+ Any,
5
+ Concatenate,
6
+ ParamSpec,
7
+ TypeAlias,
8
+ TypeVar,
9
+ )
10
+
11
+ MAX_RPM = 1e10
12
+
13
+
14
+ @dataclass
15
+ class RateLimiterState:
16
+ next_request_time: float = 0.0
17
+
18
+
19
+ QueryT = TypeVar("QueryT")
20
+ QueryR = TypeVar("QueryR")
21
+ QueryP = ParamSpec("QueryP")
22
+
23
+ RetrievalFuncSingle: TypeAlias = Callable[
24
+ Concatenate[QueryT, QueryP], Coroutine[Any, Any, QueryR]
25
+ ]
26
+ RetrievalFuncList: TypeAlias = Callable[
27
+ Concatenate[list[QueryT], QueryP], Coroutine[Any, Any, list[QueryR]]
28
+ ]
29
+
30
+ RetrievalMethodSingle: TypeAlias = Callable[
31
+ Concatenate[Any, QueryT, QueryP], Coroutine[Any, Any, QueryR]
32
+ ]
33
+ RetrievalMethodList: TypeAlias = Callable[
34
+ Concatenate[Any, list[QueryT], QueryP], Coroutine[Any, Any, list[QueryR]]
35
+ ]
36
+
37
+ RetrievalCallableSingle: TypeAlias = (
38
+ RetrievalFuncSingle[QueryT, QueryP, QueryR]
39
+ | RetrievalMethodSingle[QueryT, QueryP, QueryR]
40
+ )
41
+
42
+ RetrievalCallableList: TypeAlias = (
43
+ RetrievalFuncList[QueryT, QueryP, QueryR]
44
+ | RetrievalMethodList[QueryT, QueryP, QueryR]
45
+ )
46
+
47
+
48
+ RateLimDecoratorWithArgsSingle = Callable[
49
+ [RetrievalCallableSingle[QueryT, QueryP, QueryR]],
50
+ RetrievalCallableSingle[QueryT, QueryP, QueryR],
51
+ ]
52
+
53
+
54
+ RateLimDecoratorWithArgsList = Callable[
55
+ [RetrievalCallableList[QueryT, QueryP, QueryR]],
56
+ RetrievalCallableList[QueryT, QueryP, QueryR],
57
+ ]
@@ -0,0 +1,57 @@
1
+ import inspect
2
+ from collections.abc import Callable, Coroutine, Sequence
3
+ from typing import (
4
+ Any,
5
+ )
6
+
7
+ from .types import (
8
+ QueryP,
9
+ QueryR,
10
+ QueryT,
11
+ RetrievalCallableList,
12
+ RetrievalCallableSingle,
13
+ )
14
+
15
+
16
+ def is_bound_method(func: Callable[..., Any], self_candidate: Any) -> bool:
17
+ return (inspect.ismethod(func) and (func.__self__ is self_candidate)) or hasattr(self_candidate, func.__name__)
18
+
19
+
20
+ def split_pos_args(
21
+ call: (RetrievalCallableSingle[QueryT, QueryP, QueryR] | RetrievalCallableList[QueryT, QueryP, QueryR]),
22
+ args: Sequence[Any],
23
+ ) -> tuple[Any | None, QueryT | list[QueryT], Sequence[Any]]:
24
+ if not args:
25
+ raise ValueError("No positional arguments passed.")
26
+ maybe_self = args[0]
27
+ if is_bound_method(call, maybe_self):
28
+ # Case: Bound instance method with signature (self, inp, *rest)
29
+ if len(args) < 2:
30
+ raise ValueError(
31
+ "Must pass at least `self` and an input (or a list of inputs) " + "for a bound instance method."
32
+ )
33
+ return maybe_self, args[1], args[2:]
34
+ # Case: Standalone function with signature (inp, *rest)
35
+ if not args:
36
+ raise ValueError("Must pass an input (or a list of inputs) " + "for a standalone function.")
37
+ return None, args[0], args[1:]
38
+
39
+
40
+ def partial_retrieval_callable(
41
+ call: Callable[..., Coroutine[Any, Any, QueryR]],
42
+ self_obj: Any,
43
+ *args: QueryP.args,
44
+ **kwargs: QueryP.kwargs,
45
+ ) -> Callable[[QueryT], Coroutine[Any, Any, QueryR]]:
46
+ async def wrapper(inp: QueryT) -> QueryR:
47
+ if self_obj is not None:
48
+ # `call` is a method
49
+ return await call(self_obj, inp, *args, **kwargs)
50
+ # `call` is a function
51
+ return await call(inp, *args, **kwargs)
52
+
53
+ return wrapper
54
+
55
+
56
+ def expected_exec_time_from_max_concurrency_and_rpm(rpm: float, max_concurrency: int) -> float:
57
+ return 60.0 / (rpm / max_concurrency)
@@ -0,0 +1,36 @@
1
+ import logging
2
+ from logging import Formatter, LogRecord
3
+ from pathlib import Path
4
+
5
+ import yaml
6
+ from termcolor import colored
7
+ from termcolor._types import Color # type: ignore[import]
8
+
9
+
10
+ class ColorFormatter(Formatter):
11
+ def format(self, record: LogRecord) -> str:
12
+ message = super().format(record)
13
+ color: Color | None = getattr(record, "color", None)
14
+ if color:
15
+ return colored(message, color)
16
+ return message
17
+
18
+
19
+ def setup_logging(logs_file_path: str | Path, logs_config_path: str | Path) -> None:
20
+ logs_file_path = Path(logs_file_path)
21
+ logs_file_path.parent.mkdir(exist_ok=True, parents=True)
22
+ with Path(logs_config_path).open() as f:
23
+ config = yaml.safe_load(f)
24
+
25
+ config["handlers"]["fileHandler"]["filename"] = logs_file_path
26
+
27
+ logging.config.dictConfig(config) # type: ignore
28
+
29
+ root = logging.getLogger()
30
+ for handler in root.handlers:
31
+ if handler.formatter is not None:
32
+ fmt_str = handler.formatter._fmt # noqa: SLF001
33
+ handler.setFormatter(ColorFormatter(fmt_str))
34
+
35
+
36
+ logger = logging.getLogger(__name__)