steindamm 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- steindamm/__init__.py +28 -0
- steindamm/base.py +49 -0
- steindamm/exceptions.py +7 -0
- steindamm/py.typed +0 -0
- steindamm/semaphore.lua +45 -0
- steindamm/semaphore.py +127 -0
- steindamm/token_bucket/local_token_bucket.py +118 -0
- steindamm/token_bucket/redis_token_bucket.py +137 -0
- steindamm/token_bucket/token_bucket.lua +89 -0
- steindamm/token_bucket/token_bucket.py +208 -0
- steindamm/token_bucket/token_bucket_base.py +161 -0
- steindamm-0.7.0.dist-info/METADATA +385 -0
- steindamm-0.7.0.dist-info/RECORD +14 -0
- steindamm-0.7.0.dist-info/WHEEL +4 -0
steindamm/__init__.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Various token bucket and semaphore implementations using a Redis or local backend.
|
|
3
|
+
|
|
4
|
+
Use SyncTokenBucket or AsyncTokenBucket to automatically select between Redis-based
|
|
5
|
+
and local in-memory implementations based on whether a Redis connection is provided.
|
|
6
|
+
|
|
7
|
+
For explicit control over the implementation, import and use
|
|
8
|
+
SyncRedisTokenBucket, AsyncRedisTokenBucket, SyncLocalTokenBucket, or AsyncLocalTokenBucket directly.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
# TODO: Add local semaphore implementation and update docs accordingly
|
|
12
|
+
from redis_limiters.exceptions import MaxSleepExceededError
|
|
13
|
+
from redis_limiters.semaphore import AsyncSemaphore, SyncSemaphore
|
|
14
|
+
from redis_limiters.token_bucket.local_token_bucket import AsyncLocalTokenBucket, SyncLocalTokenBucket
|
|
15
|
+
from redis_limiters.token_bucket.redis_token_bucket import AsyncRedisTokenBucket, SyncRedisTokenBucket
|
|
16
|
+
from redis_limiters.token_bucket.token_bucket import AsyncTokenBucket, SyncTokenBucket
|
|
17
|
+
|
|
18
|
+
__all__ = (
|
|
19
|
+
"AsyncLocalTokenBucket",
|
|
20
|
+
"AsyncRedisTokenBucket",
|
|
21
|
+
"AsyncSemaphore",
|
|
22
|
+
"AsyncTokenBucket",
|
|
23
|
+
"MaxSleepExceededError",
|
|
24
|
+
"SyncLocalTokenBucket",
|
|
25
|
+
"SyncRedisTokenBucket",
|
|
26
|
+
"SyncSemaphore",
|
|
27
|
+
"SyncTokenBucket",
|
|
28
|
+
)
|
steindamm/base.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
"""Base classes for Redis Lua script handling."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any, ClassVar
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel, ConfigDict
|
|
7
|
+
from redis import Redis as SyncRedis
|
|
8
|
+
from redis.asyncio import Redis as AsyncRedis
|
|
9
|
+
from redis.asyncio.cluster import RedisCluster as AsyncRedisCluster
|
|
10
|
+
from redis.cluster import RedisCluster as SyncRedisCluster
|
|
11
|
+
from redis.commands.core import AsyncScript, Script
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class SyncLuaScriptBase(BaseModel):
|
|
15
|
+
"""Base class for synchronous Redis Lua script handling."""
|
|
16
|
+
|
|
17
|
+
model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True)
|
|
18
|
+
|
|
19
|
+
connection: SyncRedis | SyncRedisCluster
|
|
20
|
+
script_name: ClassVar[str]
|
|
21
|
+
script: Script = None # type: ignore[assignment]
|
|
22
|
+
|
|
23
|
+
def __init__(self, **kwargs: Any) -> None:
|
|
24
|
+
"""Initialize the Lua script base class and load the script."""
|
|
25
|
+
super().__init__(**kwargs)
|
|
26
|
+
|
|
27
|
+
# https://github.com/redis/redis-py/issues/3712
|
|
28
|
+
# Load script on initialization
|
|
29
|
+
with open(Path(__file__).parent / self.script_name) as f:
|
|
30
|
+
self.script = self.connection.register_script(f.read()) # type: ignore
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class AsyncLuaScriptBase(BaseModel):
|
|
34
|
+
"""Base class for asynchronous Redis Lua script handling."""
|
|
35
|
+
|
|
36
|
+
model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True)
|
|
37
|
+
|
|
38
|
+
connection: AsyncRedis | AsyncRedisCluster
|
|
39
|
+
script_name: ClassVar[str]
|
|
40
|
+
script: AsyncScript = None # type: ignore[assignment]
|
|
41
|
+
|
|
42
|
+
def __init__(self, **kwargs: Any) -> None:
|
|
43
|
+
"""Initialize the Lua script base class and load the script."""
|
|
44
|
+
super().__init__(**kwargs)
|
|
45
|
+
|
|
46
|
+
# https://github.com/redis/redis-py/issues/3712
|
|
47
|
+
# Load script on initialization
|
|
48
|
+
with open(Path(__file__).parent / self.script_name) as f:
|
|
49
|
+
self.script = self.connection.register_script(f.read()) # type: ignore
|
steindamm/exceptions.py
ADDED
steindamm/py.typed
ADDED
|
File without changes
|
steindamm/semaphore.lua
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
--- Script called from the Semaphore implementation.
|
|
2
|
+
---
|
|
3
|
+
--- Lua scripts are run atomically by default, and since redis
|
|
4
|
+
--- is single threaded, there are no race conditions to worry about.
|
|
5
|
+
---
|
|
6
|
+
--- The script checks if a list exists for the Semaphore, and
|
|
7
|
+
--- creates one of length `capacity` if it doesn't.
|
|
8
|
+
---
|
|
9
|
+
--- keys:
|
|
10
|
+
--- * key: The key to use for the list
|
|
11
|
+
--- * exists: The key to use for the string we use to check if the lists exists
|
|
12
|
+
---
|
|
13
|
+
--- args:
|
|
14
|
+
--- * capacity: The capacity of the semaphore (i.e., the length of the list)
|
|
15
|
+
---
|
|
16
|
+
--- returns:
|
|
17
|
+
--- * 1 if created, else 0 (but the return value isn't used; only useful for debugging)
|
|
18
|
+
|
|
19
|
+
redis.replicate_commands()
|
|
20
|
+
|
|
21
|
+
-- Init config variables
|
|
22
|
+
local key = tostring(KEYS[1])
|
|
23
|
+
local exists = tostring(KEYS[2])
|
|
24
|
+
local capacity = tonumber(ARGV[1])
|
|
25
|
+
|
|
26
|
+
-- Check if list exists
|
|
27
|
+
-- Note, we cannot use `EXISTS` or `LLEN` directly on the `key` below,
|
|
28
|
+
-- as a list in Redis will "stop existing" if it's empty (empty state will occur
|
|
29
|
+
-- whenever the Semaphore is fully utilized). Instead, we use a separate
|
|
30
|
+
-- key to check whether a list has been created for our `key` or not.
|
|
31
|
+
local does_not_exist = redis.call('SETNX', string.format(exists, key), 1)
|
|
32
|
+
|
|
33
|
+
-- Create the list if none exists
|
|
34
|
+
if does_not_exist == 1 then
|
|
35
|
+
-- Add '1' as an argument equal to the capacity of the semaphore
|
|
36
|
+
-- If capacity is 5 here, we generate `{RPUSH, 1, 1, 1, 1, 1}`.
|
|
37
|
+
local args = { 'RPUSH', key }
|
|
38
|
+
for _ = 1, capacity do
|
|
39
|
+
table.insert(args, 1)
|
|
40
|
+
end
|
|
41
|
+
redis.call(unpack(args))
|
|
42
|
+
return true
|
|
43
|
+
end
|
|
44
|
+
|
|
45
|
+
return false
|
steindamm/semaphore.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
"""Semaphore limiter implementation."""
|
|
2
|
+
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from logging import getLogger
|
|
5
|
+
from types import TracebackType
|
|
6
|
+
from typing import Annotated, ClassVar
|
|
7
|
+
|
|
8
|
+
from pydantic import BaseModel, Field
|
|
9
|
+
from redis.asyncio.client import Pipeline
|
|
10
|
+
from redis.asyncio.cluster import ClusterPipeline
|
|
11
|
+
|
|
12
|
+
from redis_limiters import MaxSleepExceededError
|
|
13
|
+
from redis_limiters.base import AsyncLuaScriptBase, SyncLuaScriptBase
|
|
14
|
+
|
|
15
|
+
logger = getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
PositiveInt = Annotated[int, Field(gt=0)]
|
|
18
|
+
NonNegativeFloat = Annotated[float, Field(ge=0)]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
# TODO: Implement local semaphore as done with token bucket.
|
|
22
|
+
class SemaphoreBase(BaseModel): # noqa: D101 TODO: Fix after local semaphore is added
|
|
23
|
+
name: str
|
|
24
|
+
capacity: PositiveInt = 5
|
|
25
|
+
expiry: PositiveInt = 60
|
|
26
|
+
max_sleep: NonNegativeFloat = 30
|
|
27
|
+
|
|
28
|
+
@property
|
|
29
|
+
def key(self) -> str:
|
|
30
|
+
"""Key to use for the Semaphore list."""
|
|
31
|
+
return f"{{limiter}}:semaphore:{self.name}"
|
|
32
|
+
|
|
33
|
+
@property
|
|
34
|
+
def exists(self) -> str:
|
|
35
|
+
"""Key to use when checking if the Semaphore list has been created or not."""
|
|
36
|
+
return f"{{limiter}}:semaphore:{self.name}-exists"
|
|
37
|
+
|
|
38
|
+
def __str__(self) -> str:
|
|
39
|
+
return f"Semaphore instance for queue {self.key}"
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class SyncSemaphore(SemaphoreBase, SyncLuaScriptBase): # noqa: D101 TODO: Fix after local semaphore is added
|
|
43
|
+
script_name: ClassVar[str] = "semaphore.lua"
|
|
44
|
+
|
|
45
|
+
def __enter__(self) -> None:
|
|
46
|
+
"""Call the semaphore Lua script to create a semaphore, then call BLPOP to acquire it."""
|
|
47
|
+
# Retrieve timestamp for when to wake up from Redis
|
|
48
|
+
# To understand what exists does, check the Lua script
|
|
49
|
+
if self.script(
|
|
50
|
+
keys=[self.key, self.exists],
|
|
51
|
+
args=[self.capacity],
|
|
52
|
+
):
|
|
53
|
+
logger.info("Created new semaphore `%s` with capacity %s", self.name, self.capacity)
|
|
54
|
+
else:
|
|
55
|
+
logger.debug("Skipped creating semaphore, since one exists")
|
|
56
|
+
|
|
57
|
+
start = datetime.now()
|
|
58
|
+
|
|
59
|
+
self.connection.blpop([self.key], self.max_sleep)
|
|
60
|
+
pipeline = self.connection.pipeline()
|
|
61
|
+
pipeline.expire(self.key, self.expiry)
|
|
62
|
+
pipeline.expire(self.exists, self.expiry)
|
|
63
|
+
pipeline.execute()
|
|
64
|
+
|
|
65
|
+
# Raise an exception if we exceeded `max_sleep`
|
|
66
|
+
if 0.0 < self.max_sleep < (datetime.now() - start).total_seconds():
|
|
67
|
+
raise MaxSleepExceededError("Max sleep exceeded waiting for Semaphore")
|
|
68
|
+
|
|
69
|
+
logger.debug("Acquired semaphore %s", self.name)
|
|
70
|
+
|
|
71
|
+
def __exit__(
|
|
72
|
+
self,
|
|
73
|
+
exc_type: type[BaseException] | None,
|
|
74
|
+
exc_val: BaseException | None,
|
|
75
|
+
exc_tb: TracebackType | None,
|
|
76
|
+
) -> None:
|
|
77
|
+
pipeline = self.connection.pipeline()
|
|
78
|
+
pipeline.lpush(self.key, 1)
|
|
79
|
+
pipeline.expire(self.key, self.expiry)
|
|
80
|
+
pipeline.expire(self.exists, self.expiry)
|
|
81
|
+
pipeline.execute()
|
|
82
|
+
|
|
83
|
+
logger.debug("Released semaphore %s", self.name)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class AsyncSemaphore(SemaphoreBase, AsyncLuaScriptBase): # noqa: D101 TODO: Fix after local semaphore is added
|
|
87
|
+
script_name: ClassVar[str] = "semaphore.lua"
|
|
88
|
+
|
|
89
|
+
async def __aenter__(self) -> None:
|
|
90
|
+
"""Call the semaphore Lua script to create a semaphore, then call BLPOP to acquire it."""
|
|
91
|
+
# Retrieve timestamp for when to wake up from Redis
|
|
92
|
+
|
|
93
|
+
if await self.script(
|
|
94
|
+
keys=[self.key, self.exists],
|
|
95
|
+
args=[self.capacity],
|
|
96
|
+
):
|
|
97
|
+
logger.info("Created new semaphore `%s` with capacity %s", self.name, self.capacity)
|
|
98
|
+
else:
|
|
99
|
+
logger.debug("Skipped creating semaphore, since one exists")
|
|
100
|
+
|
|
101
|
+
start = datetime.now()
|
|
102
|
+
|
|
103
|
+
await self.connection.blpop([self.key], self.max_sleep) # type: ignore[union-attr]
|
|
104
|
+
pipeline: Pipeline | ClusterPipeline = self.connection.pipeline()
|
|
105
|
+
pipeline.expire(self.key, self.expiry) # type: ignore[union-attr]
|
|
106
|
+
pipeline.expire(self.exists, self.expiry) # type: ignore[union-attr]
|
|
107
|
+
await pipeline.execute()
|
|
108
|
+
|
|
109
|
+
# Raise an exception if we waited too long
|
|
110
|
+
if 0.0 < self.max_sleep < (datetime.now() - start).total_seconds():
|
|
111
|
+
raise MaxSleepExceededError(f"Max sleep ({self.max_sleep}s) exceeded waiting for Semaphore")
|
|
112
|
+
|
|
113
|
+
logger.debug("Acquired semaphore %s", self.name)
|
|
114
|
+
|
|
115
|
+
async def __aexit__(
|
|
116
|
+
self,
|
|
117
|
+
exc_type: type[BaseException] | None,
|
|
118
|
+
exc_val: BaseException | None,
|
|
119
|
+
exc_tb: TracebackType | None,
|
|
120
|
+
) -> None:
|
|
121
|
+
pipeline: Pipeline[str] | ClusterPipeline[str] = self.connection.pipeline()
|
|
122
|
+
pipeline.lpush(self.key, 1) # type: ignore[union-attr]
|
|
123
|
+
pipeline.expire(self.key, self.expiry) # type: ignore[union-attr]
|
|
124
|
+
pipeline.expire(self.exists, self.expiry) # type: ignore[union-attr]
|
|
125
|
+
await pipeline.execute()
|
|
126
|
+
|
|
127
|
+
logger.debug("Released semaphore %s", self.name)
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
"""Synchronous and Asynchronous local token bucket implementations."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import time
|
|
5
|
+
from threading import Lock
|
|
6
|
+
from types import TracebackType
|
|
7
|
+
from typing import ClassVar
|
|
8
|
+
|
|
9
|
+
from redis_limiters.token_bucket.token_bucket_base import TokenBucketBase
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class SyncLocalTokenBucket(TokenBucketBase):
|
|
13
|
+
"""
|
|
14
|
+
Synchronous local token bucket.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
name: Unique identifier for this token bucket.
|
|
18
|
+
capacity: Maximum number of tokens the bucket can hold.
|
|
19
|
+
refill_frequency: Time in seconds between token refills.
|
|
20
|
+
initial_tokens: Starting number of tokens. Defaults to capacity if not specified.
|
|
21
|
+
refill_amount: Number of tokens added per refill.
|
|
22
|
+
max_sleep: Maximum seconds to sleep when rate limited. 0 means no limit.
|
|
23
|
+
expiry: Key expiry time in seconds - currently not implemented for local buckets.
|
|
24
|
+
tokens_to_consume: Number of tokens to consume per operation.
|
|
25
|
+
|
|
26
|
+
Example:
|
|
27
|
+
.. code-block:: python
|
|
28
|
+
|
|
29
|
+
bucket = SyncLocalTokenBucket(name="api", capacity=10)
|
|
30
|
+
with bucket:
|
|
31
|
+
make_api_call()
|
|
32
|
+
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
# Class-level storage for bucket state (shared across instances)
|
|
36
|
+
# TODO: Currently there's no cleanup of old buckets.
|
|
37
|
+
# Consider adding periodic cleanup based on expiry.
|
|
38
|
+
_buckets: ClassVar[dict[str, dict]] = {}
|
|
39
|
+
_locks: ClassVar[dict[str, Lock]] = {}
|
|
40
|
+
_main_lock: ClassVar[Lock] = Lock()
|
|
41
|
+
|
|
42
|
+
def _get_lock(self) -> Lock:
|
|
43
|
+
# This is not safe in free threaded python
|
|
44
|
+
# Not acquiring main lock to improve performance in CPython with GIL
|
|
45
|
+
if self.key not in self._locks:
|
|
46
|
+
with self._main_lock:
|
|
47
|
+
if self.key not in self._locks:
|
|
48
|
+
self._locks[self.key] = Lock()
|
|
49
|
+
return self._locks[self.key]
|
|
50
|
+
|
|
51
|
+
def __enter__(self) -> None:
|
|
52
|
+
"""Acquire token(s) from the token bucket and sleep until they are available."""
|
|
53
|
+
# Execute token bucket logic with thread safety
|
|
54
|
+
with self._get_lock():
|
|
55
|
+
timestamp = self.execute_local_token_bucket_logic(self._buckets)
|
|
56
|
+
|
|
57
|
+
# Parse timestamp and sleep
|
|
58
|
+
sleep_time = self.parse_timestamp(timestamp)
|
|
59
|
+
time.sleep(sleep_time)
|
|
60
|
+
|
|
61
|
+
def __exit__(
|
|
62
|
+
self,
|
|
63
|
+
exc_type: type[BaseException] | None,
|
|
64
|
+
exc_val: BaseException | None,
|
|
65
|
+
exc_tb: TracebackType | None,
|
|
66
|
+
) -> None:
|
|
67
|
+
return
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class AsyncLocalTokenBucket(TokenBucketBase):
|
|
71
|
+
"""
|
|
72
|
+
Asynchronous local token bucket.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
name: Unique identifier for this token bucket.
|
|
76
|
+
capacity: Maximum number of tokens the bucket can hold.
|
|
77
|
+
refill_frequency: Time in seconds between token refills.
|
|
78
|
+
initial_tokens: Starting number of tokens. Defaults to capacity if not specified.
|
|
79
|
+
refill_amount: Number of tokens added per refill.
|
|
80
|
+
max_sleep: Maximum seconds to sleep when rate limited. 0 means no limit.
|
|
81
|
+
expiry: Key expiry time in seconds - currently not implemented for local buckets.
|
|
82
|
+
tokens_to_consume: Number of tokens to consume per operation.
|
|
83
|
+
|
|
84
|
+
Example:
|
|
85
|
+
.. code-block:: python
|
|
86
|
+
|
|
87
|
+
bucket = AsyncLocalTokenBucket(name="api", capacity=10)
|
|
88
|
+
async with bucket:
|
|
89
|
+
await make_api_call()
|
|
90
|
+
|
|
91
|
+
Note: If you need to use this class from multiple threads (multiple event loops),
|
|
92
|
+
consider using SyncLocalTokenBucket instead, which provides proper thread safety.
|
|
93
|
+
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
# Class-level storage for bucket state (shared across instances)
|
|
97
|
+
# TODO: Currently there's no cleanup of old buckets.
|
|
98
|
+
# Consider adding periodic cleanup based on expiry.
|
|
99
|
+
_buckets: ClassVar[dict[str, dict]] = {}
|
|
100
|
+
|
|
101
|
+
async def __aenter__(self) -> None:
|
|
102
|
+
"""Acquire token(s) from the token bucket and sleep until they are available."""
|
|
103
|
+
# Execute token bucket logic
|
|
104
|
+
# No lock needed: asyncio is single-threaded and execute_local_token_bucket_logic
|
|
105
|
+
# has no await points, making it atomic from asyncio's perspective
|
|
106
|
+
timestamp = self.execute_local_token_bucket_logic(self._buckets)
|
|
107
|
+
|
|
108
|
+
# Parse timestamp and sleep
|
|
109
|
+
sleep_time = self.parse_timestamp(timestamp)
|
|
110
|
+
await asyncio.sleep(sleep_time)
|
|
111
|
+
|
|
112
|
+
async def __aexit__(
|
|
113
|
+
self,
|
|
114
|
+
exc_type: type[BaseException] | None,
|
|
115
|
+
exc_val: BaseException | None,
|
|
116
|
+
exc_tb: TracebackType | None,
|
|
117
|
+
) -> None:
|
|
118
|
+
return
|
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
"""Synchronous and Asynchronous Redis-backed (Standalone or Cluster) token bucket implementations."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import time
|
|
5
|
+
from types import TracebackType
|
|
6
|
+
from typing import ClassVar, cast
|
|
7
|
+
|
|
8
|
+
from redis_limiters.base import AsyncLuaScriptBase, SyncLuaScriptBase
|
|
9
|
+
from redis_limiters.token_bucket.token_bucket_base import TokenBucketBase, get_current_time_ms
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class SyncRedisTokenBucket(TokenBucketBase, SyncLuaScriptBase):
|
|
13
|
+
"""
|
|
14
|
+
Synchronous Redis-backed (Standalone or Cluster) token bucket.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
name: Unique identifier for this token bucket.
|
|
18
|
+
connection: Redis connection (SyncRedis or SyncRedisCluster).
|
|
19
|
+
capacity: Maximum number of tokens the bucket can hold.
|
|
20
|
+
refill_frequency: Time in seconds between token refills.
|
|
21
|
+
initial_tokens: Starting number of tokens. Defaults to capacity if not specified.
|
|
22
|
+
refill_amount: Number of tokens added per refill.
|
|
23
|
+
max_sleep: Maximum seconds to sleep when rate limited. 0 means no limit.
|
|
24
|
+
expiry: Key expiry time in seconds.
|
|
25
|
+
tokens_to_consume: Number of tokens to consume per operation.
|
|
26
|
+
|
|
27
|
+
Example:
|
|
28
|
+
.. code-block:: python
|
|
29
|
+
|
|
30
|
+
from redis import Redis # or from redis.cluster import RedisCluster
|
|
31
|
+
redis_conn = Redis(host='localhost', port=6379)
|
|
32
|
+
bucket = SyncRedisTokenBucket(connection=redis_conn, name="api", capacity=10)
|
|
33
|
+
with bucket:
|
|
34
|
+
make_api_call()
|
|
35
|
+
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
script_name: ClassVar[str] = "token_bucket/token_bucket.lua"
|
|
39
|
+
|
|
40
|
+
def __enter__(self) -> float:
|
|
41
|
+
"""Acquire token(s) from the token bucket and sleep until they are available."""
|
|
42
|
+
# Retrieve timestamp for when to wake up from Redis Lua script
|
|
43
|
+
milliseconds = get_current_time_ms()
|
|
44
|
+
timestamp: int = cast(
|
|
45
|
+
int,
|
|
46
|
+
self.script(
|
|
47
|
+
keys=[self.key],
|
|
48
|
+
args=[
|
|
49
|
+
self.capacity,
|
|
50
|
+
self.refill_amount,
|
|
51
|
+
self.initial_tokens or self.capacity,
|
|
52
|
+
self.refill_frequency,
|
|
53
|
+
milliseconds,
|
|
54
|
+
self.expiry,
|
|
55
|
+
self.tokens_to_consume,
|
|
56
|
+
],
|
|
57
|
+
),
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
# Estimate sleep time
|
|
61
|
+
sleep_time = self.parse_timestamp(timestamp)
|
|
62
|
+
|
|
63
|
+
# Sleep before returning
|
|
64
|
+
time.sleep(sleep_time)
|
|
65
|
+
|
|
66
|
+
return sleep_time
|
|
67
|
+
|
|
68
|
+
def __exit__(
|
|
69
|
+
self,
|
|
70
|
+
exc_type: type[BaseException] | None,
|
|
71
|
+
exc_val: BaseException | None,
|
|
72
|
+
exc_tb: TracebackType | None,
|
|
73
|
+
) -> None:
|
|
74
|
+
return
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class AsyncRedisTokenBucket(TokenBucketBase, AsyncLuaScriptBase):
|
|
78
|
+
"""
|
|
79
|
+
Asynchronous Redis-backed (Standalone or Cluster) token bucket.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
name: Unique identifier for this token bucket.
|
|
83
|
+
connection: Redis connection (AsyncRedis or AsyncRedisCluster).
|
|
84
|
+
capacity: Maximum number of tokens the bucket can hold.
|
|
85
|
+
refill_frequency: Time in seconds between token refills.
|
|
86
|
+
initial_tokens: Starting number of tokens. Defaults to capacity if not specified.
|
|
87
|
+
refill_amount: Number of tokens added per refill.
|
|
88
|
+
max_sleep: Maximum seconds to sleep when rate limited. 0 means no limit.
|
|
89
|
+
expiry: Key expiry time in seconds.
|
|
90
|
+
tokens_to_consume: Number of tokens to consume per operation.
|
|
91
|
+
|
|
92
|
+
Example:
|
|
93
|
+
.. code-block:: python
|
|
94
|
+
|
|
95
|
+
from redis.asyncio import Redis # or from redis.asyncio.cluster import RedisCluster
|
|
96
|
+
redis_conn = Redis(host='localhost', port=6379)
|
|
97
|
+
bucket = AsyncRedisTokenBucket(connection=redis_conn, name="api", capacity=10)
|
|
98
|
+
async with bucket:
|
|
99
|
+
await make_api_call()
|
|
100
|
+
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
script_name: ClassVar[str] = "token_bucket/token_bucket.lua"
|
|
104
|
+
|
|
105
|
+
async def __aenter__(self) -> None:
|
|
106
|
+
"""Acquire token(s) from the token bucket and sleep until they are available."""
|
|
107
|
+
# Retrieve timestamp for when to wake up from Redis Lua script
|
|
108
|
+
milliseconds = get_current_time_ms()
|
|
109
|
+
timestamp: int = cast(
|
|
110
|
+
int,
|
|
111
|
+
await self.script(
|
|
112
|
+
keys=[self.key],
|
|
113
|
+
args=[
|
|
114
|
+
self.capacity,
|
|
115
|
+
self.refill_amount,
|
|
116
|
+
self.initial_tokens or self.capacity,
|
|
117
|
+
self.refill_frequency,
|
|
118
|
+
milliseconds,
|
|
119
|
+
self.expiry,
|
|
120
|
+
self.tokens_to_consume,
|
|
121
|
+
],
|
|
122
|
+
),
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
# Estimate sleep time
|
|
126
|
+
sleep_time = self.parse_timestamp(timestamp)
|
|
127
|
+
|
|
128
|
+
# Sleep before returning
|
|
129
|
+
await asyncio.sleep(sleep_time)
|
|
130
|
+
|
|
131
|
+
async def __aexit__(
|
|
132
|
+
self,
|
|
133
|
+
exc_type: type[BaseException] | None,
|
|
134
|
+
exc_val: BaseException | None,
|
|
135
|
+
exc_tb: TracebackType | None,
|
|
136
|
+
) -> None:
|
|
137
|
+
return
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
--- Lua scripts are run atomically by default, and since redis
|
|
2
|
+
--- is single threaded, there are no race conditions to worry about.
|
|
3
|
+
---
|
|
4
|
+
--- This script does three things, in order:
|
|
5
|
+
--- 1. Retrieves token bucket state, which means the last slot assigned,
|
|
6
|
+
--- and how many tokens are left to be assigned for that slot
|
|
7
|
+
--- 2. Works out whether we need to move to the next slot(s), or consume
|
|
8
|
+
--- tokens from the current one.
|
|
9
|
+
--- 3. Saves the token bucket state and returns the slot. The state is a
|
|
10
|
+
--- combination of the last slot assigned (timestamp) and the number of tokens left.
|
|
11
|
+
---
|
|
12
|
+
--- The token bucket implementation is forward looking, so we're really just handing
|
|
13
|
+
--- out the next time there would be tokens in the bucket, and letting the client
|
|
14
|
+
--- decide wait until that time.
|
|
15
|
+
---
|
|
16
|
+
--- returns:
|
|
17
|
+
--- * The assigned slot, as a millisecond timestamp
|
|
18
|
+
|
|
19
|
+
redis.replicate_commands()
|
|
20
|
+
|
|
21
|
+
-- Arguments
|
|
22
|
+
local capacity = tonumber(ARGV[1])
|
|
23
|
+
local refill_amount = tonumber(ARGV[2])
|
|
24
|
+
local initial_tokens = tonumber(ARGV[3])
|
|
25
|
+
local time_between_slots = tonumber(ARGV[4]) * 1000 -- Convert to milliseconds
|
|
26
|
+
local milliseconds = tonumber(ARGV[5])
|
|
27
|
+
local expiry = tonumber(ARGV[6])
|
|
28
|
+
local tokens_to_consume = tonumber(ARGV[7]) -- Number of tokens to consume
|
|
29
|
+
|
|
30
|
+
-- Validate that tokens_to_consume doesn't exceed capacity
|
|
31
|
+
if tokens_to_consume > capacity then
|
|
32
|
+
return redis.error_reply("Requested tokens exceed bucket capacity")
|
|
33
|
+
end
|
|
34
|
+
|
|
35
|
+
-- Validate that tokens_to_consume is positive
|
|
36
|
+
if tokens_to_consume <= 0 then
|
|
37
|
+
return redis.error_reply("Must consume at least 1 token")
|
|
38
|
+
end
|
|
39
|
+
|
|
40
|
+
-- Keys
|
|
41
|
+
local data_key = KEYS[1]
|
|
42
|
+
|
|
43
|
+
-- Get current time in milliseconds
|
|
44
|
+
local now = milliseconds
|
|
45
|
+
|
|
46
|
+
-- Default bucket values (used if no bucket exists yet)
|
|
47
|
+
local tokens = math.min(initial_tokens, capacity)
|
|
48
|
+
local slot = now
|
|
49
|
+
|
|
50
|
+
-- Retrieve stored state, if any
|
|
51
|
+
local data = redis.call('GET', data_key)
|
|
52
|
+
if data then
|
|
53
|
+
local last_slot, stored_tokens = data:match('(%S+) (%S+)')
|
|
54
|
+
slot = tonumber(last_slot)
|
|
55
|
+
tokens = tonumber(stored_tokens)
|
|
56
|
+
|
|
57
|
+
-- Calculate the number of slots that have passed since the last update
|
|
58
|
+
local slots_passed = math.floor((now - slot) / time_between_slots)
|
|
59
|
+
if slots_passed > 0 then
|
|
60
|
+
-- Refill the tokens based on the number of slots passed, capped by capacity
|
|
61
|
+
tokens = math.min(tokens + slots_passed * refill_amount, capacity)
|
|
62
|
+
-- Update the slot to this run
|
|
63
|
+
-- The previously added +20 ms execution penalty was removed as it was not needed
|
|
64
|
+
-- and all it did was add additional latency to all requests and in our case,
|
|
65
|
+
-- timing is handled gracefully with the condition used (wake_up_time < now)
|
|
66
|
+
slot = now
|
|
67
|
+
end
|
|
68
|
+
end
|
|
69
|
+
|
|
70
|
+
-- If not enough tokens are available, move to the next slot(s) and refill accordingly
|
|
71
|
+
if tokens < tokens_to_consume then
|
|
72
|
+
-- Calculate how many additional tokens we need
|
|
73
|
+
local needed_tokens = tokens_to_consume - tokens
|
|
74
|
+
-- Calculate how many slots we need to move forward to get enough tokens
|
|
75
|
+
local needed_slots = math.ceil(needed_tokens / refill_amount)
|
|
76
|
+
slot = slot + needed_slots * time_between_slots
|
|
77
|
+
tokens = tokens + needed_slots * refill_amount
|
|
78
|
+
-- Clamp tokens to capacity
|
|
79
|
+
tokens = math.min(tokens, capacity)
|
|
80
|
+
end
|
|
81
|
+
|
|
82
|
+
-- Consume tokens
|
|
83
|
+
tokens = tokens - tokens_to_consume
|
|
84
|
+
|
|
85
|
+
-- Save updated state and set expiry
|
|
86
|
+
redis.call('SETEX', data_key, expiry, string.format('%d %d', slot, tokens))
|
|
87
|
+
|
|
88
|
+
-- Return the slot when the next token(s) will be available
|
|
89
|
+
return slot
|