krons 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kronos/__init__.py +0 -0
- kronos/core/__init__.py +145 -0
- kronos/core/broadcaster.py +116 -0
- kronos/core/element.py +225 -0
- kronos/core/event.py +316 -0
- kronos/core/eventbus.py +116 -0
- kronos/core/flow.py +356 -0
- kronos/core/graph.py +442 -0
- kronos/core/node.py +982 -0
- kronos/core/pile.py +575 -0
- kronos/core/processor.py +494 -0
- kronos/core/progression.py +296 -0
- kronos/enforcement/__init__.py +57 -0
- kronos/enforcement/common/__init__.py +34 -0
- kronos/enforcement/common/boolean.py +85 -0
- kronos/enforcement/common/choice.py +97 -0
- kronos/enforcement/common/mapping.py +118 -0
- kronos/enforcement/common/model.py +102 -0
- kronos/enforcement/common/number.py +98 -0
- kronos/enforcement/common/string.py +140 -0
- kronos/enforcement/context.py +129 -0
- kronos/enforcement/policy.py +80 -0
- kronos/enforcement/registry.py +153 -0
- kronos/enforcement/rule.py +312 -0
- kronos/enforcement/service.py +370 -0
- kronos/enforcement/validator.py +198 -0
- kronos/errors.py +146 -0
- kronos/operations/__init__.py +32 -0
- kronos/operations/builder.py +228 -0
- kronos/operations/flow.py +398 -0
- kronos/operations/node.py +101 -0
- kronos/operations/registry.py +92 -0
- kronos/protocols.py +414 -0
- kronos/py.typed +0 -0
- kronos/services/__init__.py +81 -0
- kronos/services/backend.py +286 -0
- kronos/services/endpoint.py +608 -0
- kronos/services/hook.py +471 -0
- kronos/services/imodel.py +465 -0
- kronos/services/registry.py +115 -0
- kronos/services/utilities/__init__.py +36 -0
- kronos/services/utilities/header_factory.py +87 -0
- kronos/services/utilities/rate_limited_executor.py +271 -0
- kronos/services/utilities/rate_limiter.py +180 -0
- kronos/services/utilities/resilience.py +414 -0
- kronos/session/__init__.py +41 -0
- kronos/session/exchange.py +258 -0
- kronos/session/message.py +60 -0
- kronos/session/session.py +411 -0
- kronos/specs/__init__.py +25 -0
- kronos/specs/adapters/__init__.py +0 -0
- kronos/specs/adapters/_utils.py +45 -0
- kronos/specs/adapters/dataclass_field.py +246 -0
- kronos/specs/adapters/factory.py +56 -0
- kronos/specs/adapters/pydantic_adapter.py +309 -0
- kronos/specs/adapters/sql_ddl.py +946 -0
- kronos/specs/catalog/__init__.py +36 -0
- kronos/specs/catalog/_audit.py +39 -0
- kronos/specs/catalog/_common.py +43 -0
- kronos/specs/catalog/_content.py +59 -0
- kronos/specs/catalog/_enforcement.py +70 -0
- kronos/specs/factory.py +120 -0
- kronos/specs/operable.py +314 -0
- kronos/specs/phrase.py +405 -0
- kronos/specs/protocol.py +140 -0
- kronos/specs/spec.py +506 -0
- kronos/types/__init__.py +60 -0
- kronos/types/_sentinel.py +311 -0
- kronos/types/base.py +369 -0
- kronos/types/db_types.py +260 -0
- kronos/types/identity.py +66 -0
- kronos/utils/__init__.py +40 -0
- kronos/utils/_hash.py +234 -0
- kronos/utils/_json_dump.py +392 -0
- kronos/utils/_lazy_init.py +63 -0
- kronos/utils/_to_list.py +165 -0
- kronos/utils/_to_num.py +85 -0
- kronos/utils/_utils.py +375 -0
- kronos/utils/concurrency/__init__.py +205 -0
- kronos/utils/concurrency/_async_call.py +333 -0
- kronos/utils/concurrency/_cancel.py +122 -0
- kronos/utils/concurrency/_errors.py +96 -0
- kronos/utils/concurrency/_patterns.py +363 -0
- kronos/utils/concurrency/_primitives.py +328 -0
- kronos/utils/concurrency/_priority_queue.py +135 -0
- kronos/utils/concurrency/_resource_tracker.py +110 -0
- kronos/utils/concurrency/_run_async.py +67 -0
- kronos/utils/concurrency/_task.py +95 -0
- kronos/utils/concurrency/_utils.py +79 -0
- kronos/utils/fuzzy/__init__.py +14 -0
- kronos/utils/fuzzy/_extract_json.py +90 -0
- kronos/utils/fuzzy/_fuzzy_json.py +288 -0
- kronos/utils/fuzzy/_fuzzy_match.py +149 -0
- kronos/utils/fuzzy/_string_similarity.py +187 -0
- kronos/utils/fuzzy/_to_dict.py +396 -0
- kronos/utils/sql/__init__.py +13 -0
- kronos/utils/sql/_sql_validation.py +142 -0
- krons-0.1.0.dist-info/METADATA +70 -0
- krons-0.1.0.dist-info/RECORD +101 -0
- krons-0.1.0.dist-info/WHEEL +4 -0
- krons-0.1.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
# Copyright (c) 2025 - 2026, HaiyangLi <quantocean.li at gmail dot com>
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
"""HTTP header construction utilities for API authentication."""
|
|
5
|
+
|
|
6
|
+
from typing import Literal
|
|
7
|
+
|
|
8
|
+
from pydantic import SecretStr
|
|
9
|
+
|
|
10
|
+
AUTH_TYPES = Literal["bearer", "x-api-key", "none"]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class HeaderFactory:
|
|
14
|
+
"""Factory for constructing HTTP headers with various auth schemes.
|
|
15
|
+
|
|
16
|
+
Supports Bearer token, x-api-key, and no-auth patterns.
|
|
17
|
+
Handles SecretStr unwrapping and validation automatically.
|
|
18
|
+
|
|
19
|
+
Example:
|
|
20
|
+
>>> headers = HeaderFactory.get_header("bearer", api_key="sk-xxx")
|
|
21
|
+
>>> headers
|
|
22
|
+
{'Content-Type': 'application/json', 'Authorization': 'Bearer sk-xxx'}
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
@staticmethod
|
|
26
|
+
def get_content_type_header(
|
|
27
|
+
content_type: str = "application/json",
|
|
28
|
+
) -> dict[str, str]:
|
|
29
|
+
"""Build Content-Type header dict."""
|
|
30
|
+
return {"Content-Type": content_type}
|
|
31
|
+
|
|
32
|
+
@staticmethod
|
|
33
|
+
def get_bearer_auth_header(api_key: str) -> dict[str, str]:
|
|
34
|
+
"""Build Authorization header with Bearer scheme."""
|
|
35
|
+
return {"Authorization": f"Bearer {api_key}"}
|
|
36
|
+
|
|
37
|
+
@staticmethod
|
|
38
|
+
def get_x_api_key_header(api_key: str) -> dict[str, str]:
|
|
39
|
+
"""Build x-api-key header for providers requiring this scheme."""
|
|
40
|
+
return {"x-api-key": api_key}
|
|
41
|
+
|
|
42
|
+
@staticmethod
|
|
43
|
+
def get_header(
|
|
44
|
+
auth_type: AUTH_TYPES,
|
|
45
|
+
content_type: str | None = "application/json",
|
|
46
|
+
api_key: str | SecretStr | None = None,
|
|
47
|
+
default_headers: dict[str, str] | None = None,
|
|
48
|
+
) -> dict[str, str]:
|
|
49
|
+
"""Construct complete HTTP headers for API requests.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
auth_type: Authentication scheme ("bearer", "x-api-key", "none").
|
|
53
|
+
content_type: Content-Type value (None to omit).
|
|
54
|
+
api_key: API key (str or SecretStr, required unless auth_type="none").
|
|
55
|
+
default_headers: Additional headers to merge.
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
Complete header dict ready for HTTP client.
|
|
59
|
+
|
|
60
|
+
Raises:
|
|
61
|
+
ValueError: If api_key missing/empty when auth required, or invalid auth_type.
|
|
62
|
+
"""
|
|
63
|
+
dict_ = {}
|
|
64
|
+
if content_type is not None:
|
|
65
|
+
dict_ = HeaderFactory.get_content_type_header(content_type)
|
|
66
|
+
|
|
67
|
+
if auth_type == "none":
|
|
68
|
+
pass
|
|
69
|
+
else:
|
|
70
|
+
if isinstance(api_key, SecretStr):
|
|
71
|
+
api_key = api_key.get_secret_value()
|
|
72
|
+
|
|
73
|
+
if not api_key or not str(api_key).strip():
|
|
74
|
+
raise ValueError("API key is required for authentication")
|
|
75
|
+
|
|
76
|
+
api_key = api_key.strip()
|
|
77
|
+
|
|
78
|
+
if auth_type == "bearer":
|
|
79
|
+
dict_.update(HeaderFactory.get_bearer_auth_header(api_key))
|
|
80
|
+
elif auth_type == "x-api-key":
|
|
81
|
+
dict_.update(HeaderFactory.get_x_api_key_header(api_key))
|
|
82
|
+
else:
|
|
83
|
+
raise ValueError(f"Unsupported auth type: {auth_type}")
|
|
84
|
+
|
|
85
|
+
if default_headers:
|
|
86
|
+
dict_.update(default_headers)
|
|
87
|
+
return dict_
|
|
@@ -0,0 +1,271 @@
|
|
|
1
|
+
# Copyright (c) 2025 - 2026, HaiyangLi <quantocean.li at gmail dot com>
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
"""Rate-limited execution infrastructure with dual token bucket support.
|
|
5
|
+
|
|
6
|
+
Provides permission-based rate limiting for API calls with separate
|
|
7
|
+
request count and token usage limits, plus atomic rollback on partial acquire.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import logging
|
|
13
|
+
from typing import TYPE_CHECKING, Any, Self
|
|
14
|
+
|
|
15
|
+
from typing_extensions import override
|
|
16
|
+
|
|
17
|
+
from kronos.core import Event, Executor, Processor
|
|
18
|
+
from kronos.services.endpoint import APICalling
|
|
19
|
+
from kronos.utils.concurrency import get_cancelled_exc_class, sleep
|
|
20
|
+
|
|
21
|
+
from .rate_limiter import TokenBucket
|
|
22
|
+
|
|
23
|
+
if TYPE_CHECKING:
|
|
24
|
+
import asyncio
|
|
25
|
+
|
|
26
|
+
from kronos.core import Pile
|
|
27
|
+
|
|
28
|
+
__all__ = ("RateLimitedExecutor", "RateLimitedProcessor")
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger(__name__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class RateLimitedProcessor(Processor):
|
|
34
|
+
"""Processor with dual token bucket rate limiting (requests + tokens).
|
|
35
|
+
|
|
36
|
+
Enforces both request count and token usage limits atomically.
|
|
37
|
+
Automatically rolls back request bucket if token bucket acquire fails.
|
|
38
|
+
|
|
39
|
+
Example:
|
|
40
|
+
>>> req_bucket = TokenBucket(RateLimitConfig(capacity=100, refill_rate=1.67))
|
|
41
|
+
>>> tok_bucket = TokenBucket(RateLimitConfig(capacity=100000, refill_rate=1667))
|
|
42
|
+
>>> processor = await RateLimitedProcessor.create(
|
|
43
|
+
... queue_capacity=50, capacity_refresh_time=60.0,
|
|
44
|
+
... request_bucket=req_bucket, token_bucket=tok_bucket
|
|
45
|
+
... )
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
event_type = APICalling
|
|
49
|
+
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
queue_capacity: int,
|
|
53
|
+
capacity_refresh_time: float,
|
|
54
|
+
pile: Pile[Event] | None = None,
|
|
55
|
+
executor: Executor | None = None,
|
|
56
|
+
request_bucket: TokenBucket | None = None,
|
|
57
|
+
token_bucket: TokenBucket | None = None,
|
|
58
|
+
replenishment_interval: float = 60.0,
|
|
59
|
+
concurrency_limit: int = 100,
|
|
60
|
+
max_queue_size: int = 1000,
|
|
61
|
+
max_denial_tracking: int = 10000,
|
|
62
|
+
) -> None:
|
|
63
|
+
"""Initialize rate-limited processor.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
queue_capacity: Max events per batch.
|
|
67
|
+
capacity_refresh_time: Batch refresh interval (seconds).
|
|
68
|
+
pile: Reference to executor's Flow.items (set by executor).
|
|
69
|
+
executor: Reference to executor for progression updates.
|
|
70
|
+
request_bucket: TokenBucket for request rate limiting.
|
|
71
|
+
token_bucket: TokenBucket for token rate limiting.
|
|
72
|
+
replenishment_interval: Rate limit reset interval.
|
|
73
|
+
concurrency_limit: Max concurrent executions.
|
|
74
|
+
max_queue_size: Max queue size.
|
|
75
|
+
max_denial_tracking: Max denial entries to track.
|
|
76
|
+
"""
|
|
77
|
+
super().__init__( # type: ignore[arg-type]
|
|
78
|
+
queue_capacity=queue_capacity,
|
|
79
|
+
capacity_refresh_time=capacity_refresh_time,
|
|
80
|
+
pile=pile, # type: ignore[arg-type]
|
|
81
|
+
executor=executor,
|
|
82
|
+
concurrency_limit=concurrency_limit,
|
|
83
|
+
max_queue_size=max_queue_size,
|
|
84
|
+
max_denial_tracking=max_denial_tracking,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
self.request_bucket = request_bucket
|
|
88
|
+
self.token_bucket = token_bucket
|
|
89
|
+
self.replenishment_interval = replenishment_interval
|
|
90
|
+
self.concurrency_limit = concurrency_limit
|
|
91
|
+
self._replenisher_task: asyncio.Task[None] | None = None
|
|
92
|
+
|
|
93
|
+
async def start_replenishing(self) -> None:
|
|
94
|
+
"""Background task: periodically reset rate limit buckets to full capacity."""
|
|
95
|
+
await self.start()
|
|
96
|
+
|
|
97
|
+
try:
|
|
98
|
+
while not self.is_stopped():
|
|
99
|
+
await sleep(self.replenishment_interval)
|
|
100
|
+
|
|
101
|
+
if self.request_bucket:
|
|
102
|
+
await self.request_bucket.reset()
|
|
103
|
+
logger.debug(
|
|
104
|
+
"Request bucket replenished: %d requests",
|
|
105
|
+
self.request_bucket.capacity,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
if self.token_bucket:
|
|
109
|
+
await self.token_bucket.reset()
|
|
110
|
+
logger.debug(
|
|
111
|
+
"Token bucket replenished: %d tokens",
|
|
112
|
+
self.token_bucket.capacity,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
except get_cancelled_exc_class():
|
|
116
|
+
logger.info("Rate limit replenisher task cancelled.")
|
|
117
|
+
|
|
118
|
+
@override
|
|
119
|
+
@classmethod
|
|
120
|
+
async def create( # type: ignore[override]
|
|
121
|
+
cls,
|
|
122
|
+
queue_capacity: int,
|
|
123
|
+
capacity_refresh_time: float,
|
|
124
|
+
pile: Pile[Event] | None = None,
|
|
125
|
+
executor: Executor | None = None,
|
|
126
|
+
request_bucket: TokenBucket | None = None,
|
|
127
|
+
token_bucket: TokenBucket | None = None,
|
|
128
|
+
replenishment_interval: float = 60.0,
|
|
129
|
+
concurrency_limit: int = 100,
|
|
130
|
+
max_queue_size: int = 1000,
|
|
131
|
+
max_denial_tracking: int = 10000,
|
|
132
|
+
) -> Self:
|
|
133
|
+
"""Factory: create processor and start background replenishment task."""
|
|
134
|
+
self = cls(
|
|
135
|
+
queue_capacity=queue_capacity,
|
|
136
|
+
capacity_refresh_time=capacity_refresh_time,
|
|
137
|
+
pile=pile,
|
|
138
|
+
executor=executor,
|
|
139
|
+
request_bucket=request_bucket,
|
|
140
|
+
token_bucket=token_bucket,
|
|
141
|
+
replenishment_interval=replenishment_interval,
|
|
142
|
+
concurrency_limit=concurrency_limit,
|
|
143
|
+
max_queue_size=max_queue_size,
|
|
144
|
+
max_denial_tracking=max_denial_tracking,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
import asyncio
|
|
148
|
+
|
|
149
|
+
self._replenisher_task = asyncio.create_task(self.start_replenishing())
|
|
150
|
+
|
|
151
|
+
return self
|
|
152
|
+
|
|
153
|
+
@override
|
|
154
|
+
async def stop(self) -> None:
|
|
155
|
+
"""Stop processor and cancel background replenishment task."""
|
|
156
|
+
if self._replenisher_task:
|
|
157
|
+
self._replenisher_task.cancel()
|
|
158
|
+
try:
|
|
159
|
+
await self._replenisher_task
|
|
160
|
+
except get_cancelled_exc_class():
|
|
161
|
+
pass
|
|
162
|
+
|
|
163
|
+
await super().stop()
|
|
164
|
+
|
|
165
|
+
@override
|
|
166
|
+
async def request_permission(
|
|
167
|
+
self,
|
|
168
|
+
required_tokens: int | None = None,
|
|
169
|
+
**kwargs: Any,
|
|
170
|
+
) -> bool:
|
|
171
|
+
"""Check rate limits and acquire tokens atomically.
|
|
172
|
+
|
|
173
|
+
Acquires from request bucket first, then token bucket. If token bucket
|
|
174
|
+
fails, rolls back request bucket automatically.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
required_tokens: Token count for this request (None = skip token check).
|
|
178
|
+
**kwargs: Ignored (for interface compatibility).
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
True if permitted, False if rate limited.
|
|
182
|
+
"""
|
|
183
|
+
if self.request_bucket is None and self.token_bucket is None:
|
|
184
|
+
return True
|
|
185
|
+
|
|
186
|
+
request_acquired = False
|
|
187
|
+
if self.request_bucket:
|
|
188
|
+
request_acquired = await self.request_bucket.try_acquire(tokens=1)
|
|
189
|
+
if not request_acquired:
|
|
190
|
+
logger.debug("Request rate limit exceeded")
|
|
191
|
+
return False
|
|
192
|
+
|
|
193
|
+
if self.token_bucket and required_tokens:
|
|
194
|
+
token_acquired = await self.token_bucket.try_acquire(tokens=required_tokens)
|
|
195
|
+
if not token_acquired:
|
|
196
|
+
if request_acquired and self.request_bucket:
|
|
197
|
+
await self.request_bucket.release(tokens=1)
|
|
198
|
+
|
|
199
|
+
logger.debug(
|
|
200
|
+
f"Token rate limit exceeded (required: {required_tokens}, "
|
|
201
|
+
f"available: {self.token_bucket.tokens:.0f})"
|
|
202
|
+
)
|
|
203
|
+
return False
|
|
204
|
+
|
|
205
|
+
return True
|
|
206
|
+
|
|
207
|
+
def to_dict(self) -> dict[str, Any]:
|
|
208
|
+
"""Serialize processor config to dict (excludes runtime state)."""
|
|
209
|
+
return {
|
|
210
|
+
"queue_capacity": self.queue_capacity,
|
|
211
|
+
"capacity_refresh_time": self.capacity_refresh_time,
|
|
212
|
+
"replenishment_interval": self.replenishment_interval,
|
|
213
|
+
"concurrency_limit": self.concurrency_limit,
|
|
214
|
+
"max_queue_size": self.max_queue_size,
|
|
215
|
+
"max_denial_tracking": self.max_denial_tracking,
|
|
216
|
+
"request_bucket": (self.request_bucket.to_dict() if self.request_bucket else None),
|
|
217
|
+
"token_bucket": (self.token_bucket.to_dict() if self.token_bucket else None),
|
|
218
|
+
}
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
class RateLimitedExecutor(Executor):
|
|
222
|
+
"""Executor with integrated rate limiting via RateLimitedProcessor.
|
|
223
|
+
|
|
224
|
+
Manages processor lifecycle and forwards events for permission checking.
|
|
225
|
+
|
|
226
|
+
Example:
|
|
227
|
+
>>> executor = RateLimitedExecutor(processor_config={
|
|
228
|
+
... "queue_capacity": 50,
|
|
229
|
+
... "capacity_refresh_time": 60.0,
|
|
230
|
+
... "request_bucket": req_bucket,
|
|
231
|
+
... "token_bucket": tok_bucket,
|
|
232
|
+
... })
|
|
233
|
+
>>> await executor.start()
|
|
234
|
+
"""
|
|
235
|
+
|
|
236
|
+
processor_type = RateLimitedProcessor
|
|
237
|
+
|
|
238
|
+
def __init__(
|
|
239
|
+
self,
|
|
240
|
+
processor_config: dict[str, Any] | None = None,
|
|
241
|
+
strict_event_type: bool = False,
|
|
242
|
+
name: str | None = None,
|
|
243
|
+
) -> None:
|
|
244
|
+
"""Initialize rate-limited executor.
|
|
245
|
+
|
|
246
|
+
Args:
|
|
247
|
+
processor_config: Config dict for RateLimitedProcessor.create().
|
|
248
|
+
strict_event_type: If True, Flow enforces exact type matching.
|
|
249
|
+
name: Optional name for the executor Flow.
|
|
250
|
+
"""
|
|
251
|
+
super().__init__(
|
|
252
|
+
processor_config=processor_config,
|
|
253
|
+
strict_event_type=strict_event_type,
|
|
254
|
+
name=name or "rate_limited_executor",
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
@override
|
|
258
|
+
async def start(self) -> None:
|
|
259
|
+
"""Start executor and spawn replenishment task if not running."""
|
|
260
|
+
await super().start()
|
|
261
|
+
|
|
262
|
+
if (
|
|
263
|
+
self.processor
|
|
264
|
+
and isinstance(self.processor, RateLimitedProcessor)
|
|
265
|
+
and not self.processor._replenisher_task
|
|
266
|
+
):
|
|
267
|
+
import asyncio
|
|
268
|
+
|
|
269
|
+
self.processor._replenisher_task = asyncio.create_task(
|
|
270
|
+
self.processor.start_replenishing()
|
|
271
|
+
)
|
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
# Copyright (c) 2025 - 2026, HaiyangLi <quantocean.li at gmail dot com>
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
"""Token bucket rate limiter for controlling request/token throughput."""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import logging
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
|
|
11
|
+
from kronos.utils.concurrency import Lock, current_time, sleep
|
|
12
|
+
|
|
13
|
+
__all__ = ("RateLimitConfig", "TokenBucket")
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass(frozen=True, slots=True)
|
|
19
|
+
class RateLimitConfig:
|
|
20
|
+
"""Immutable configuration for TokenBucket rate limiter.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
capacity: Maximum tokens the bucket can hold.
|
|
24
|
+
refill_rate: Tokens added per second.
|
|
25
|
+
initial_tokens: Starting tokens (defaults to capacity).
|
|
26
|
+
|
|
27
|
+
Example:
|
|
28
|
+
>>> config = RateLimitConfig(capacity=100, refill_rate=10.0)
|
|
29
|
+
>>> bucket = TokenBucket(config)
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
capacity: int
|
|
33
|
+
refill_rate: float
|
|
34
|
+
initial_tokens: int | None = None
|
|
35
|
+
|
|
36
|
+
def __post_init__(self):
|
|
37
|
+
"""Validate configuration parameters."""
|
|
38
|
+
if self.capacity <= 0:
|
|
39
|
+
raise ValueError("capacity must be > 0")
|
|
40
|
+
if self.refill_rate <= 0:
|
|
41
|
+
raise ValueError("refill_rate must be > 0")
|
|
42
|
+
if self.initial_tokens is None:
|
|
43
|
+
object.__setattr__(self, "initial_tokens", self.capacity)
|
|
44
|
+
elif self.initial_tokens < 0:
|
|
45
|
+
raise ValueError("initial_tokens must be >= 0")
|
|
46
|
+
elif self.initial_tokens > self.capacity:
|
|
47
|
+
raise ValueError(
|
|
48
|
+
f"initial_tokens ({self.initial_tokens}) cannot exceed capacity ({self.capacity})"
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class TokenBucket:
|
|
53
|
+
"""Token bucket rate limiter with automatic refill.
|
|
54
|
+
|
|
55
|
+
Tokens are consumed on acquire() and refilled continuously based on
|
|
56
|
+
elapsed time. Thread-safe via async lock.
|
|
57
|
+
|
|
58
|
+
Example:
|
|
59
|
+
>>> config = RateLimitConfig(capacity=100, refill_rate=10.0)
|
|
60
|
+
>>> bucket = TokenBucket(config)
|
|
61
|
+
>>> if await bucket.try_acquire(5):
|
|
62
|
+
... # proceed with rate-limited operation
|
|
63
|
+
... pass
|
|
64
|
+
|
|
65
|
+
Attributes:
|
|
66
|
+
capacity: Maximum tokens the bucket can hold.
|
|
67
|
+
refill_rate: Tokens added per second.
|
|
68
|
+
tokens: Current available tokens (float for partial refills).
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
def __init__(self, config: RateLimitConfig):
|
|
72
|
+
"""Initialize bucket from config."""
|
|
73
|
+
self.capacity = config.capacity
|
|
74
|
+
self.refill_rate = config.refill_rate
|
|
75
|
+
assert config.initial_tokens is not None
|
|
76
|
+
self.tokens = float(config.initial_tokens)
|
|
77
|
+
self.last_refill = current_time()
|
|
78
|
+
self._lock = Lock()
|
|
79
|
+
|
|
80
|
+
async def acquire(self, tokens: int = 1, *, timeout: float | None = None) -> bool:
|
|
81
|
+
"""Acquire N tokens, waiting if necessary.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
tokens: Number of tokens to acquire
|
|
85
|
+
timeout: Max wait time in seconds (None = wait forever)
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
True if acquired, False if timeout
|
|
89
|
+
|
|
90
|
+
Raises:
|
|
91
|
+
ValueError: If tokens <= 0 or tokens > capacity
|
|
92
|
+
"""
|
|
93
|
+
if tokens <= 0:
|
|
94
|
+
raise ValueError("tokens must be > 0")
|
|
95
|
+
if tokens > self.capacity:
|
|
96
|
+
raise ValueError(
|
|
97
|
+
f"Cannot acquire {tokens} tokens: exceeds bucket capacity {self.capacity}"
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
start_time = current_time()
|
|
101
|
+
|
|
102
|
+
while True:
|
|
103
|
+
async with self._lock:
|
|
104
|
+
self._refill()
|
|
105
|
+
|
|
106
|
+
if self.tokens >= tokens:
|
|
107
|
+
self.tokens -= tokens
|
|
108
|
+
logger.debug(f"Acquired {tokens} tokens, {self.tokens:.2f} remaining")
|
|
109
|
+
return True
|
|
110
|
+
|
|
111
|
+
deficit = tokens - self.tokens
|
|
112
|
+
wait_time = deficit / self.refill_rate
|
|
113
|
+
|
|
114
|
+
# Check timeout
|
|
115
|
+
if timeout is not None:
|
|
116
|
+
elapsed = current_time() - start_time
|
|
117
|
+
if elapsed + wait_time > timeout:
|
|
118
|
+
logger.warning(f"Rate limit timeout after {elapsed:.2f}s")
|
|
119
|
+
return False
|
|
120
|
+
wait_time = min(wait_time, timeout - elapsed)
|
|
121
|
+
|
|
122
|
+
logger.debug(f"Waiting {wait_time:.2f}s for {deficit:.2f} tokens")
|
|
123
|
+
await sleep(wait_time)
|
|
124
|
+
|
|
125
|
+
def _refill(self) -> None:
|
|
126
|
+
"""Refill tokens based on elapsed time (call under lock)."""
|
|
127
|
+
now = current_time()
|
|
128
|
+
elapsed = now - self.last_refill
|
|
129
|
+
new_tokens = elapsed * self.refill_rate
|
|
130
|
+
|
|
131
|
+
self.tokens = min(self.capacity, self.tokens + new_tokens)
|
|
132
|
+
self.last_refill = now
|
|
133
|
+
|
|
134
|
+
async def try_acquire(self, tokens: int = 1) -> bool:
|
|
135
|
+
"""Try to acquire tokens without waiting.
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
True if acquired immediately, False if insufficient tokens
|
|
139
|
+
|
|
140
|
+
Raises:
|
|
141
|
+
ValueError: If tokens <= 0
|
|
142
|
+
"""
|
|
143
|
+
if tokens <= 0:
|
|
144
|
+
raise ValueError("tokens must be > 0")
|
|
145
|
+
async with self._lock:
|
|
146
|
+
self._refill()
|
|
147
|
+
|
|
148
|
+
if self.tokens >= tokens:
|
|
149
|
+
self.tokens -= tokens
|
|
150
|
+
return True
|
|
151
|
+
return False
|
|
152
|
+
|
|
153
|
+
async def reset(self) -> None:
|
|
154
|
+
"""Reset bucket to full capacity (thread-safe).
|
|
155
|
+
|
|
156
|
+
Used by RateLimitedProcessor for interval-based replenishment.
|
|
157
|
+
"""
|
|
158
|
+
async with self._lock:
|
|
159
|
+
self.tokens = float(self.capacity)
|
|
160
|
+
self.last_refill = current_time()
|
|
161
|
+
|
|
162
|
+
async def release(self, tokens: int = 1) -> None:
|
|
163
|
+
"""Release tokens back to bucket (thread-safe).
|
|
164
|
+
|
|
165
|
+
Used for rollback when dual-bucket acquire fails partway through.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
tokens: Number of tokens to release back.
|
|
169
|
+
|
|
170
|
+
Raises:
|
|
171
|
+
ValueError: If tokens <= 0.
|
|
172
|
+
"""
|
|
173
|
+
if tokens <= 0:
|
|
174
|
+
raise ValueError("tokens must be > 0")
|
|
175
|
+
async with self._lock:
|
|
176
|
+
self.tokens = min(self.capacity, self.tokens + tokens)
|
|
177
|
+
|
|
178
|
+
def to_dict(self) -> dict[str, float]:
|
|
179
|
+
"""Serialize config to dict (excludes runtime state)."""
|
|
180
|
+
return {"capacity": self.capacity, "refill_rate": self.refill_rate}
|