queue-max 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- queue_max/__init__.py +62 -0
- queue_max/cli.py +373 -0
- queue_max/contrib/__init__.py +7 -0
- queue_max/contrib/django/__init__.py +61 -0
- queue_max/contrib/django/management/__init__.py +0 -0
- queue_max/contrib/django/management/commands/__init__.py +0 -0
- queue_max/contrib/django/management/commands/queue_purge.py +19 -0
- queue_max/contrib/django/management/commands/queue_stats.py +39 -0
- queue_max/contrib/django/management/commands/queue_worker.py +69 -0
- queue_max/contrib/fastapi/__init__.py +117 -0
- queue_max/contrib/flask/__init__.py +99 -0
- queue_max/core/__init__.py +16 -0
- queue_max/core/circuit_breaker.py +162 -0
- queue_max/core/database.py +253 -0
- queue_max/core/decorator.py +346 -0
- queue_max/core/queue.py +420 -0
- queue_max/core/rate_limiter.py +214 -0
- queue_max/core/worker.py +426 -0
- queue_max/exceptions.py +25 -0
- queue_max/models/__init__.py +5 -0
- queue_max/models/job.py +340 -0
- queue_max/py.typed +0 -0
- queue_max/utils/__init__.py +23 -0
- queue_max/utils/helpers.py +156 -0
- queue_max-0.1.0.dist-info/METADATA +233 -0
- queue_max-0.1.0.dist-info/RECORD +30 -0
- queue_max-0.1.0.dist-info/WHEEL +5 -0
- queue_max-0.1.0.dist-info/entry_points.txt +2 -0
- queue_max-0.1.0.dist-info/licenses/LICENSE +21 -0
- queue_max-0.1.0.dist-info/top_level.txt +1 -0
queue_max/core/queue.py
ADDED
|
@@ -0,0 +1,420 @@
|
|
|
1
|
+
"""Main Queue class for Queue Max.
|
|
2
|
+
|
|
3
|
+
Provides the primary API for enqueuing, processing, and managing jobs
|
|
4
|
+
with support for sharding, rate limiting, circuit breaker, and monitoring.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
import os
|
|
9
|
+
import random
|
|
10
|
+
import threading
|
|
11
|
+
import time
|
|
12
|
+
from contextlib import contextmanager
|
|
13
|
+
from typing import Any, Callable, Dict, Generator, List, Optional
|
|
14
|
+
|
|
15
|
+
from queue_max.core.circuit_breaker import CircuitBreaker
|
|
16
|
+
from queue_max.core.database import DATA_DIR, NUM_SHARDS, ShardManager
|
|
17
|
+
from queue_max.core.rate_limiter import RateLimiter
|
|
18
|
+
from queue_max.exceptions import QueueError
|
|
19
|
+
from queue_max.models.job import Job
|
|
20
|
+
from queue_max.utils.helpers import (
|
|
21
|
+
determine_shard,
|
|
22
|
+
get_env_int,
|
|
23
|
+
validate_payload,
|
|
24
|
+
validate_priority,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger("queue_max")
|
|
28
|
+
|
|
29
|
+
class Queue:
|
|
30
|
+
"""Main queue for managing and processing jobs.
|
|
31
|
+
|
|
32
|
+
Provides a persistent, sharded task queue backed by SQLite
|
|
33
|
+
with rate limiting, circuit breaker, and automatic recovery.
|
|
34
|
+
|
|
35
|
+
Example:
|
|
36
|
+
>>> queue = Queue(rate_limit=100)
|
|
37
|
+
>>> queue.enqueue({"task": "send_email", "to": "user@example.com"})
|
|
38
|
+
>>> stats = queue.get_stats()
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
shards: Optional[int] = None,
|
|
44
|
+
rate_limit: Optional[int] = None,
|
|
45
|
+
max_retries: Optional[int] = None,
|
|
46
|
+
data_dir: Optional[str] = None,
|
|
47
|
+
circuit_breaker_threshold: Optional[int] = None,
|
|
48
|
+
circuit_breaker_timeout: Optional[float] = None,
|
|
49
|
+
):
|
|
50
|
+
"""Initialize the queue.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
shards: Number of shards (default: NUM_SHARDS env or 6).
|
|
54
|
+
rate_limit: Requests per minute (default: RATE_LIMIT_MAX env or 160).
|
|
55
|
+
max_retries: Maximum retry attempts (default: QUEUE_MAX_RETRIES env or 3).
|
|
56
|
+
data_dir: Directory for shard files (default: DATA_DIR env or ./data).
|
|
57
|
+
circuit_breaker_threshold: Failures before circuit opens (default: 5).
|
|
58
|
+
circuit_breaker_timeout: Seconds before recovery attempt (default: 60).
|
|
59
|
+
"""
|
|
60
|
+
self.num_shards = shards or get_env_int("NUM_SHARDS", NUM_SHARDS)
|
|
61
|
+
effective_rate_limit = rate_limit or get_env_int("RATE_LIMIT_MAX", 160)
|
|
62
|
+
self.max_retries = max_retries or get_env_int("QUEUE_MAX_RETRIES", 3)
|
|
63
|
+
self.data_dir = data_dir or os.environ.get("DATA_DIR", DATA_DIR)
|
|
64
|
+
|
|
65
|
+
self.shard_manager = ShardManager(self.num_shards, self.data_dir)
|
|
66
|
+
self.rate_limiter = RateLimiter(effective_rate_limit)
|
|
67
|
+
self.circuit_breaker = CircuitBreaker(
|
|
68
|
+
failure_threshold=circuit_breaker_threshold or 5,
|
|
69
|
+
recovery_timeout=circuit_breaker_timeout or 60.0,
|
|
70
|
+
)
|
|
71
|
+
self._start_time = time.time()
|
|
72
|
+
self._pop_lock = threading.Lock()
|
|
73
|
+
|
|
74
|
+
self._events: Dict[str, List[Callable]] = {
|
|
75
|
+
"job_enqueued": [],
|
|
76
|
+
"job_completed": [],
|
|
77
|
+
"job_failed": [],
|
|
78
|
+
"job_retried": [],
|
|
79
|
+
"alert": [],
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def is_healthy(self) -> bool:
|
|
84
|
+
"""Check if queue is healthy (circuit not open)."""
|
|
85
|
+
try:
|
|
86
|
+
return self.circuit_breaker.state.value != "open"
|
|
87
|
+
except Exception:
|
|
88
|
+
return False
|
|
89
|
+
|
|
90
|
+
def on(self, event: str, callback: Callable) -> "Queue":
|
|
91
|
+
"""Register an event listener.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
event: Event name ('job_enqueued', 'job_completed', 'job_failed', 'alert').
|
|
95
|
+
callback: Function to call when event is emitted.
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
Self for method chaining.
|
|
99
|
+
"""
|
|
100
|
+
if event not in self._events:
|
|
101
|
+
raise ValueError(f"Unknown event: {event}. Valid events: {list(self._events.keys())}")
|
|
102
|
+
self._events[event].append(callback)
|
|
103
|
+
return self
|
|
104
|
+
|
|
105
|
+
def _emit(self, event: str, **data: Any) -> None:
|
|
106
|
+
"""Emit an event to all registered listeners."""
|
|
107
|
+
for callback in self._events[event]:
|
|
108
|
+
try:
|
|
109
|
+
callback(**data)
|
|
110
|
+
except Exception:
|
|
111
|
+
logger.exception(f"Error in event handler for {event}")
|
|
112
|
+
|
|
113
|
+
@contextmanager
|
|
114
|
+
def batch(self) -> Generator[None, None, None]:
|
|
115
|
+
"""Context manager for batch operations (disables events temporarily)."""
|
|
116
|
+
original_emit = self._emit
|
|
117
|
+
self._emit = lambda event, **data: None
|
|
118
|
+
try:
|
|
119
|
+
yield
|
|
120
|
+
finally:
|
|
121
|
+
self._emit = original_emit
|
|
122
|
+
|
|
123
|
+
def enqueue(
|
|
124
|
+
self,
|
|
125
|
+
payload: Dict[str, Any],
|
|
126
|
+
pagina_id: Optional[int] = None,
|
|
127
|
+
priority: int = 0,
|
|
128
|
+
max_retries: Optional[int] = None,
|
|
129
|
+
) -> Dict[str, Any]:
|
|
130
|
+
"""Enqueue a job.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
payload: Job payload (JSON-serializable dict).
|
|
134
|
+
pagina_id: Optional ID for consistent sharding.
|
|
135
|
+
priority: Priority (0=low, 1=medium, 2=high).
|
|
136
|
+
max_retries: Maximum retry attempts (default: instance default).
|
|
137
|
+
|
|
138
|
+
Returns:
|
|
139
|
+
Dict with 'id' (job ID) and 'shard_id' (assigned shard).
|
|
140
|
+
|
|
141
|
+
Raises:
|
|
142
|
+
ValueError: If payload is not a dict or priority is invalid.
|
|
143
|
+
"""
|
|
144
|
+
payload = validate_payload(payload)
|
|
145
|
+
priority = validate_priority(priority)
|
|
146
|
+
max_retries = max_retries or self.max_retries
|
|
147
|
+
|
|
148
|
+
shard_id = determine_shard(pagina_id, self.num_shards)
|
|
149
|
+
job_id = self.shard_manager.insert_job(shard_id, payload, pagina_id, priority, max_retries)
|
|
150
|
+
|
|
151
|
+
self._emit("job_enqueued", job_id=job_id, shard_id=shard_id)
|
|
152
|
+
|
|
153
|
+
alert_threshold = get_env_int("QUEUE_ALERT_THRESHOLD", 1000)
|
|
154
|
+
pending = self.shard_manager.get_stats(shard_id)["pending"]
|
|
155
|
+
if pending > alert_threshold:
|
|
156
|
+
self._emit("alert", type="QUEUE_SIZE", pending=pending, threshold=alert_threshold)
|
|
157
|
+
|
|
158
|
+
return {"id": job_id, "shard_id": shard_id}
|
|
159
|
+
|
|
160
|
+
def enqueue_batch(self, jobs: List[Dict[str, Any]]) -> Dict[str, int]:
|
|
161
|
+
"""Enqueue multiple jobs in a batch.
|
|
162
|
+
|
|
163
|
+
Each job dict must contain at least 'payload'.
|
|
164
|
+
Optional keys: 'pagina_id', 'priority', 'max_retries'.
|
|
165
|
+
|
|
166
|
+
Returns:
|
|
167
|
+
Dict with 'total' enqueued count.
|
|
168
|
+
"""
|
|
169
|
+
total = 0
|
|
170
|
+
with self.batch():
|
|
171
|
+
for job in jobs:
|
|
172
|
+
self.enqueue(
|
|
173
|
+
payload=job["payload"],
|
|
174
|
+
pagina_id=job.get("pagina_id"),
|
|
175
|
+
priority=job.get("priority", 0),
|
|
176
|
+
max_retries=job.get("max_retries"),
|
|
177
|
+
)
|
|
178
|
+
total += 1
|
|
179
|
+
return {"total": total}
|
|
180
|
+
|
|
181
|
+
def enqueue_from_file(self, filepath: str, fmt: str = "jsonl") -> Dict[str, int]:
|
|
182
|
+
"""Enqueue jobs from a file (JSON Lines format).
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
filepath: Path to the file.
|
|
186
|
+
fmt: File format ('jsonl' or 'csv').
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
Dict with 'total' enqueued count.
|
|
190
|
+
"""
|
|
191
|
+
import csv
|
|
192
|
+
import json
|
|
193
|
+
|
|
194
|
+
total = 0
|
|
195
|
+
if fmt == "jsonl":
|
|
196
|
+
with open(filepath) as f:
|
|
197
|
+
for line in f:
|
|
198
|
+
line = line.strip()
|
|
199
|
+
if line:
|
|
200
|
+
job = json.loads(line)
|
|
201
|
+
self.enqueue(**job)
|
|
202
|
+
total += 1
|
|
203
|
+
elif fmt == "csv":
|
|
204
|
+
with open(filepath) as f:
|
|
205
|
+
for row in csv.DictReader(f):
|
|
206
|
+
payload = {k: v for k, v in row.items() if k != "priority"}
|
|
207
|
+
priority = int(row.get("priority", 0))
|
|
208
|
+
self.enqueue(payload=payload, priority=priority)
|
|
209
|
+
total += 1
|
|
210
|
+
else:
|
|
211
|
+
raise ValueError(f"Unsupported format: {fmt}")
|
|
212
|
+
return {"total": total}
|
|
213
|
+
|
|
214
|
+
def pop_job(self, worker_id: str) -> Optional[Job]:
|
|
215
|
+
"""Atomically pop the next job from the queue.
|
|
216
|
+
|
|
217
|
+
Scans shards in random order. Respects rate limiting and circuit breaker.
|
|
218
|
+
|
|
219
|
+
Args:
|
|
220
|
+
worker_id: Unique identifier for the calling worker.
|
|
221
|
+
|
|
222
|
+
Returns:
|
|
223
|
+
A Job if one is available, None if the queue is empty.
|
|
224
|
+
"""
|
|
225
|
+
try:
|
|
226
|
+
self.rate_limiter.acquire()
|
|
227
|
+
except Exception:
|
|
228
|
+
return None
|
|
229
|
+
try:
|
|
230
|
+
self.circuit_breaker.call(lambda: None)
|
|
231
|
+
except Exception:
|
|
232
|
+
return None
|
|
233
|
+
|
|
234
|
+
with self._pop_lock:
|
|
235
|
+
shard_order = list(range(self.num_shards))
|
|
236
|
+
random.shuffle(shard_order)
|
|
237
|
+
|
|
238
|
+
for shard_id in shard_order:
|
|
239
|
+
try:
|
|
240
|
+
job = self.shard_manager.pop_job(shard_id, worker_id)
|
|
241
|
+
if job is not None:
|
|
242
|
+
return job
|
|
243
|
+
except Exception:
|
|
244
|
+
logger.exception(f"Error popping from shard {shard_id}")
|
|
245
|
+
continue
|
|
246
|
+
return None
|
|
247
|
+
|
|
248
|
+
def complete_job(self, job_id: int, shard_id: int) -> None:
|
|
249
|
+
"""Mark a job as completed and remove it from the queue."""
|
|
250
|
+
self.shard_manager.complete_job(shard_id, job_id)
|
|
251
|
+
self._emit("job_completed", job_id=job_id, shard_id=shard_id)
|
|
252
|
+
|
|
253
|
+
def fail_job(
|
|
254
|
+
self,
|
|
255
|
+
job_id: int,
|
|
256
|
+
shard_id: int,
|
|
257
|
+
error: Exception,
|
|
258
|
+
permanent: bool = False,
|
|
259
|
+
) -> None:
|
|
260
|
+
"""Mark a job as failed with retry/backoff logic.
|
|
261
|
+
|
|
262
|
+
Args:
|
|
263
|
+
job_id: ID of the job to fail.
|
|
264
|
+
shard_id: Shard containing the job.
|
|
265
|
+
error: The exception that caused the failure.
|
|
266
|
+
permanent: If True, fail immediately without retry.
|
|
267
|
+
"""
|
|
268
|
+
self.shard_manager.fail_job(shard_id, job_id, error, permanent=permanent)
|
|
269
|
+
if permanent:
|
|
270
|
+
self._emit("job_failed", job_id=job_id, shard_id=shard_id, error=str(error))
|
|
271
|
+
else:
|
|
272
|
+
self._emit("job_retried", job_id=job_id, shard_id=shard_id, error=str(error))
|
|
273
|
+
|
|
274
|
+
def retry_failed_jobs(self, shard_id: Optional[int] = None) -> int:
|
|
275
|
+
"""Re-queue all failed jobs.
|
|
276
|
+
|
|
277
|
+
Args:
|
|
278
|
+
shard_id: Specific shard, or None for all shards.
|
|
279
|
+
|
|
280
|
+
Returns:
|
|
281
|
+
Total number of jobs retried.
|
|
282
|
+
"""
|
|
283
|
+
total = 0
|
|
284
|
+
if shard_id is not None:
|
|
285
|
+
total = self.shard_manager.retry_failed_jobs(shard_id)
|
|
286
|
+
else:
|
|
287
|
+
for sid in range(self.num_shards):
|
|
288
|
+
total += self.shard_manager.retry_failed_jobs(sid)
|
|
289
|
+
return total
|
|
290
|
+
|
|
291
|
+
def cleanup_old_jobs(self, days: int = 7) -> int:
|
|
292
|
+
"""Remove old jobs from the queue.
|
|
293
|
+
|
|
294
|
+
Args:
|
|
295
|
+
days: Remove jobs older than this many days (default: 7).
|
|
296
|
+
|
|
297
|
+
Returns:
|
|
298
|
+
Total number of jobs removed.
|
|
299
|
+
"""
|
|
300
|
+
total = 0
|
|
301
|
+
for shard_id in range(self.num_shards):
|
|
302
|
+
total += self.shard_manager.cleanup_old_jobs(shard_id, days)
|
|
303
|
+
return total
|
|
304
|
+
|
|
305
|
+
def purge_queue(self, status: Optional[str] = None) -> int:
|
|
306
|
+
"""Purge jobs from the queue.
|
|
307
|
+
|
|
308
|
+
Args:
|
|
309
|
+
status: Filter by status, or None for all.
|
|
310
|
+
|
|
311
|
+
Returns:
|
|
312
|
+
Number of jobs purged.
|
|
313
|
+
"""
|
|
314
|
+
total = 0
|
|
315
|
+
for shard_id in range(self.num_shards):
|
|
316
|
+
conn = self.shard_manager._get_connection(shard_id)
|
|
317
|
+
if status:
|
|
318
|
+
cursor = conn.execute("DELETE FROM fila WHERE status=?", (status,))
|
|
319
|
+
else:
|
|
320
|
+
cursor = conn.execute("DELETE FROM fila")
|
|
321
|
+
total += cursor.rowcount
|
|
322
|
+
conn.commit()
|
|
323
|
+
return total
|
|
324
|
+
|
|
325
|
+
def get_failed_jobs(self, limit: int = 100) -> List[Job]:
|
|
326
|
+
"""Get all failed jobs across all shards."""
|
|
327
|
+
jobs: List[Job] = []
|
|
328
|
+
for shard_id in range(self.num_shards):
|
|
329
|
+
jobs.extend(self.shard_manager.get_failed_jobs(shard_id, limit))
|
|
330
|
+
jobs.sort(key=lambda j: j.id, reverse=True)
|
|
331
|
+
return jobs
|
|
332
|
+
|
|
333
|
+
def get_processing_jobs(self) -> List[Job]:
|
|
334
|
+
"""Get all currently processing jobs."""
|
|
335
|
+
jobs: List[Job] = []
|
|
336
|
+
for shard_id in range(self.num_shards):
|
|
337
|
+
jobs.extend(self.shard_manager.get_processing_jobs(shard_id))
|
|
338
|
+
return jobs
|
|
339
|
+
|
|
340
|
+
def get_pending_count(self) -> int:
|
|
341
|
+
"""Get total number of pending jobs."""
|
|
342
|
+
return self.get_stats().get("pending", 0)
|
|
343
|
+
|
|
344
|
+
def is_empty(self) -> bool:
|
|
345
|
+
"""Check if queue has no pending jobs."""
|
|
346
|
+
return self.get_pending_count() == 0
|
|
347
|
+
|
|
348
|
+
def heartbeat(self, shard_id: int, worker_id: str) -> None:
|
|
349
|
+
"""Update heartbeat for a worker."""
|
|
350
|
+
self.shard_manager.heartbeat(shard_id, worker_id)
|
|
351
|
+
|
|
352
|
+
def recover_orphans(self) -> int:
|
|
353
|
+
"""Recover orphaned jobs (stuck in processing state).
|
|
354
|
+
|
|
355
|
+
Returns:
|
|
356
|
+
Number of jobs recovered.
|
|
357
|
+
"""
|
|
358
|
+
total = 0
|
|
359
|
+
stuck_timeout = get_env_int("STUCK_TIMEOUT", 30000)
|
|
360
|
+
for shard_id in range(self.num_shards):
|
|
361
|
+
total += self.shard_manager.recover_orphans(shard_id, stuck_timeout)
|
|
362
|
+
return total
|
|
363
|
+
|
|
364
|
+
def get_stats(self) -> Dict[str, Any]:
|
|
365
|
+
"""Get comprehensive queue statistics.
|
|
366
|
+
|
|
367
|
+
Returns:
|
|
368
|
+
Dict with pending, processing, failed, and configuration stats.
|
|
369
|
+
"""
|
|
370
|
+
shard_stats = self.shard_manager.get_all_stats()
|
|
371
|
+
rate_stats = self.rate_limiter.get_stats()
|
|
372
|
+
cb_stats = self.circuit_breaker.get_stats()
|
|
373
|
+
return {
|
|
374
|
+
**shard_stats,
|
|
375
|
+
"tokens_available": rate_stats.get("tokens_remaining", 0),
|
|
376
|
+
"circuit_state": cb_stats.get("state", "unknown"),
|
|
377
|
+
"circuit_failures": cb_stats.get("failure_count", 0),
|
|
378
|
+
"num_shards": self.num_shards,
|
|
379
|
+
"rate_limit": rate_stats.get("rate_limit", 0),
|
|
380
|
+
"max_retries": self.max_retries,
|
|
381
|
+
"uptime_seconds": round(time.time() - self._start_time, 2),
|
|
382
|
+
"is_healthy": self.is_healthy,
|
|
383
|
+
}
|
|
384
|
+
|
|
385
|
+
def wait_until_empty(self, timeout: Optional[float] = None) -> bool:
|
|
386
|
+
"""Wait until the queue is empty.
|
|
387
|
+
|
|
388
|
+
Returns:
|
|
389
|
+
True if queue became empty, False if timeout.
|
|
390
|
+
"""
|
|
391
|
+
start = time.time()
|
|
392
|
+
while True:
|
|
393
|
+
if self.is_empty():
|
|
394
|
+
return True
|
|
395
|
+
if timeout and (time.time() - start) > timeout:
|
|
396
|
+
return False
|
|
397
|
+
time.sleep(1)
|
|
398
|
+
|
|
399
|
+
def wait_for_jobs(self, count: int = 1, timeout: Optional[float] = None) -> bool:
|
|
400
|
+
"""Wait for at least N jobs in the queue."""
|
|
401
|
+
start = time.time()
|
|
402
|
+
while True:
|
|
403
|
+
if self.get_pending_count() >= count:
|
|
404
|
+
return True
|
|
405
|
+
if timeout and (time.time() - start) > timeout:
|
|
406
|
+
return False
|
|
407
|
+
time.sleep(0.5)
|
|
408
|
+
|
|
409
|
+
def close(self) -> None:
|
|
410
|
+
"""Close all database connections."""
|
|
411
|
+
self.shard_manager.close_all()
|
|
412
|
+
|
|
413
|
+
def __enter__(self) -> "Queue":
|
|
414
|
+
return self
|
|
415
|
+
|
|
416
|
+
def __exit__(self, *args) -> None:
|
|
417
|
+
self.close()
|
|
418
|
+
|
|
419
|
+
def __repr__(self) -> str:
|
|
420
|
+
return f"Queue(shards={self.num_shards}, rate_limit={self.rate_limiter.rate_limit})"
|
|
@@ -0,0 +1,214 @@
|
|
|
1
|
+
"""Token bucket rate limiter for Queue Max.
|
|
2
|
+
|
|
3
|
+
Thread-safe rate limiter with configurable rate limits, burst capacity,
|
|
4
|
+
and adaptive rate adjustment.
|
|
5
|
+
|
|
6
|
+
Features:
|
|
7
|
+
- Per-minute, per-second, and per-hour limits
|
|
8
|
+
- Burst capacity control
|
|
9
|
+
- Configurable jitter
|
|
10
|
+
- Dynamic rate adjustment
|
|
11
|
+
- Pre-configured limiters for popular APIs
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
import threading
|
|
15
|
+
import time
|
|
16
|
+
from datetime import datetime, timezone
|
|
17
|
+
from enum import Enum
|
|
18
|
+
from typing import Dict, Optional
|
|
19
|
+
|
|
20
|
+
from queue_max.exceptions import RateLimitError
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class RateLimitUnit(Enum):
|
|
24
|
+
"""Time unit for rate limiting."""
|
|
25
|
+
PER_MINUTE = "per_minute"
|
|
26
|
+
PER_SECOND = "per_second"
|
|
27
|
+
PER_HOUR = "per_hour"
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class RateLimiter:
|
|
31
|
+
"""Token bucket rate limiter.
|
|
32
|
+
|
|
33
|
+
Thread-safe rate limiter that distributes requests uniformly over time.
|
|
34
|
+
|
|
35
|
+
Attributes:
|
|
36
|
+
rate_limit: Maximum requests per unit.
|
|
37
|
+
unit: Time unit for rate limiting.
|
|
38
|
+
burst_capacity: Maximum tokens that can accumulate.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
rate_limit: int = 160,
|
|
44
|
+
unit: RateLimitUnit = RateLimitUnit.PER_MINUTE,
|
|
45
|
+
burst_capacity: Optional[int] = None,
|
|
46
|
+
enable_jitter: bool = True,
|
|
47
|
+
):
|
|
48
|
+
"""Initialize the rate limiter.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
rate_limit: Maximum requests per unit (default: 160).
|
|
52
|
+
unit: Time unit (PER_MINUTE, PER_SECOND, PER_HOUR).
|
|
53
|
+
burst_capacity: Maximum tokens that can accumulate (default: rate_limit).
|
|
54
|
+
enable_jitter: Add jitter to refill (default: True).
|
|
55
|
+
"""
|
|
56
|
+
self.rate_limit = rate_limit
|
|
57
|
+
self.unit = unit
|
|
58
|
+
self.burst_capacity = burst_capacity or rate_limit
|
|
59
|
+
self.enable_jitter = enable_jitter
|
|
60
|
+
|
|
61
|
+
if unit == RateLimitUnit.PER_SECOND:
|
|
62
|
+
self.interval = 1.0 / rate_limit
|
|
63
|
+
elif unit == RateLimitUnit.PER_HOUR:
|
|
64
|
+
self.interval = 3600.0 / rate_limit
|
|
65
|
+
else:
|
|
66
|
+
self.interval = 60.0 / rate_limit
|
|
67
|
+
|
|
68
|
+
self._tokens = float(self.burst_capacity)
|
|
69
|
+
self._last_refill = time.monotonic()
|
|
70
|
+
self._mutex = threading.Lock()
|
|
71
|
+
|
|
72
|
+
@classmethod
|
|
73
|
+
def per_second(cls, rate_limit: int, **kwargs) -> "RateLimiter":
|
|
74
|
+
"""Create a per-second rate limiter."""
|
|
75
|
+
return cls(rate_limit, unit=RateLimitUnit.PER_SECOND, **kwargs)
|
|
76
|
+
|
|
77
|
+
@classmethod
|
|
78
|
+
def per_minute(cls, rate_limit: int, **kwargs) -> "RateLimiter":
|
|
79
|
+
"""Create a per-minute rate limiter."""
|
|
80
|
+
return cls(rate_limit, unit=RateLimitUnit.PER_MINUTE, **kwargs)
|
|
81
|
+
|
|
82
|
+
@classmethod
|
|
83
|
+
def per_hour(cls, rate_limit: int, **kwargs) -> "RateLimiter":
|
|
84
|
+
"""Create a per-hour rate limiter."""
|
|
85
|
+
return cls(rate_limit, unit=RateLimitUnit.PER_HOUR, **kwargs)
|
|
86
|
+
|
|
87
|
+
def acquire(self, timeout: float = 30.0) -> bool:
|
|
88
|
+
"""Acquire a token, blocking until one is available.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
timeout: Maximum time to wait in seconds (default: 30.0).
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
True if a token was acquired.
|
|
95
|
+
|
|
96
|
+
Raises:
|
|
97
|
+
RateLimitError: If unable to acquire a token within the timeout.
|
|
98
|
+
"""
|
|
99
|
+
deadline = time.monotonic() + timeout
|
|
100
|
+
while time.monotonic() < deadline:
|
|
101
|
+
if self._try_acquire():
|
|
102
|
+
return True
|
|
103
|
+
remaining = deadline - time.monotonic()
|
|
104
|
+
if remaining <= 0:
|
|
105
|
+
break
|
|
106
|
+
time.sleep(min(self.interval * 0.5, remaining * 0.1))
|
|
107
|
+
raise RateLimitError(
|
|
108
|
+
f"Rate limit exceeded: could not acquire token within {timeout}s "
|
|
109
|
+
f"(limit: {self.rate_limit}/{self.unit.value})"
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
def try_acquire(self) -> bool:
|
|
113
|
+
"""Try to acquire a token without blocking.
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
True if a token was acquired.
|
|
117
|
+
"""
|
|
118
|
+
return self._try_acquire()
|
|
119
|
+
|
|
120
|
+
def _try_acquire(self) -> bool:
|
|
121
|
+
"""Try to acquire a token without blocking."""
|
|
122
|
+
with self._mutex:
|
|
123
|
+
self._refill()
|
|
124
|
+
if self._tokens >= 1.0:
|
|
125
|
+
self._tokens -= 1.0
|
|
126
|
+
return True
|
|
127
|
+
return False
|
|
128
|
+
|
|
129
|
+
def _refill(self) -> None:
|
|
130
|
+
"""Refill tokens based on elapsed time."""
|
|
131
|
+
now = time.monotonic()
|
|
132
|
+
elapsed = now - self._last_refill
|
|
133
|
+
if self.unit == RateLimitUnit.PER_SECOND:
|
|
134
|
+
refill_rate = self.rate_limit
|
|
135
|
+
elif self.unit == RateLimitUnit.PER_HOUR:
|
|
136
|
+
refill_rate = self.rate_limit / 3600.0
|
|
137
|
+
else:
|
|
138
|
+
refill_rate = self.rate_limit / 60.0
|
|
139
|
+
tokens_to_add = elapsed * refill_rate
|
|
140
|
+
if self.enable_jitter and tokens_to_add > 0:
|
|
141
|
+
import random
|
|
142
|
+
jitter = tokens_to_add * 0.1 * random.random()
|
|
143
|
+
tokens_to_add += jitter
|
|
144
|
+
self._tokens = min(self.burst_capacity, self._tokens + tokens_to_add)
|
|
145
|
+
self._last_refill = now
|
|
146
|
+
|
|
147
|
+
def update_rate_limit(self, new_rate_limit: int) -> None:
|
|
148
|
+
"""Dynamically update the rate limit.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
new_rate_limit: New requests per unit limit.
|
|
152
|
+
"""
|
|
153
|
+
with self._mutex:
|
|
154
|
+
ratio = new_rate_limit / self.rate_limit if self.rate_limit > 0 else 1
|
|
155
|
+
self._tokens = min(self.burst_capacity, self._tokens * ratio)
|
|
156
|
+
self.rate_limit = new_rate_limit
|
|
157
|
+
if self.burst_capacity > new_rate_limit:
|
|
158
|
+
self.burst_capacity = new_rate_limit
|
|
159
|
+
self._tokens = min(self._tokens, self.burst_capacity)
|
|
160
|
+
if self.unit == RateLimitUnit.PER_SECOND:
|
|
161
|
+
self.interval = 1.0 / new_rate_limit
|
|
162
|
+
elif self.unit == RateLimitUnit.PER_HOUR:
|
|
163
|
+
self.interval = 3600.0 / new_rate_limit
|
|
164
|
+
else:
|
|
165
|
+
self.interval = 60.0 / new_rate_limit
|
|
166
|
+
|
|
167
|
+
def reset(self) -> None:
|
|
168
|
+
"""Reset the rate limiter to full tokens."""
|
|
169
|
+
with self._mutex:
|
|
170
|
+
self._tokens = float(self.burst_capacity)
|
|
171
|
+
self._last_refill = time.monotonic()
|
|
172
|
+
|
|
173
|
+
def get_remaining_tokens(self) -> float:
|
|
174
|
+
"""Get the number of tokens currently available."""
|
|
175
|
+
with self._mutex:
|
|
176
|
+
self._refill()
|
|
177
|
+
return self._tokens
|
|
178
|
+
|
|
179
|
+
def get_retry_after(self) -> float:
|
|
180
|
+
"""Get recommended retry delay in seconds.
|
|
181
|
+
|
|
182
|
+
Returns:
|
|
183
|
+
Seconds to wait before next attempt (0 if token available).
|
|
184
|
+
"""
|
|
185
|
+
with self._mutex:
|
|
186
|
+
if self._tokens >= 1:
|
|
187
|
+
return 0
|
|
188
|
+
if self.unit == RateLimitUnit.PER_SECOND:
|
|
189
|
+
tps = self.rate_limit
|
|
190
|
+
elif self.unit == RateLimitUnit.PER_HOUR:
|
|
191
|
+
tps = self.rate_limit / 3600.0
|
|
192
|
+
else:
|
|
193
|
+
tps = self.rate_limit / 60.0
|
|
194
|
+
return (1 - self._tokens) / tps if tps > 0 else self.interval
|
|
195
|
+
|
|
196
|
+
def get_stats(self) -> Dict[str, float]:
|
|
197
|
+
"""Get current rate limiter statistics."""
|
|
198
|
+
with self._mutex:
|
|
199
|
+
now = time.monotonic()
|
|
200
|
+
elapsed = now - self._last_refill
|
|
201
|
+
if self.unit == RateLimitUnit.PER_SECOND:
|
|
202
|
+
refill_rate = self.rate_limit
|
|
203
|
+
elif self.unit == RateLimitUnit.PER_HOUR:
|
|
204
|
+
refill_rate = self.rate_limit / 3600.0
|
|
205
|
+
else:
|
|
206
|
+
refill_rate = self.rate_limit / 60.0
|
|
207
|
+
tokens = min(self.burst_capacity, self._tokens + elapsed * refill_rate)
|
|
208
|
+
return {
|
|
209
|
+
"tokens_remaining": round(tokens, 2),
|
|
210
|
+
"rate_limit": self.rate_limit,
|
|
211
|
+
"burst_capacity": self.burst_capacity,
|
|
212
|
+
"interval": round(self.interval, 4),
|
|
213
|
+
"unit": self.unit.value,
|
|
214
|
+
}
|