steindamm 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
steindamm/__init__.py ADDED
@@ -0,0 +1,28 @@
1
+ """
2
+ Various token bucket and semaphore implementations using a Redis or local backend.
3
+
4
+ Use SyncTokenBucket or AsyncTokenBucket to automatically select between Redis-based
5
+ and local in-memory implementations based on whether a Redis connection is provided.
6
+
7
+ For explicit control over the implementation, import and use
8
+ SyncRedisTokenBucket, AsyncRedisTokenBucket, SyncLocalTokenBucket, or AsyncLocalTokenBucket directly.
9
+ """
10
+
11
+ # TODO: Add local semaphore implementation and update docs accordingly
12
+ from redis_limiters.exceptions import MaxSleepExceededError
13
+ from redis_limiters.semaphore import AsyncSemaphore, SyncSemaphore
14
+ from redis_limiters.token_bucket.local_token_bucket import AsyncLocalTokenBucket, SyncLocalTokenBucket
15
+ from redis_limiters.token_bucket.redis_token_bucket import AsyncRedisTokenBucket, SyncRedisTokenBucket
16
+ from redis_limiters.token_bucket.token_bucket import AsyncTokenBucket, SyncTokenBucket
17
+
18
+ __all__ = (
19
+ "AsyncLocalTokenBucket",
20
+ "AsyncRedisTokenBucket",
21
+ "AsyncSemaphore",
22
+ "AsyncTokenBucket",
23
+ "MaxSleepExceededError",
24
+ "SyncLocalTokenBucket",
25
+ "SyncRedisTokenBucket",
26
+ "SyncSemaphore",
27
+ "SyncTokenBucket",
28
+ )
steindamm/base.py ADDED
@@ -0,0 +1,49 @@
1
+ """Base classes for Redis Lua script handling."""
2
+
3
+ from pathlib import Path
4
+ from typing import Any, ClassVar
5
+
6
+ from pydantic import BaseModel, ConfigDict
7
+ from redis import Redis as SyncRedis
8
+ from redis.asyncio import Redis as AsyncRedis
9
+ from redis.asyncio.cluster import RedisCluster as AsyncRedisCluster
10
+ from redis.cluster import RedisCluster as SyncRedisCluster
11
+ from redis.commands.core import AsyncScript, Script
12
+
13
+
14
+ class SyncLuaScriptBase(BaseModel):
15
+ """Base class for synchronous Redis Lua script handling."""
16
+
17
+ model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True)
18
+
19
+ connection: SyncRedis | SyncRedisCluster
20
+ script_name: ClassVar[str]
21
+ script: Script = None # type: ignore[assignment]
22
+
23
+ def __init__(self, **kwargs: Any) -> None:
24
+ """Initialize the Lua script base class and load the script."""
25
+ super().__init__(**kwargs)
26
+
27
+ # https://github.com/redis/redis-py/issues/3712
28
+ # Load script on initialization
29
+ with open(Path(__file__).parent / self.script_name) as f:
30
+ self.script = self.connection.register_script(f.read()) # type: ignore
31
+
32
+
33
+ class AsyncLuaScriptBase(BaseModel):
34
+ """Base class for asynchronous Redis Lua script handling."""
35
+
36
+ model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True)
37
+
38
+ connection: AsyncRedis | AsyncRedisCluster
39
+ script_name: ClassVar[str]
40
+ script: AsyncScript = None # type: ignore[assignment]
41
+
42
+ def __init__(self, **kwargs: Any) -> None:
43
+ """Initialize the Lua script base class and load the script."""
44
+ super().__init__(**kwargs)
45
+
46
+ # https://github.com/redis/redis-py/issues/3712
47
+ # Load script on initialization
48
+ with open(Path(__file__).parent / self.script_name) as f:
49
+ self.script = self.connection.register_script(f.read()) # type: ignore
@@ -0,0 +1,7 @@
1
+ """Exceptions for redis_limiters package."""
2
+
3
+
4
+ class MaxSleepExceededError(Exception):
5
+ """Raised when we've slept for longer than the `max_sleep` specified limit."""
6
+
7
+ pass
steindamm/py.typed ADDED
File without changes
@@ -0,0 +1,45 @@
1
+ --- Script called from the Semaphore implementation.
2
+ ---
3
+ --- Lua scripts are run atomically by default, and since redis
4
+ --- is single threaded, there are no race conditions to worry about.
5
+ ---
6
+ --- The script checks if a list exists for the Semaphore, and
7
+ --- creates one of length `capacity` if it doesn't.
8
+ ---
9
+ --- keys:
10
+ --- * key: The key to use for the list
11
+ --- * exists: The key to use for the string we use to check if the lists exists
12
+ ---
13
+ --- args:
14
+ --- * capacity: The capacity of the semaphore (i.e., the length of the list)
15
+ ---
16
+ --- returns:
17
+ --- * 1 if created, else 0 (but the return value isn't used; only useful for debugging)
18
+
19
+ redis.replicate_commands()
20
+
21
+ -- Init config variables
22
+ local key = tostring(KEYS[1])
23
+ local exists = tostring(KEYS[2])
24
+ local capacity = tonumber(ARGV[1])
25
+
26
+ -- Check if list exists
27
+ -- Note, we cannot use `EXISTS` or `LLEN` directly on the `key` below,
28
+ -- as a list in Redis will "stop existing" if it's empty (empty state will occur
29
+ -- whenever the Semaphore is fully utilized). Instead, we use a separate
30
+ -- key to check whether a list has been created for our `key` or not.
31
+ local does_not_exist = redis.call('SETNX', string.format(exists, key), 1)
32
+
33
+ -- Create the list if none exists
34
+ if does_not_exist == 1 then
35
+ -- Add '1' as an argument equal to the capacity of the semaphore
36
+ -- If capacity is 5 here, we generate `{RPUSH, 1, 1, 1, 1, 1}`.
37
+ local args = { 'RPUSH', key }
38
+ for _ = 1, capacity do
39
+ table.insert(args, 1)
40
+ end
41
+ redis.call(unpack(args))
42
+ return true
43
+ end
44
+
45
+ return false
steindamm/semaphore.py ADDED
@@ -0,0 +1,127 @@
1
+ """Semaphore limiter implementation."""
2
+
3
+ from datetime import datetime
4
+ from logging import getLogger
5
+ from types import TracebackType
6
+ from typing import Annotated, ClassVar
7
+
8
+ from pydantic import BaseModel, Field
9
+ from redis.asyncio.client import Pipeline
10
+ from redis.asyncio.cluster import ClusterPipeline
11
+
12
+ from redis_limiters import MaxSleepExceededError
13
+ from redis_limiters.base import AsyncLuaScriptBase, SyncLuaScriptBase
14
+
15
+ logger = getLogger(__name__)
16
+
17
+ PositiveInt = Annotated[int, Field(gt=0)]
18
+ NonNegativeFloat = Annotated[float, Field(ge=0)]
19
+
20
+
21
+ # TODO: Implement local semaphore as done with token bucket.
22
+ class SemaphoreBase(BaseModel): # noqa: D101 TODO: Fix after local semaphore is added
23
+ name: str
24
+ capacity: PositiveInt = 5
25
+ expiry: PositiveInt = 60
26
+ max_sleep: NonNegativeFloat = 30
27
+
28
+ @property
29
+ def key(self) -> str:
30
+ """Key to use for the Semaphore list."""
31
+ return f"{{limiter}}:semaphore:{self.name}"
32
+
33
+ @property
34
+ def exists(self) -> str:
35
+ """Key to use when checking if the Semaphore list has been created or not."""
36
+ return f"{{limiter}}:semaphore:{self.name}-exists"
37
+
38
+ def __str__(self) -> str:
39
+ return f"Semaphore instance for queue {self.key}"
40
+
41
+
42
+ class SyncSemaphore(SemaphoreBase, SyncLuaScriptBase): # noqa: D101 TODO: Fix after local semaphore is added
43
+ script_name: ClassVar[str] = "semaphore.lua"
44
+
45
+ def __enter__(self) -> None:
46
+ """Call the semaphore Lua script to create a semaphore, then call BLPOP to acquire it."""
47
+ # Retrieve timestamp for when to wake up from Redis
48
+ # To understand what exists does, check the Lua script
49
+ if self.script(
50
+ keys=[self.key, self.exists],
51
+ args=[self.capacity],
52
+ ):
53
+ logger.info("Created new semaphore `%s` with capacity %s", self.name, self.capacity)
54
+ else:
55
+ logger.debug("Skipped creating semaphore, since one exists")
56
+
57
+ start = datetime.now()
58
+
59
+ self.connection.blpop([self.key], self.max_sleep)
60
+ pipeline = self.connection.pipeline()
61
+ pipeline.expire(self.key, self.expiry)
62
+ pipeline.expire(self.exists, self.expiry)
63
+ pipeline.execute()
64
+
65
+ # Raise an exception if we exceeded `max_sleep`
66
+ if 0.0 < self.max_sleep < (datetime.now() - start).total_seconds():
67
+ raise MaxSleepExceededError("Max sleep exceeded waiting for Semaphore")
68
+
69
+ logger.debug("Acquired semaphore %s", self.name)
70
+
71
+ def __exit__(
72
+ self,
73
+ exc_type: type[BaseException] | None,
74
+ exc_val: BaseException | None,
75
+ exc_tb: TracebackType | None,
76
+ ) -> None:
77
+ pipeline = self.connection.pipeline()
78
+ pipeline.lpush(self.key, 1)
79
+ pipeline.expire(self.key, self.expiry)
80
+ pipeline.expire(self.exists, self.expiry)
81
+ pipeline.execute()
82
+
83
+ logger.debug("Released semaphore %s", self.name)
84
+
85
+
86
+ class AsyncSemaphore(SemaphoreBase, AsyncLuaScriptBase): # noqa: D101 TODO: Fix after local semaphore is added
87
+ script_name: ClassVar[str] = "semaphore.lua"
88
+
89
+ async def __aenter__(self) -> None:
90
+ """Call the semaphore Lua script to create a semaphore, then call BLPOP to acquire it."""
91
+ # Retrieve timestamp for when to wake up from Redis
92
+
93
+ if await self.script(
94
+ keys=[self.key, self.exists],
95
+ args=[self.capacity],
96
+ ):
97
+ logger.info("Created new semaphore `%s` with capacity %s", self.name, self.capacity)
98
+ else:
99
+ logger.debug("Skipped creating semaphore, since one exists")
100
+
101
+ start = datetime.now()
102
+
103
+ await self.connection.blpop([self.key], self.max_sleep) # type: ignore[union-attr]
104
+ pipeline: Pipeline | ClusterPipeline = self.connection.pipeline()
105
+ pipeline.expire(self.key, self.expiry) # type: ignore[union-attr]
106
+ pipeline.expire(self.exists, self.expiry) # type: ignore[union-attr]
107
+ await pipeline.execute()
108
+
109
+ # Raise an exception if we waited too long
110
+ if 0.0 < self.max_sleep < (datetime.now() - start).total_seconds():
111
+ raise MaxSleepExceededError(f"Max sleep ({self.max_sleep}s) exceeded waiting for Semaphore")
112
+
113
+ logger.debug("Acquired semaphore %s", self.name)
114
+
115
+ async def __aexit__(
116
+ self,
117
+ exc_type: type[BaseException] | None,
118
+ exc_val: BaseException | None,
119
+ exc_tb: TracebackType | None,
120
+ ) -> None:
121
+ pipeline: Pipeline[str] | ClusterPipeline[str] = self.connection.pipeline()
122
+ pipeline.lpush(self.key, 1) # type: ignore[union-attr]
123
+ pipeline.expire(self.key, self.expiry) # type: ignore[union-attr]
124
+ pipeline.expire(self.exists, self.expiry) # type: ignore[union-attr]
125
+ await pipeline.execute()
126
+
127
+ logger.debug("Released semaphore %s", self.name)
@@ -0,0 +1,118 @@
1
+ """Synchronous and Asynchronous local token bucket implementations."""
2
+
3
+ import asyncio
4
+ import time
5
+ from threading import Lock
6
+ from types import TracebackType
7
+ from typing import ClassVar
8
+
9
+ from redis_limiters.token_bucket.token_bucket_base import TokenBucketBase
10
+
11
+
12
+ class SyncLocalTokenBucket(TokenBucketBase):
13
+ """
14
+ Synchronous local token bucket.
15
+
16
+ Args:
17
+ name: Unique identifier for this token bucket.
18
+ capacity: Maximum number of tokens the bucket can hold.
19
+ refill_frequency: Time in seconds between token refills.
20
+ initial_tokens: Starting number of tokens. Defaults to capacity if not specified.
21
+ refill_amount: Number of tokens added per refill.
22
+ max_sleep: Maximum seconds to sleep when rate limited. 0 means no limit.
23
+ expiry: Key expiry time in seconds - currently not implemented for local buckets.
24
+ tokens_to_consume: Number of tokens to consume per operation.
25
+
26
+ Example:
27
+ .. code-block:: python
28
+
29
+ bucket = SyncLocalTokenBucket(name="api", capacity=10)
30
+ with bucket:
31
+ make_api_call()
32
+
33
+ """
34
+
35
+ # Class-level storage for bucket state (shared across instances)
36
+ # TODO: Currently there's no cleanup of old buckets.
37
+ # Consider adding periodic cleanup based on expiry.
38
+ _buckets: ClassVar[dict[str, dict]] = {}
39
+ _locks: ClassVar[dict[str, Lock]] = {}
40
+ _main_lock: ClassVar[Lock] = Lock()
41
+
42
+ def _get_lock(self) -> Lock:
43
+ # This is not safe in free threaded python
44
+ # Not acquiring main lock to improve performance in CPython with GIL
45
+ if self.key not in self._locks:
46
+ with self._main_lock:
47
+ if self.key not in self._locks:
48
+ self._locks[self.key] = Lock()
49
+ return self._locks[self.key]
50
+
51
+ def __enter__(self) -> None:
52
+ """Acquire token(s) from the token bucket and sleep until they are available."""
53
+ # Execute token bucket logic with thread safety
54
+ with self._get_lock():
55
+ timestamp = self.execute_local_token_bucket_logic(self._buckets)
56
+
57
+ # Parse timestamp and sleep
58
+ sleep_time = self.parse_timestamp(timestamp)
59
+ time.sleep(sleep_time)
60
+
61
+ def __exit__(
62
+ self,
63
+ exc_type: type[BaseException] | None,
64
+ exc_val: BaseException | None,
65
+ exc_tb: TracebackType | None,
66
+ ) -> None:
67
+ return
68
+
69
+
70
+ class AsyncLocalTokenBucket(TokenBucketBase):
71
+ """
72
+ Asynchronous local token bucket.
73
+
74
+ Args:
75
+ name: Unique identifier for this token bucket.
76
+ capacity: Maximum number of tokens the bucket can hold.
77
+ refill_frequency: Time in seconds between token refills.
78
+ initial_tokens: Starting number of tokens. Defaults to capacity if not specified.
79
+ refill_amount: Number of tokens added per refill.
80
+ max_sleep: Maximum seconds to sleep when rate limited. 0 means no limit.
81
+ expiry: Key expiry time in seconds - currently not implemented for local buckets.
82
+ tokens_to_consume: Number of tokens to consume per operation.
83
+
84
+ Example:
85
+ .. code-block:: python
86
+
87
+ bucket = AsyncLocalTokenBucket(name="api", capacity=10)
88
+ async with bucket:
89
+ await make_api_call()
90
+
91
+ Note: If you need to use this class from multiple threads (multiple event loops),
92
+ consider using SyncLocalTokenBucket instead, which provides proper thread safety.
93
+
94
+ """
95
+
96
+ # Class-level storage for bucket state (shared across instances)
97
+ # TODO: Currently there's no cleanup of old buckets.
98
+ # Consider adding periodic cleanup based on expiry.
99
+ _buckets: ClassVar[dict[str, dict]] = {}
100
+
101
+ async def __aenter__(self) -> None:
102
+ """Acquire token(s) from the token bucket and sleep until they are available."""
103
+ # Execute token bucket logic
104
+ # No lock needed: asyncio is single-threaded and execute_local_token_bucket_logic
105
+ # has no await points, making it atomic from asyncio's perspective
106
+ timestamp = self.execute_local_token_bucket_logic(self._buckets)
107
+
108
+ # Parse timestamp and sleep
109
+ sleep_time = self.parse_timestamp(timestamp)
110
+ await asyncio.sleep(sleep_time)
111
+
112
+ async def __aexit__(
113
+ self,
114
+ exc_type: type[BaseException] | None,
115
+ exc_val: BaseException | None,
116
+ exc_tb: TracebackType | None,
117
+ ) -> None:
118
+ return
@@ -0,0 +1,137 @@
1
+ """Synchronous and Asynchronous Redis-backed (Standalone or Cluster) token bucket implementations."""
2
+
3
+ import asyncio
4
+ import time
5
+ from types import TracebackType
6
+ from typing import ClassVar, cast
7
+
8
+ from redis_limiters.base import AsyncLuaScriptBase, SyncLuaScriptBase
9
+ from redis_limiters.token_bucket.token_bucket_base import TokenBucketBase, get_current_time_ms
10
+
11
+
12
+ class SyncRedisTokenBucket(TokenBucketBase, SyncLuaScriptBase):
13
+ """
14
+ Synchronous Redis-backed (Standalone or Cluster) token bucket.
15
+
16
+ Args:
17
+ name: Unique identifier for this token bucket.
18
+ connection: Redis connection (SyncRedis or SyncRedisCluster).
19
+ capacity: Maximum number of tokens the bucket can hold.
20
+ refill_frequency: Time in seconds between token refills.
21
+ initial_tokens: Starting number of tokens. Defaults to capacity if not specified.
22
+ refill_amount: Number of tokens added per refill.
23
+ max_sleep: Maximum seconds to sleep when rate limited. 0 means no limit.
24
+ expiry: Key expiry time in seconds.
25
+ tokens_to_consume: Number of tokens to consume per operation.
26
+
27
+ Example:
28
+ .. code-block:: python
29
+
30
+ from redis import Redis # or from redis.cluster import RedisCluster
31
+ redis_conn = Redis(host='localhost', port=6379)
32
+ bucket = SyncRedisTokenBucket(connection=redis_conn, name="api", capacity=10)
33
+ with bucket:
34
+ make_api_call()
35
+
36
+ """
37
+
38
+ script_name: ClassVar[str] = "token_bucket/token_bucket.lua"
39
+
40
+ def __enter__(self) -> float:
41
+ """Acquire token(s) from the token bucket and sleep until they are available."""
42
+ # Retrieve timestamp for when to wake up from Redis Lua script
43
+ milliseconds = get_current_time_ms()
44
+ timestamp: int = cast(
45
+ int,
46
+ self.script(
47
+ keys=[self.key],
48
+ args=[
49
+ self.capacity,
50
+ self.refill_amount,
51
+ self.initial_tokens or self.capacity,
52
+ self.refill_frequency,
53
+ milliseconds,
54
+ self.expiry,
55
+ self.tokens_to_consume,
56
+ ],
57
+ ),
58
+ )
59
+
60
+ # Estimate sleep time
61
+ sleep_time = self.parse_timestamp(timestamp)
62
+
63
+ # Sleep before returning
64
+ time.sleep(sleep_time)
65
+
66
+ return sleep_time
67
+
68
+ def __exit__(
69
+ self,
70
+ exc_type: type[BaseException] | None,
71
+ exc_val: BaseException | None,
72
+ exc_tb: TracebackType | None,
73
+ ) -> None:
74
+ return
75
+
76
+
77
+ class AsyncRedisTokenBucket(TokenBucketBase, AsyncLuaScriptBase):
78
+ """
79
+ Asynchronous Redis-backed (Standalone or Cluster) token bucket.
80
+
81
+ Args:
82
+ name: Unique identifier for this token bucket.
83
+ connection: Redis connection (AsyncRedis or AsyncRedisCluster).
84
+ capacity: Maximum number of tokens the bucket can hold.
85
+ refill_frequency: Time in seconds between token refills.
86
+ initial_tokens: Starting number of tokens. Defaults to capacity if not specified.
87
+ refill_amount: Number of tokens added per refill.
88
+ max_sleep: Maximum seconds to sleep when rate limited. 0 means no limit.
89
+ expiry: Key expiry time in seconds.
90
+ tokens_to_consume: Number of tokens to consume per operation.
91
+
92
+ Example:
93
+ .. code-block:: python
94
+
95
+ from redis.asyncio import Redis # or from redis.asyncio.cluster import RedisCluster
96
+ redis_conn = Redis(host='localhost', port=6379)
97
+ bucket = AsyncRedisTokenBucket(connection=redis_conn, name="api", capacity=10)
98
+ async with bucket:
99
+ await make_api_call()
100
+
101
+ """
102
+
103
+ script_name: ClassVar[str] = "token_bucket/token_bucket.lua"
104
+
105
+ async def __aenter__(self) -> None:
106
+ """Acquire token(s) from the token bucket and sleep until they are available."""
107
+ # Retrieve timestamp for when to wake up from Redis Lua script
108
+ milliseconds = get_current_time_ms()
109
+ timestamp: int = cast(
110
+ int,
111
+ await self.script(
112
+ keys=[self.key],
113
+ args=[
114
+ self.capacity,
115
+ self.refill_amount,
116
+ self.initial_tokens or self.capacity,
117
+ self.refill_frequency,
118
+ milliseconds,
119
+ self.expiry,
120
+ self.tokens_to_consume,
121
+ ],
122
+ ),
123
+ )
124
+
125
+ # Estimate sleep time
126
+ sleep_time = self.parse_timestamp(timestamp)
127
+
128
+ # Sleep before returning
129
+ await asyncio.sleep(sleep_time)
130
+
131
+ async def __aexit__(
132
+ self,
133
+ exc_type: type[BaseException] | None,
134
+ exc_val: BaseException | None,
135
+ exc_tb: TracebackType | None,
136
+ ) -> None:
137
+ return
@@ -0,0 +1,89 @@
1
+ --- Lua scripts are run atomically by default, and since redis
2
+ --- is single threaded, there are no race conditions to worry about.
3
+ ---
4
+ --- This script does three things, in order:
5
+ --- 1. Retrieves token bucket state, which means the last slot assigned,
6
+ --- and how many tokens are left to be assigned for that slot
7
+ --- 2. Works out whether we need to move to the next slot(s), or consume
8
+ --- tokens from the current one.
9
+ --- 3. Saves the token bucket state and returns the slot. The state is a
10
+ --- combination of the last slot assigned (timestamp) and the number of tokens left.
11
+ ---
12
+ --- The token bucket implementation is forward looking, so we're really just handing
13
+ --- out the next time there would be tokens in the bucket, and letting the client
14
+ --- decide wait until that time.
15
+ ---
16
+ --- returns:
17
+ --- * The assigned slot, as a millisecond timestamp
18
+
19
+ redis.replicate_commands()
20
+
21
+ -- Arguments
22
+ local capacity = tonumber(ARGV[1])
23
+ local refill_amount = tonumber(ARGV[2])
24
+ local initial_tokens = tonumber(ARGV[3])
25
+ local time_between_slots = tonumber(ARGV[4]) * 1000 -- Convert to milliseconds
26
+ local milliseconds = tonumber(ARGV[5])
27
+ local expiry = tonumber(ARGV[6])
28
+ local tokens_to_consume = tonumber(ARGV[7]) -- Number of tokens to consume
29
+
30
+ -- Validate that tokens_to_consume doesn't exceed capacity
31
+ if tokens_to_consume > capacity then
32
+ return redis.error_reply("Requested tokens exceed bucket capacity")
33
+ end
34
+
35
+ -- Validate that tokens_to_consume is positive
36
+ if tokens_to_consume <= 0 then
37
+ return redis.error_reply("Must consume at least 1 token")
38
+ end
39
+
40
+ -- Keys
41
+ local data_key = KEYS[1]
42
+
43
+ -- Get current time in milliseconds
44
+ local now = milliseconds
45
+
46
+ -- Default bucket values (used if no bucket exists yet)
47
+ local tokens = math.min(initial_tokens, capacity)
48
+ local slot = now
49
+
50
+ -- Retrieve stored state, if any
51
+ local data = redis.call('GET', data_key)
52
+ if data then
53
+ local last_slot, stored_tokens = data:match('(%S+) (%S+)')
54
+ slot = tonumber(last_slot)
55
+ tokens = tonumber(stored_tokens)
56
+
57
+ -- Calculate the number of slots that have passed since the last update
58
+ local slots_passed = math.floor((now - slot) / time_between_slots)
59
+ if slots_passed > 0 then
60
+ -- Refill the tokens based on the number of slots passed, capped by capacity
61
+ tokens = math.min(tokens + slots_passed * refill_amount, capacity)
62
+ -- Update the slot to this run
63
+ -- The previously added +20 ms execution penalty was removed as it was not needed
64
+ -- and all it did was add additional latency to all requests and in our case,
65
+ -- timing is handled gracefully with the condition used (wake_up_time < now)
66
+ slot = now
67
+ end
68
+ end
69
+
70
+ -- If not enough tokens are available, move to the next slot(s) and refill accordingly
71
+ if tokens < tokens_to_consume then
72
+ -- Calculate how many additional tokens we need
73
+ local needed_tokens = tokens_to_consume - tokens
74
+ -- Calculate how many slots we need to move forward to get enough tokens
75
+ local needed_slots = math.ceil(needed_tokens / refill_amount)
76
+ slot = slot + needed_slots * time_between_slots
77
+ tokens = tokens + needed_slots * refill_amount
78
+ -- Clamp tokens to capacity
79
+ tokens = math.min(tokens, capacity)
80
+ end
81
+
82
+ -- Consume tokens
83
+ tokens = tokens - tokens_to_consume
84
+
85
+ -- Save updated state and set expiry
86
+ redis.call('SETEX', data_key, expiry, string.format('%d %d', slot, tokens))
87
+
88
+ -- Return the slot when the next token(s) will be available
89
+ return slot