strapi-kit 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- strapi_kit/__init__.py +97 -0
- strapi_kit/__version__.py +15 -0
- strapi_kit/_version.py +34 -0
- strapi_kit/auth/__init__.py +7 -0
- strapi_kit/auth/api_token.py +48 -0
- strapi_kit/cache/__init__.py +5 -0
- strapi_kit/cache/schema_cache.py +211 -0
- strapi_kit/client/__init__.py +11 -0
- strapi_kit/client/async_client.py +1032 -0
- strapi_kit/client/base.py +460 -0
- strapi_kit/client/sync_client.py +980 -0
- strapi_kit/config_provider.py +368 -0
- strapi_kit/exceptions/__init__.py +37 -0
- strapi_kit/exceptions/errors.py +205 -0
- strapi_kit/export/__init__.py +10 -0
- strapi_kit/export/exporter.py +384 -0
- strapi_kit/export/importer.py +619 -0
- strapi_kit/export/media_handler.py +322 -0
- strapi_kit/export/relation_resolver.py +172 -0
- strapi_kit/models/__init__.py +104 -0
- strapi_kit/models/bulk.py +69 -0
- strapi_kit/models/config.py +174 -0
- strapi_kit/models/enums.py +97 -0
- strapi_kit/models/export_format.py +166 -0
- strapi_kit/models/import_options.py +142 -0
- strapi_kit/models/request/__init__.py +1 -0
- strapi_kit/models/request/fields.py +65 -0
- strapi_kit/models/request/filters.py +611 -0
- strapi_kit/models/request/pagination.py +168 -0
- strapi_kit/models/request/populate.py +281 -0
- strapi_kit/models/request/query.py +429 -0
- strapi_kit/models/request/sort.py +147 -0
- strapi_kit/models/response/__init__.py +1 -0
- strapi_kit/models/response/base.py +75 -0
- strapi_kit/models/response/component.py +67 -0
- strapi_kit/models/response/media.py +91 -0
- strapi_kit/models/response/meta.py +44 -0
- strapi_kit/models/response/normalized.py +168 -0
- strapi_kit/models/response/relation.py +48 -0
- strapi_kit/models/response/v4.py +70 -0
- strapi_kit/models/response/v5.py +57 -0
- strapi_kit/models/schema.py +93 -0
- strapi_kit/operations/__init__.py +16 -0
- strapi_kit/operations/media.py +226 -0
- strapi_kit/operations/streaming.py +144 -0
- strapi_kit/parsers/__init__.py +5 -0
- strapi_kit/parsers/version_detecting.py +171 -0
- strapi_kit/protocols.py +455 -0
- strapi_kit/utils/__init__.py +15 -0
- strapi_kit/utils/rate_limiter.py +201 -0
- strapi_kit/utils/uid.py +88 -0
- strapi_kit-0.0.1.dist-info/METADATA +1098 -0
- strapi_kit-0.0.1.dist-info/RECORD +55 -0
- strapi_kit-0.0.1.dist-info/WHEEL +4 -0
- strapi_kit-0.0.1.dist-info/licenses/LICENSE +21 -0
strapi_kit/protocols.py
ADDED
|
@@ -0,0 +1,455 @@
|
|
|
1
|
+
"""Protocol definitions for dependency injection.
|
|
2
|
+
|
|
3
|
+
This module defines interfaces for core components, enabling:
|
|
4
|
+
- Dependency injection
|
|
5
|
+
- Easy mocking in tests
|
|
6
|
+
- Loose coupling between components
|
|
7
|
+
- Custom implementations
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from typing import TYPE_CHECKING, Any, Literal, Protocol, runtime_checkable
|
|
11
|
+
|
|
12
|
+
import httpx
|
|
13
|
+
|
|
14
|
+
from .models.response.normalized import (
|
|
15
|
+
NormalizedCollectionResponse,
|
|
16
|
+
NormalizedSingleResponse,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
from .models.schema import ContentTypeSchema
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@runtime_checkable
|
|
24
|
+
class AuthProvider(Protocol):
|
|
25
|
+
"""Protocol for authentication providers.
|
|
26
|
+
|
|
27
|
+
Implementations must provide methods to generate auth headers
|
|
28
|
+
and validate credentials.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def get_headers(self) -> dict[str, str]:
|
|
32
|
+
"""Get authentication headers for HTTP requests.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
Dictionary with authentication headers (e.g., Authorization: Bearer ...)
|
|
36
|
+
"""
|
|
37
|
+
...
|
|
38
|
+
|
|
39
|
+
def validate_token(self) -> bool:
|
|
40
|
+
"""Validate that authentication credentials are valid.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
True if credentials are valid, False otherwise
|
|
44
|
+
"""
|
|
45
|
+
...
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@runtime_checkable
|
|
49
|
+
class HTTPClient(Protocol):
|
|
50
|
+
"""Protocol for synchronous HTTP clients.
|
|
51
|
+
|
|
52
|
+
Defines the interface for making HTTP requests in sync mode.
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
def request(
|
|
56
|
+
self,
|
|
57
|
+
method: str,
|
|
58
|
+
url: str,
|
|
59
|
+
*,
|
|
60
|
+
params: dict[str, Any] | None = None,
|
|
61
|
+
json: dict[str, Any] | None = None,
|
|
62
|
+
headers: dict[str, str] | None = None,
|
|
63
|
+
) -> httpx.Response:
|
|
64
|
+
"""Make an HTTP request.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
method: HTTP method (GET, POST, PUT, DELETE, etc.)
|
|
68
|
+
url: Full URL to request
|
|
69
|
+
params: URL query parameters
|
|
70
|
+
json: JSON request body
|
|
71
|
+
headers: HTTP headers
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
HTTP response object
|
|
75
|
+
"""
|
|
76
|
+
...
|
|
77
|
+
|
|
78
|
+
def post(
|
|
79
|
+
self,
|
|
80
|
+
url: str,
|
|
81
|
+
*,
|
|
82
|
+
files: dict[str, Any] | None = None,
|
|
83
|
+
data: dict[str, Any] | None = None,
|
|
84
|
+
headers: dict[str, str] | None = None,
|
|
85
|
+
) -> httpx.Response:
|
|
86
|
+
"""Make a POST request with multipart data.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
url: Full URL to request
|
|
90
|
+
files: Files for multipart upload
|
|
91
|
+
data: Form data
|
|
92
|
+
headers: HTTP headers
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
HTTP response object
|
|
96
|
+
"""
|
|
97
|
+
...
|
|
98
|
+
|
|
99
|
+
def stream(
|
|
100
|
+
self,
|
|
101
|
+
method: str,
|
|
102
|
+
url: str,
|
|
103
|
+
**kwargs: Any,
|
|
104
|
+
) -> Any:
|
|
105
|
+
"""Stream an HTTP request.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
method: HTTP method
|
|
109
|
+
url: Full URL to request
|
|
110
|
+
**kwargs: Additional request parameters
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
Context manager for streaming response
|
|
114
|
+
"""
|
|
115
|
+
...
|
|
116
|
+
|
|
117
|
+
def close(self) -> None:
|
|
118
|
+
"""Close the HTTP client and release resources."""
|
|
119
|
+
...
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
@runtime_checkable
|
|
123
|
+
class AsyncHTTPClient(Protocol):
|
|
124
|
+
"""Protocol for asynchronous HTTP clients.
|
|
125
|
+
|
|
126
|
+
Defines the interface for making HTTP requests in async mode.
|
|
127
|
+
"""
|
|
128
|
+
|
|
129
|
+
async def request(
|
|
130
|
+
self,
|
|
131
|
+
method: str,
|
|
132
|
+
url: str,
|
|
133
|
+
*,
|
|
134
|
+
params: dict[str, Any] | None = None,
|
|
135
|
+
json: dict[str, Any] | None = None,
|
|
136
|
+
headers: dict[str, str] | None = None,
|
|
137
|
+
) -> httpx.Response:
|
|
138
|
+
"""Make an async HTTP request.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
method: HTTP method (GET, POST, PUT, DELETE, etc.)
|
|
142
|
+
url: Full URL to request
|
|
143
|
+
params: URL query parameters
|
|
144
|
+
json: JSON request body
|
|
145
|
+
headers: HTTP headers
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
HTTP response object
|
|
149
|
+
"""
|
|
150
|
+
...
|
|
151
|
+
|
|
152
|
+
async def post(
|
|
153
|
+
self,
|
|
154
|
+
url: str,
|
|
155
|
+
*,
|
|
156
|
+
files: dict[str, Any] | None = None,
|
|
157
|
+
data: dict[str, Any] | None = None,
|
|
158
|
+
headers: dict[str, str] | None = None,
|
|
159
|
+
) -> httpx.Response:
|
|
160
|
+
"""Make an async POST request with multipart data.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
url: Full URL to request
|
|
164
|
+
files: Files for multipart upload
|
|
165
|
+
data: Form data
|
|
166
|
+
headers: HTTP headers
|
|
167
|
+
|
|
168
|
+
Returns:
|
|
169
|
+
HTTP response object
|
|
170
|
+
"""
|
|
171
|
+
...
|
|
172
|
+
|
|
173
|
+
def stream(
|
|
174
|
+
self,
|
|
175
|
+
method: str,
|
|
176
|
+
url: str,
|
|
177
|
+
**kwargs: Any,
|
|
178
|
+
) -> Any:
|
|
179
|
+
"""Stream an async HTTP request.
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
method: HTTP method
|
|
183
|
+
url: Full URL to request
|
|
184
|
+
**kwargs: Additional request parameters
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
Async context manager for streaming response.
|
|
188
|
+
Note: httpx.AsyncClient.stream() is NOT an async method itself -
|
|
189
|
+
it returns an async context manager directly. Use with `async with`.
|
|
190
|
+
"""
|
|
191
|
+
...
|
|
192
|
+
|
|
193
|
+
async def aclose(self) -> None:
|
|
194
|
+
"""Close the HTTP client and release resources."""
|
|
195
|
+
...
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
@runtime_checkable
|
|
199
|
+
class ResponseParser(Protocol):
|
|
200
|
+
"""Protocol for response parsers.
|
|
201
|
+
|
|
202
|
+
Implementations must handle parsing of Strapi responses
|
|
203
|
+
into normalized format.
|
|
204
|
+
"""
|
|
205
|
+
|
|
206
|
+
def parse_single(self, response_data: dict[str, Any]) -> NormalizedSingleResponse:
|
|
207
|
+
"""Parse a single entity response.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
response_data: Raw JSON response from Strapi
|
|
211
|
+
|
|
212
|
+
Returns:
|
|
213
|
+
Normalized single entity response
|
|
214
|
+
"""
|
|
215
|
+
...
|
|
216
|
+
|
|
217
|
+
def parse_collection(self, response_data: dict[str, Any]) -> NormalizedCollectionResponse:
|
|
218
|
+
"""Parse a collection response.
|
|
219
|
+
|
|
220
|
+
Args:
|
|
221
|
+
response_data: Raw JSON response from Strapi
|
|
222
|
+
|
|
223
|
+
Returns:
|
|
224
|
+
Normalized collection response
|
|
225
|
+
"""
|
|
226
|
+
...
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
@runtime_checkable
|
|
230
|
+
class ConfigProvider(Protocol):
|
|
231
|
+
"""Protocol for configuration providers.
|
|
232
|
+
|
|
233
|
+
Defines the interface for accessing client configuration.
|
|
234
|
+
This allows for alternative config sources (files, databases, etc.)
|
|
235
|
+
while maintaining type safety.
|
|
236
|
+
"""
|
|
237
|
+
|
|
238
|
+
def get_base_url(self) -> str:
|
|
239
|
+
"""Get the base URL of the Strapi instance.
|
|
240
|
+
|
|
241
|
+
Returns:
|
|
242
|
+
Base URL (without trailing slash)
|
|
243
|
+
"""
|
|
244
|
+
...
|
|
245
|
+
|
|
246
|
+
def get_api_token(self) -> str:
|
|
247
|
+
"""Get the API token for authentication.
|
|
248
|
+
|
|
249
|
+
Returns:
|
|
250
|
+
API token string
|
|
251
|
+
"""
|
|
252
|
+
...
|
|
253
|
+
|
|
254
|
+
@property
|
|
255
|
+
def api_version(self) -> Literal["v4", "v5", "auto"]:
|
|
256
|
+
"""Get the configured API version.
|
|
257
|
+
|
|
258
|
+
Returns:
|
|
259
|
+
API version ("v4", "v5", or "auto")
|
|
260
|
+
"""
|
|
261
|
+
...
|
|
262
|
+
|
|
263
|
+
@property
|
|
264
|
+
def timeout(self) -> float:
|
|
265
|
+
"""Get request timeout in seconds.
|
|
266
|
+
|
|
267
|
+
Returns:
|
|
268
|
+
Timeout value
|
|
269
|
+
"""
|
|
270
|
+
...
|
|
271
|
+
|
|
272
|
+
@property
|
|
273
|
+
def max_connections(self) -> int:
|
|
274
|
+
"""Get maximum concurrent connections.
|
|
275
|
+
|
|
276
|
+
Returns:
|
|
277
|
+
Max connections count
|
|
278
|
+
"""
|
|
279
|
+
...
|
|
280
|
+
|
|
281
|
+
@property
|
|
282
|
+
def verify_ssl(self) -> bool:
|
|
283
|
+
"""Get SSL verification setting.
|
|
284
|
+
|
|
285
|
+
Returns:
|
|
286
|
+
Whether to verify SSL certificates
|
|
287
|
+
"""
|
|
288
|
+
...
|
|
289
|
+
|
|
290
|
+
@property
|
|
291
|
+
def retry(self) -> Any:
|
|
292
|
+
"""Get retry configuration.
|
|
293
|
+
|
|
294
|
+
Returns:
|
|
295
|
+
Retry config object
|
|
296
|
+
"""
|
|
297
|
+
...
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
@runtime_checkable
|
|
301
|
+
class StrapiClient(Protocol):
|
|
302
|
+
"""Protocol for Strapi client implementations.
|
|
303
|
+
|
|
304
|
+
Defines the interface that both SyncClient and AsyncClient implement,
|
|
305
|
+
allowing for type-safe dependency injection in export/import modules.
|
|
306
|
+
|
|
307
|
+
Note: This protocol defines the sync version. Async methods follow the
|
|
308
|
+
same signature but are awaitable.
|
|
309
|
+
"""
|
|
310
|
+
|
|
311
|
+
@property
|
|
312
|
+
def base_url(self) -> str:
|
|
313
|
+
"""Get the base URL of the Strapi instance."""
|
|
314
|
+
...
|
|
315
|
+
|
|
316
|
+
@property
|
|
317
|
+
def api_version(self) -> str | None:
|
|
318
|
+
"""Get the detected or configured API version."""
|
|
319
|
+
...
|
|
320
|
+
|
|
321
|
+
def get_one(
|
|
322
|
+
self,
|
|
323
|
+
endpoint: str,
|
|
324
|
+
query: Any = None,
|
|
325
|
+
headers: dict[str, str] | None = None,
|
|
326
|
+
) -> NormalizedSingleResponse:
|
|
327
|
+
"""Get a single entity.
|
|
328
|
+
|
|
329
|
+
Args:
|
|
330
|
+
endpoint: API endpoint path
|
|
331
|
+
query: Optional query configuration
|
|
332
|
+
headers: Additional headers
|
|
333
|
+
|
|
334
|
+
Returns:
|
|
335
|
+
Normalized single entity response
|
|
336
|
+
"""
|
|
337
|
+
...
|
|
338
|
+
|
|
339
|
+
def get_many(
|
|
340
|
+
self,
|
|
341
|
+
endpoint: str,
|
|
342
|
+
query: Any = None,
|
|
343
|
+
headers: dict[str, str] | None = None,
|
|
344
|
+
) -> NormalizedCollectionResponse:
|
|
345
|
+
"""Get multiple entities.
|
|
346
|
+
|
|
347
|
+
Args:
|
|
348
|
+
endpoint: API endpoint path
|
|
349
|
+
query: Optional query configuration
|
|
350
|
+
headers: Additional headers
|
|
351
|
+
|
|
352
|
+
Returns:
|
|
353
|
+
Normalized collection response
|
|
354
|
+
"""
|
|
355
|
+
...
|
|
356
|
+
|
|
357
|
+
def create(
|
|
358
|
+
self,
|
|
359
|
+
endpoint: str,
|
|
360
|
+
data: dict[str, Any],
|
|
361
|
+
query: Any = None,
|
|
362
|
+
headers: dict[str, str] | None = None,
|
|
363
|
+
) -> NormalizedSingleResponse:
|
|
364
|
+
"""Create a new entity.
|
|
365
|
+
|
|
366
|
+
Args:
|
|
367
|
+
endpoint: API endpoint path
|
|
368
|
+
data: Entity data to create
|
|
369
|
+
query: Optional query configuration
|
|
370
|
+
headers: Additional headers
|
|
371
|
+
|
|
372
|
+
Returns:
|
|
373
|
+
Normalized single entity response
|
|
374
|
+
"""
|
|
375
|
+
...
|
|
376
|
+
|
|
377
|
+
def update(
|
|
378
|
+
self,
|
|
379
|
+
endpoint: str,
|
|
380
|
+
data: dict[str, Any],
|
|
381
|
+
query: Any = None,
|
|
382
|
+
headers: dict[str, str] | None = None,
|
|
383
|
+
) -> NormalizedSingleResponse:
|
|
384
|
+
"""Update an existing entity.
|
|
385
|
+
|
|
386
|
+
Args:
|
|
387
|
+
endpoint: API endpoint path
|
|
388
|
+
data: Entity data to update
|
|
389
|
+
query: Optional query configuration
|
|
390
|
+
headers: Additional headers
|
|
391
|
+
|
|
392
|
+
Returns:
|
|
393
|
+
Normalized single entity response
|
|
394
|
+
"""
|
|
395
|
+
...
|
|
396
|
+
|
|
397
|
+
def remove(
|
|
398
|
+
self,
|
|
399
|
+
endpoint: str,
|
|
400
|
+
headers: dict[str, str] | None = None,
|
|
401
|
+
) -> NormalizedSingleResponse:
|
|
402
|
+
"""Delete an entity.
|
|
403
|
+
|
|
404
|
+
Args:
|
|
405
|
+
endpoint: API endpoint path
|
|
406
|
+
headers: Additional headers
|
|
407
|
+
|
|
408
|
+
Returns:
|
|
409
|
+
Normalized single entity response
|
|
410
|
+
"""
|
|
411
|
+
...
|
|
412
|
+
|
|
413
|
+
|
|
414
|
+
@runtime_checkable
|
|
415
|
+
class SchemaProvider(Protocol):
|
|
416
|
+
"""Protocol for content type schema providers.
|
|
417
|
+
|
|
418
|
+
Defines the interface for accessing and caching content type schemas.
|
|
419
|
+
Enables proper relation resolution during export/import operations.
|
|
420
|
+
"""
|
|
421
|
+
|
|
422
|
+
def get_schema(self, content_type: str) -> "ContentTypeSchema":
|
|
423
|
+
"""Get schema for a content type.
|
|
424
|
+
|
|
425
|
+
Args:
|
|
426
|
+
content_type: Content type UID (e.g., "api::article.article")
|
|
427
|
+
|
|
428
|
+
Returns:
|
|
429
|
+
Content type schema
|
|
430
|
+
"""
|
|
431
|
+
...
|
|
432
|
+
|
|
433
|
+
def cache_schema(self, content_type: str, schema: "ContentTypeSchema") -> None:
|
|
434
|
+
"""Cache schema for a content type.
|
|
435
|
+
|
|
436
|
+
Args:
|
|
437
|
+
content_type: Content type UID
|
|
438
|
+
schema: Schema to cache
|
|
439
|
+
"""
|
|
440
|
+
...
|
|
441
|
+
|
|
442
|
+
def clear_cache(self) -> None:
|
|
443
|
+
"""Clear all cached schemas."""
|
|
444
|
+
...
|
|
445
|
+
|
|
446
|
+
def has_schema(self, content_type: str) -> bool:
|
|
447
|
+
"""Check if schema is cached.
|
|
448
|
+
|
|
449
|
+
Args:
|
|
450
|
+
content_type: Content type UID
|
|
451
|
+
|
|
452
|
+
Returns:
|
|
453
|
+
True if schema is cached, False otherwise
|
|
454
|
+
"""
|
|
455
|
+
...
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""Utility modules for strapi-kit.
|
|
2
|
+
|
|
3
|
+
This package contains helper utilities including:
|
|
4
|
+
- Rate limiting
|
|
5
|
+
- UID handling
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from strapi_kit.utils.rate_limiter import AsyncTokenBucketRateLimiter, TokenBucketRateLimiter
|
|
9
|
+
from strapi_kit.utils.uid import uid_to_endpoint
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"TokenBucketRateLimiter",
|
|
13
|
+
"AsyncTokenBucketRateLimiter",
|
|
14
|
+
"uid_to_endpoint",
|
|
15
|
+
]
|
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
"""Rate limiting utilities using token bucket algorithm.
|
|
2
|
+
|
|
3
|
+
Provides both synchronous and asynchronous rate limiters to control
|
|
4
|
+
request rates when communicating with the Strapi API.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import asyncio
|
|
8
|
+
import logging
|
|
9
|
+
import threading
|
|
10
|
+
import time
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class TokenBucketRateLimiter:
|
|
16
|
+
"""Synchronous rate limiter using token bucket algorithm.
|
|
17
|
+
|
|
18
|
+
The token bucket algorithm allows bursting up to the bucket capacity
|
|
19
|
+
while maintaining the specified average rate over time.
|
|
20
|
+
|
|
21
|
+
Example:
|
|
22
|
+
>>> limiter = TokenBucketRateLimiter(rate=5.0) # 5 requests per second
|
|
23
|
+
>>> for _ in range(10):
|
|
24
|
+
... limiter.acquire()
|
|
25
|
+
... # make_request()
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
rate: float,
|
|
31
|
+
capacity: float | None = None,
|
|
32
|
+
) -> None:
|
|
33
|
+
"""Initialize the rate limiter.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
rate: Maximum requests per second
|
|
37
|
+
capacity: Bucket capacity (defaults to rate for 1 second burst)
|
|
38
|
+
"""
|
|
39
|
+
if rate <= 0:
|
|
40
|
+
raise ValueError("Rate must be positive")
|
|
41
|
+
|
|
42
|
+
self._rate = rate
|
|
43
|
+
self._capacity = capacity if capacity is not None else rate
|
|
44
|
+
self._tokens = self._capacity
|
|
45
|
+
self._last_update = time.monotonic()
|
|
46
|
+
self._lock = threading.Lock()
|
|
47
|
+
|
|
48
|
+
logger.debug(f"Rate limiter initialized: {rate}/s, capacity: {self._capacity}")
|
|
49
|
+
|
|
50
|
+
def acquire(self, tokens: float = 1.0, blocking: bool = True) -> bool:
|
|
51
|
+
"""Acquire tokens from the bucket.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
tokens: Number of tokens to acquire (default: 1)
|
|
55
|
+
blocking: Whether to block until tokens are available (default: True)
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
True if tokens were acquired, False if non-blocking and not available
|
|
59
|
+
"""
|
|
60
|
+
with self._lock:
|
|
61
|
+
while True:
|
|
62
|
+
self._refill()
|
|
63
|
+
|
|
64
|
+
if self._tokens >= tokens:
|
|
65
|
+
self._tokens -= tokens
|
|
66
|
+
return True
|
|
67
|
+
|
|
68
|
+
if not blocking:
|
|
69
|
+
return False
|
|
70
|
+
|
|
71
|
+
# Calculate wait time until enough tokens are available
|
|
72
|
+
tokens_needed = tokens - self._tokens
|
|
73
|
+
wait_time = tokens_needed / self._rate
|
|
74
|
+
|
|
75
|
+
# Release lock while sleeping
|
|
76
|
+
self._lock.release()
|
|
77
|
+
try:
|
|
78
|
+
time.sleep(wait_time)
|
|
79
|
+
finally:
|
|
80
|
+
self._lock.acquire()
|
|
81
|
+
|
|
82
|
+
def _refill(self) -> None:
|
|
83
|
+
"""Refill tokens based on elapsed time."""
|
|
84
|
+
now = time.monotonic()
|
|
85
|
+
elapsed = now - self._last_update
|
|
86
|
+
self._last_update = now
|
|
87
|
+
|
|
88
|
+
# Add tokens based on elapsed time
|
|
89
|
+
self._tokens = min(self._capacity, self._tokens + elapsed * self._rate)
|
|
90
|
+
|
|
91
|
+
@property
|
|
92
|
+
def available_tokens(self) -> float:
|
|
93
|
+
"""Get current available tokens."""
|
|
94
|
+
with self._lock:
|
|
95
|
+
self._refill()
|
|
96
|
+
return self._tokens
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class AsyncTokenBucketRateLimiter:
|
|
100
|
+
"""Asynchronous rate limiter using token bucket algorithm.
|
|
101
|
+
|
|
102
|
+
The token bucket algorithm allows bursting up to the bucket capacity
|
|
103
|
+
while maintaining the specified average rate over time.
|
|
104
|
+
|
|
105
|
+
Example:
|
|
106
|
+
>>> limiter = AsyncTokenBucketRateLimiter(rate=5.0) # 5 requests per second
|
|
107
|
+
>>> for _ in range(10):
|
|
108
|
+
... await limiter.acquire()
|
|
109
|
+
... # await make_request()
|
|
110
|
+
"""
|
|
111
|
+
|
|
112
|
+
def __init__(
|
|
113
|
+
self,
|
|
114
|
+
rate: float,
|
|
115
|
+
capacity: float | None = None,
|
|
116
|
+
) -> None:
|
|
117
|
+
"""Initialize the rate limiter.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
rate: Maximum requests per second
|
|
121
|
+
capacity: Bucket capacity (defaults to rate for 1 second burst)
|
|
122
|
+
"""
|
|
123
|
+
if rate <= 0:
|
|
124
|
+
raise ValueError("Rate must be positive")
|
|
125
|
+
|
|
126
|
+
self._rate = rate
|
|
127
|
+
self._capacity = capacity if capacity is not None else rate
|
|
128
|
+
self._tokens = self._capacity
|
|
129
|
+
self._last_update = time.monotonic()
|
|
130
|
+
self._lock = asyncio.Lock()
|
|
131
|
+
|
|
132
|
+
logger.debug(f"Async rate limiter initialized: {rate}/s, capacity: {self._capacity}")
|
|
133
|
+
|
|
134
|
+
async def acquire(self, tokens: float = 1.0, blocking: bool = True) -> bool:
|
|
135
|
+
"""Acquire tokens from the bucket.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
tokens: Number of tokens to acquire (default: 1)
|
|
139
|
+
blocking: Whether to block until tokens are available (default: True)
|
|
140
|
+
|
|
141
|
+
Returns:
|
|
142
|
+
True if tokens were acquired, False if non-blocking and not available
|
|
143
|
+
"""
|
|
144
|
+
async with self._lock:
|
|
145
|
+
while True:
|
|
146
|
+
self._refill()
|
|
147
|
+
|
|
148
|
+
if self._tokens >= tokens:
|
|
149
|
+
self._tokens -= tokens
|
|
150
|
+
return True
|
|
151
|
+
|
|
152
|
+
if not blocking:
|
|
153
|
+
return False
|
|
154
|
+
|
|
155
|
+
# Calculate wait time until enough tokens are available
|
|
156
|
+
tokens_needed = tokens - self._tokens
|
|
157
|
+
wait_time = tokens_needed / self._rate
|
|
158
|
+
|
|
159
|
+
# Release lock while sleeping
|
|
160
|
+
self._lock.release()
|
|
161
|
+
try:
|
|
162
|
+
await asyncio.sleep(wait_time)
|
|
163
|
+
finally:
|
|
164
|
+
await self._lock.acquire()
|
|
165
|
+
|
|
166
|
+
def _refill(self) -> None:
|
|
167
|
+
"""Refill tokens based on elapsed time."""
|
|
168
|
+
now = time.monotonic()
|
|
169
|
+
elapsed = now - self._last_update
|
|
170
|
+
self._last_update = now
|
|
171
|
+
|
|
172
|
+
# Add tokens based on elapsed time
|
|
173
|
+
self._tokens = min(self._capacity, self._tokens + elapsed * self._rate)
|
|
174
|
+
|
|
175
|
+
@property
|
|
176
|
+
def available_tokens(self) -> float:
|
|
177
|
+
"""Get current available tokens (requires holding the lock)."""
|
|
178
|
+
# Note: This is not thread-safe without holding the lock
|
|
179
|
+
self._refill()
|
|
180
|
+
return self._tokens
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def create_rate_limiter(
|
|
184
|
+
rate_per_second: float | None,
|
|
185
|
+
async_mode: bool = False,
|
|
186
|
+
) -> TokenBucketRateLimiter | AsyncTokenBucketRateLimiter | None:
|
|
187
|
+
"""Factory function to create appropriate rate limiter.
|
|
188
|
+
|
|
189
|
+
Args:
|
|
190
|
+
rate_per_second: Rate limit (None to disable)
|
|
191
|
+
async_mode: Whether to create async limiter
|
|
192
|
+
|
|
193
|
+
Returns:
|
|
194
|
+
Rate limiter instance or None if rate is None
|
|
195
|
+
"""
|
|
196
|
+
if rate_per_second is None:
|
|
197
|
+
return None
|
|
198
|
+
|
|
199
|
+
if async_mode:
|
|
200
|
+
return AsyncTokenBucketRateLimiter(rate=rate_per_second)
|
|
201
|
+
return TokenBucketRateLimiter(rate=rate_per_second)
|