limits 4.0.1__py3-none-any.whl → 4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- limits/__init__.py +3 -1
- limits/_version.py +4 -4
- limits/aio/__init__.py +2 -0
- limits/aio/storage/__init__.py +4 -1
- limits/aio/storage/base.py +70 -24
- limits/aio/storage/etcd.py +8 -2
- limits/aio/storage/memcached.py +159 -33
- limits/aio/storage/memory.py +100 -13
- limits/aio/storage/mongodb.py +217 -9
- limits/aio/storage/redis/__init__.py +341 -0
- limits/aio/storage/redis/bridge.py +121 -0
- limits/aio/storage/redis/coredis.py +209 -0
- limits/aio/storage/redis/redispy.py +257 -0
- limits/aio/strategies.py +124 -1
- limits/errors.py +2 -0
- limits/limits.py +10 -11
- limits/resources/redis/lua_scripts/acquire_sliding_window.lua +45 -0
- limits/resources/redis/lua_scripts/sliding_window.lua +17 -0
- limits/storage/__init__.py +6 -3
- limits/storage/base.py +92 -24
- limits/storage/etcd.py +8 -2
- limits/storage/memcached.py +143 -34
- limits/storage/memory.py +99 -12
- limits/storage/mongodb.py +204 -11
- limits/storage/redis.py +159 -138
- limits/storage/redis_cluster.py +5 -3
- limits/storage/redis_sentinel.py +14 -35
- limits/storage/registry.py +3 -3
- limits/strategies.py +121 -5
- limits/typing.py +55 -19
- limits/util.py +29 -18
- limits-4.2.dist-info/METADATA +268 -0
- limits-4.2.dist-info/RECORD +42 -0
- limits/aio/storage/redis.py +0 -470
- limits-4.0.1.dist-info/METADATA +0 -192
- limits-4.0.1.dist-info/RECORD +0 -37
- {limits-4.0.1.dist-info → limits-4.2.dist-info}/LICENSE.txt +0 -0
- {limits-4.0.1.dist-info → limits-4.2.dist-info}/WHEEL +0 -0
- {limits-4.0.1.dist-info → limits-4.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,257 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import time
|
|
4
|
+
from typing import TYPE_CHECKING, Type, cast
|
|
5
|
+
|
|
6
|
+
from limits.aio.storage.redis.bridge import RedisBridge
|
|
7
|
+
from limits.errors import ConfigurationError
|
|
8
|
+
from limits.typing import AsyncRedisClient, Callable, Optional, Union
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
import redis.commands
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class RedispyBridge(RedisBridge):
|
|
15
|
+
DEFAULT_CLUSTER_OPTIONS: dict[str, Union[float, str, bool]] = {
|
|
16
|
+
"max_connections": 1000,
|
|
17
|
+
}
|
|
18
|
+
"Default options passed to :class:`redis.asyncio.RedisCluster`"
|
|
19
|
+
|
|
20
|
+
@property
|
|
21
|
+
def base_exceptions(self) -> Union[Type[Exception], tuple[Type[Exception], ...]]:
|
|
22
|
+
return (self.dependency.RedisError,)
|
|
23
|
+
|
|
24
|
+
def use_sentinel(
|
|
25
|
+
self,
|
|
26
|
+
service_name: Optional[str],
|
|
27
|
+
use_replicas: bool,
|
|
28
|
+
sentinel_kwargs: Optional[dict[str, Union[str, float, bool]]],
|
|
29
|
+
**options: Union[str, float, bool],
|
|
30
|
+
) -> None:
|
|
31
|
+
sentinel_configuration = []
|
|
32
|
+
|
|
33
|
+
connection_options = options.copy()
|
|
34
|
+
|
|
35
|
+
sep = self.parsed_uri.netloc.find("@") + 1
|
|
36
|
+
|
|
37
|
+
for loc in self.parsed_uri.netloc[sep:].split(","):
|
|
38
|
+
host, port = loc.split(":")
|
|
39
|
+
sentinel_configuration.append((host, int(port)))
|
|
40
|
+
service_name = (
|
|
41
|
+
self.parsed_uri.path.replace("/", "")
|
|
42
|
+
if self.parsed_uri.path
|
|
43
|
+
else service_name
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
if service_name is None:
|
|
47
|
+
raise ConfigurationError("'service_name' not provided")
|
|
48
|
+
|
|
49
|
+
self.sentinel = self.dependency.asyncio.Sentinel(
|
|
50
|
+
sentinel_configuration,
|
|
51
|
+
sentinel_kwargs={**self.parsed_auth, **(sentinel_kwargs or {})},
|
|
52
|
+
**{**self.parsed_auth, **connection_options},
|
|
53
|
+
)
|
|
54
|
+
self.storage = self.sentinel.master_for(service_name)
|
|
55
|
+
self.storage_replica = self.sentinel.slave_for(service_name)
|
|
56
|
+
self.connection_getter = lambda readonly: (
|
|
57
|
+
self.storage_replica if readonly and use_replicas else self.storage
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
def use_basic(self, **options: Union[str, float, bool]) -> None:
|
|
61
|
+
if connection_pool := options.pop("connection_pool", None):
|
|
62
|
+
self.storage = self.dependency.asyncio.Redis(
|
|
63
|
+
connection_pool=connection_pool, **options
|
|
64
|
+
)
|
|
65
|
+
else:
|
|
66
|
+
self.storage = self.dependency.asyncio.Redis.from_url(self.uri, **options)
|
|
67
|
+
|
|
68
|
+
self.connection_getter = lambda _: self.storage
|
|
69
|
+
|
|
70
|
+
def use_cluster(self, **options: Union[str, float, bool]) -> None:
|
|
71
|
+
sep = self.parsed_uri.netloc.find("@") + 1
|
|
72
|
+
cluster_hosts = []
|
|
73
|
+
|
|
74
|
+
for loc in self.parsed_uri.netloc[sep:].split(","):
|
|
75
|
+
host, port = loc.split(":")
|
|
76
|
+
cluster_hosts.append(
|
|
77
|
+
self.dependency.asyncio.cluster.ClusterNode(host=host, port=int(port))
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
self.storage = self.dependency.asyncio.RedisCluster(
|
|
81
|
+
startup_nodes=cluster_hosts,
|
|
82
|
+
**{**self.DEFAULT_CLUSTER_OPTIONS, **self.parsed_auth, **options},
|
|
83
|
+
)
|
|
84
|
+
self.connection_getter = lambda _: self.storage
|
|
85
|
+
|
|
86
|
+
lua_moving_window: "redis.commands.core.Script"
|
|
87
|
+
lua_acquire_moving_window: "redis.commands.core.Script"
|
|
88
|
+
lua_sliding_window: "redis.commands.core.Script"
|
|
89
|
+
lua_acquire_sliding_window: "redis.commands.core.Script"
|
|
90
|
+
lua_clear_keys: "redis.commands.core.Script"
|
|
91
|
+
lua_incr_expire: "redis.commands.core.Script"
|
|
92
|
+
connection_getter: Callable[[bool], AsyncRedisClient]
|
|
93
|
+
|
|
94
|
+
def get_connection(self, readonly: bool = False) -> AsyncRedisClient:
|
|
95
|
+
return self.connection_getter(readonly)
|
|
96
|
+
|
|
97
|
+
def register_scripts(self) -> None:
|
|
98
|
+
# Redis-py uses a slightly different script registration
|
|
99
|
+
self.lua_moving_window = self.get_connection().register_script(
|
|
100
|
+
self.SCRIPT_MOVING_WINDOW
|
|
101
|
+
)
|
|
102
|
+
self.lua_acquire_moving_window = self.get_connection().register_script(
|
|
103
|
+
self.SCRIPT_ACQUIRE_MOVING_WINDOW
|
|
104
|
+
)
|
|
105
|
+
self.lua_clear_keys = self.get_connection().register_script(
|
|
106
|
+
self.SCRIPT_CLEAR_KEYS
|
|
107
|
+
)
|
|
108
|
+
self.lua_incr_expire = self.get_connection().register_script(
|
|
109
|
+
self.SCRIPT_INCR_EXPIRE
|
|
110
|
+
)
|
|
111
|
+
self.lua_sliding_window = self.get_connection().register_script(
|
|
112
|
+
self.SCRIPT_SLIDING_WINDOW
|
|
113
|
+
)
|
|
114
|
+
self.lua_acquire_sliding_window = self.get_connection().register_script(
|
|
115
|
+
self.SCRIPT_ACQUIRE_SLIDING_WINDOW
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
async def incr(
|
|
119
|
+
self,
|
|
120
|
+
key: str,
|
|
121
|
+
expiry: int,
|
|
122
|
+
elastic_expiry: bool = False,
|
|
123
|
+
amount: int = 1,
|
|
124
|
+
) -> int:
|
|
125
|
+
"""
|
|
126
|
+
increments the counter for a given rate limit key
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
:param key: the key to increment
|
|
130
|
+
:param expiry: amount in seconds for the key to expire in
|
|
131
|
+
:param amount: the number to increment by
|
|
132
|
+
"""
|
|
133
|
+
key = self.prefixed_key(key)
|
|
134
|
+
|
|
135
|
+
if elastic_expiry:
|
|
136
|
+
value = await self.get_connection().incrby(key, amount)
|
|
137
|
+
await self.get_connection().expire(key, expiry)
|
|
138
|
+
return value
|
|
139
|
+
else:
|
|
140
|
+
return cast(int, await self.lua_incr_expire([key], [expiry, amount]))
|
|
141
|
+
|
|
142
|
+
async def get(self, key: str) -> int:
|
|
143
|
+
"""
|
|
144
|
+
|
|
145
|
+
:param key: the key to get the counter value for
|
|
146
|
+
"""
|
|
147
|
+
|
|
148
|
+
key = self.prefixed_key(key)
|
|
149
|
+
return int(await self.get_connection(readonly=True).get(key) or 0)
|
|
150
|
+
|
|
151
|
+
async def clear(self, key: str) -> None:
|
|
152
|
+
"""
|
|
153
|
+
:param key: the key to clear rate limits for
|
|
154
|
+
|
|
155
|
+
"""
|
|
156
|
+
key = self.prefixed_key(key)
|
|
157
|
+
await self.get_connection().delete(key)
|
|
158
|
+
|
|
159
|
+
async def lua_reset(self) -> Optional[int]:
|
|
160
|
+
return cast(int, await self.lua_clear_keys([self.prefixed_key("*")]))
|
|
161
|
+
|
|
162
|
+
async def get_moving_window(
|
|
163
|
+
self, key: str, limit: int, expiry: int
|
|
164
|
+
) -> tuple[float, int]:
|
|
165
|
+
"""
|
|
166
|
+
returns the starting point and the number of entries in the moving
|
|
167
|
+
window
|
|
168
|
+
|
|
169
|
+
:param key: rate limit key
|
|
170
|
+
:param expiry: expiry of entry
|
|
171
|
+
:return: (previous count, previous TTL, current count, current TTL)
|
|
172
|
+
"""
|
|
173
|
+
key = self.prefixed_key(key)
|
|
174
|
+
timestamp = time.time()
|
|
175
|
+
window = await self.lua_moving_window([key], [timestamp - expiry, limit])
|
|
176
|
+
if window:
|
|
177
|
+
return float(window[0]), window[1]
|
|
178
|
+
return timestamp, 0
|
|
179
|
+
|
|
180
|
+
async def get_sliding_window(
|
|
181
|
+
self, previous_key: str, current_key: str, expiry: int
|
|
182
|
+
) -> tuple[int, float, int, float]:
|
|
183
|
+
if window := await self.lua_sliding_window(
|
|
184
|
+
[self.prefixed_key(previous_key), self.prefixed_key(current_key)], [expiry]
|
|
185
|
+
):
|
|
186
|
+
return (
|
|
187
|
+
int(window[0] or 0),
|
|
188
|
+
max(0, float(window[1] or 0)) / 1000,
|
|
189
|
+
int(window[2] or 0),
|
|
190
|
+
max(0, float(window[3] or 0)) / 1000,
|
|
191
|
+
)
|
|
192
|
+
return 0, 0.0, 0, 0.0
|
|
193
|
+
|
|
194
|
+
async def acquire_entry(
|
|
195
|
+
self,
|
|
196
|
+
key: str,
|
|
197
|
+
limit: int,
|
|
198
|
+
expiry: int,
|
|
199
|
+
amount: int = 1,
|
|
200
|
+
) -> bool:
|
|
201
|
+
"""
|
|
202
|
+
:param key: rate limit key to acquire an entry in
|
|
203
|
+
:param limit: amount of entries allowed
|
|
204
|
+
:param expiry: expiry of the entry
|
|
205
|
+
|
|
206
|
+
"""
|
|
207
|
+
key = self.prefixed_key(key)
|
|
208
|
+
timestamp = time.time()
|
|
209
|
+
acquired = await self.lua_acquire_moving_window(
|
|
210
|
+
[key], [timestamp, limit, expiry, amount]
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
return bool(acquired)
|
|
214
|
+
|
|
215
|
+
async def acquire_sliding_window_entry(
|
|
216
|
+
self,
|
|
217
|
+
previous_key: str,
|
|
218
|
+
current_key: str,
|
|
219
|
+
limit: int,
|
|
220
|
+
expiry: int,
|
|
221
|
+
amount: int = 1,
|
|
222
|
+
) -> bool:
|
|
223
|
+
previous_key = self.prefixed_key(previous_key)
|
|
224
|
+
current_key = self.prefixed_key(current_key)
|
|
225
|
+
acquired = await self.lua_acquire_sliding_window(
|
|
226
|
+
[previous_key, current_key], [limit, expiry, amount]
|
|
227
|
+
)
|
|
228
|
+
return bool(acquired)
|
|
229
|
+
|
|
230
|
+
async def get_expiry(self, key: str) -> float:
|
|
231
|
+
"""
|
|
232
|
+
:param key: the key to get the expiry for
|
|
233
|
+
"""
|
|
234
|
+
|
|
235
|
+
key = self.prefixed_key(key)
|
|
236
|
+
return max(await self.get_connection().ttl(key), 0) + time.time()
|
|
237
|
+
|
|
238
|
+
async def check(self) -> bool:
|
|
239
|
+
"""
|
|
240
|
+
check if storage is healthy
|
|
241
|
+
"""
|
|
242
|
+
try:
|
|
243
|
+
await self.get_connection().ping()
|
|
244
|
+
|
|
245
|
+
return True
|
|
246
|
+
except: # noqa
|
|
247
|
+
return False
|
|
248
|
+
|
|
249
|
+
async def reset(self) -> Optional[int]:
|
|
250
|
+
prefix = self.prefixed_key("*")
|
|
251
|
+
keys = await self.storage.keys(
|
|
252
|
+
prefix, target_nodes=self.dependency.asyncio.cluster.RedisCluster.ALL_NODES
|
|
253
|
+
)
|
|
254
|
+
count = 0
|
|
255
|
+
for key in keys:
|
|
256
|
+
count += await self.storage.delete(key)
|
|
257
|
+
return count
|
limits/aio/strategies.py
CHANGED
|
@@ -2,13 +2,20 @@
|
|
|
2
2
|
Asynchronous rate limiting strategies
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import time
|
|
5
8
|
from abc import ABC, abstractmethod
|
|
6
|
-
from
|
|
9
|
+
from math import floor, inf
|
|
10
|
+
|
|
11
|
+
from deprecated.sphinx import deprecated, versionadded
|
|
7
12
|
|
|
8
13
|
from ..limits import RateLimitItem
|
|
9
14
|
from ..storage import StorageTypes
|
|
15
|
+
from ..typing import cast
|
|
10
16
|
from ..util import WindowStats
|
|
11
17
|
from .storage import MovingWindowSupport, Storage
|
|
18
|
+
from .storage.base import SlidingWindowCounterSupport
|
|
12
19
|
|
|
13
20
|
|
|
14
21
|
class RateLimiter(ABC):
|
|
@@ -183,6 +190,121 @@ class FixedWindowRateLimiter(RateLimiter):
|
|
|
183
190
|
return WindowStats(reset, remaining)
|
|
184
191
|
|
|
185
192
|
|
|
193
|
+
@versionadded(version="4.1")
|
|
194
|
+
class SlidingWindowCounterRateLimiter(RateLimiter):
|
|
195
|
+
"""
|
|
196
|
+
Reference: :ref:`strategies:sliding window counter`
|
|
197
|
+
"""
|
|
198
|
+
|
|
199
|
+
def __init__(self, storage: StorageTypes):
|
|
200
|
+
if not hasattr(storage, "get_sliding_window") or not hasattr(
|
|
201
|
+
storage, "acquire_sliding_window_entry"
|
|
202
|
+
):
|
|
203
|
+
raise NotImplementedError(
|
|
204
|
+
"SlidingWindowCounterRateLimiting is not implemented for storage "
|
|
205
|
+
"of type %s" % storage.__class__
|
|
206
|
+
)
|
|
207
|
+
super().__init__(storage)
|
|
208
|
+
|
|
209
|
+
def _weighted_count(
|
|
210
|
+
self,
|
|
211
|
+
item: RateLimitItem,
|
|
212
|
+
previous_count: int,
|
|
213
|
+
previous_expires_in: float,
|
|
214
|
+
current_count: int,
|
|
215
|
+
) -> float:
|
|
216
|
+
"""
|
|
217
|
+
Return the approximated by weighting the previous window count and adding the current window count.
|
|
218
|
+
"""
|
|
219
|
+
return previous_count * previous_expires_in / item.get_expiry() + current_count
|
|
220
|
+
|
|
221
|
+
async def hit(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
|
|
222
|
+
"""
|
|
223
|
+
Consume the rate limit
|
|
224
|
+
|
|
225
|
+
:param item: The rate limit item
|
|
226
|
+
:param identifiers: variable list of strings to uniquely identify this
|
|
227
|
+
instance of the limit
|
|
228
|
+
:param cost: The cost of this hit, default 1
|
|
229
|
+
"""
|
|
230
|
+
return await cast(
|
|
231
|
+
SlidingWindowCounterSupport, self.storage
|
|
232
|
+
).acquire_sliding_window_entry(
|
|
233
|
+
item.key_for(*identifiers),
|
|
234
|
+
item.amount,
|
|
235
|
+
item.get_expiry(),
|
|
236
|
+
cost,
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
async def test(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
|
|
240
|
+
"""
|
|
241
|
+
Check if the rate limit can be consumed
|
|
242
|
+
|
|
243
|
+
:param item: The rate limit item
|
|
244
|
+
:param identifiers: variable list of strings to uniquely identify this
|
|
245
|
+
instance of the limit
|
|
246
|
+
:param cost: The expected cost to be consumed, default 1
|
|
247
|
+
"""
|
|
248
|
+
|
|
249
|
+
previous_count, previous_expires_in, current_count, _ = await cast(
|
|
250
|
+
SlidingWindowCounterSupport, self.storage
|
|
251
|
+
).get_sliding_window(item.key_for(*identifiers), item.get_expiry())
|
|
252
|
+
|
|
253
|
+
return (
|
|
254
|
+
self._weighted_count(
|
|
255
|
+
item, previous_count, previous_expires_in, current_count
|
|
256
|
+
)
|
|
257
|
+
< item.amount - cost + 1
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
async def get_window_stats(
|
|
261
|
+
self, item: RateLimitItem, *identifiers: str
|
|
262
|
+
) -> WindowStats:
|
|
263
|
+
"""
|
|
264
|
+
Query the reset time and remaining amount for the limit.
|
|
265
|
+
|
|
266
|
+
:param item: The rate limit item
|
|
267
|
+
:param identifiers: variable list of strings to uniquely identify this
|
|
268
|
+
instance of the limit
|
|
269
|
+
:return: (reset time, remaining)
|
|
270
|
+
"""
|
|
271
|
+
|
|
272
|
+
(
|
|
273
|
+
previous_count,
|
|
274
|
+
previous_expires_in,
|
|
275
|
+
current_count,
|
|
276
|
+
current_expires_in,
|
|
277
|
+
) = await cast(SlidingWindowCounterSupport, self.storage).get_sliding_window(
|
|
278
|
+
item.key_for(*identifiers), item.get_expiry()
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
remaining = max(
|
|
282
|
+
0,
|
|
283
|
+
item.amount
|
|
284
|
+
- floor(
|
|
285
|
+
self._weighted_count(
|
|
286
|
+
item, previous_count, previous_expires_in, current_count
|
|
287
|
+
)
|
|
288
|
+
),
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
now = time.time()
|
|
292
|
+
|
|
293
|
+
if not (previous_count or current_count):
|
|
294
|
+
return WindowStats(now, remaining)
|
|
295
|
+
|
|
296
|
+
expiry = item.get_expiry()
|
|
297
|
+
|
|
298
|
+
previous_reset_in, current_reset_in = inf, inf
|
|
299
|
+
if previous_count:
|
|
300
|
+
previous_reset_in = previous_expires_in % (expiry / previous_count)
|
|
301
|
+
if current_count:
|
|
302
|
+
current_reset_in = current_expires_in % expiry
|
|
303
|
+
|
|
304
|
+
return WindowStats(now + min(previous_reset_in, current_reset_in), remaining)
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
@deprecated(version="4.1")
|
|
186
308
|
class FixedWindowElasticExpiryRateLimiter(FixedWindowRateLimiter):
|
|
187
309
|
"""
|
|
188
310
|
Reference: :ref:`strategies:fixed window with elastic expiry`
|
|
@@ -208,6 +330,7 @@ class FixedWindowElasticExpiryRateLimiter(FixedWindowRateLimiter):
|
|
|
208
330
|
|
|
209
331
|
|
|
210
332
|
STRATEGIES = {
|
|
333
|
+
"sliding-window-counter": SlidingWindowCounterRateLimiter,
|
|
211
334
|
"fixed-window": FixedWindowRateLimiter,
|
|
212
335
|
"fixed-window-elastic-expiry": FixedWindowElasticExpiryRateLimiter,
|
|
213
336
|
"moving-window": MovingWindowRateLimiter,
|
limits/errors.py
CHANGED
limits/limits.py
CHANGED
|
@@ -3,14 +3,13 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
from functools import total_ordering
|
|
6
|
-
from typing import Dict, NamedTuple, Optional, Tuple, Type, Union, cast
|
|
7
6
|
|
|
8
|
-
from limits.typing import ClassVar,
|
|
7
|
+
from limits.typing import ClassVar, NamedTuple, cast
|
|
9
8
|
|
|
10
9
|
|
|
11
|
-
def safe_string(value:
|
|
10
|
+
def safe_string(value: bytes | str | int | float) -> str:
|
|
12
11
|
"""
|
|
13
|
-
|
|
12
|
+
normalize a byte/str/int or float to a str
|
|
14
13
|
"""
|
|
15
14
|
|
|
16
15
|
if isinstance(value, bytes):
|
|
@@ -33,15 +32,15 @@ TIME_TYPES = dict(
|
|
|
33
32
|
second=Granularity(1, "second"),
|
|
34
33
|
)
|
|
35
34
|
|
|
36
|
-
GRANULARITIES:
|
|
35
|
+
GRANULARITIES: dict[str, type[RateLimitItem]] = {}
|
|
37
36
|
|
|
38
37
|
|
|
39
38
|
class RateLimitItemMeta(type):
|
|
40
39
|
def __new__(
|
|
41
40
|
cls,
|
|
42
41
|
name: str,
|
|
43
|
-
parents:
|
|
44
|
-
dct:
|
|
42
|
+
parents: tuple[type, ...],
|
|
43
|
+
dct: dict[str, Granularity | list[str]],
|
|
45
44
|
) -> RateLimitItemMeta:
|
|
46
45
|
if "__slots__" not in dct:
|
|
47
46
|
dct["__slots__"] = []
|
|
@@ -49,7 +48,7 @@ class RateLimitItemMeta(type):
|
|
|
49
48
|
|
|
50
49
|
if "GRANULARITY" in dct:
|
|
51
50
|
GRANULARITIES[dct["GRANULARITY"][1]] = cast(
|
|
52
|
-
|
|
51
|
+
type[RateLimitItem], granularity
|
|
53
52
|
)
|
|
54
53
|
|
|
55
54
|
return granularity
|
|
@@ -77,7 +76,7 @@ class RateLimitItem(metaclass=RateLimitItemMeta):
|
|
|
77
76
|
"""
|
|
78
77
|
|
|
79
78
|
def __init__(
|
|
80
|
-
self, amount: int, multiples:
|
|
79
|
+
self, amount: int, multiples: int | None = 1, namespace: str = "LIMITER"
|
|
81
80
|
):
|
|
82
81
|
self.namespace = namespace
|
|
83
82
|
self.amount = int(amount)
|
|
@@ -101,14 +100,14 @@ class RateLimitItem(metaclass=RateLimitItemMeta):
|
|
|
101
100
|
|
|
102
101
|
return self.GRANULARITY.seconds * self.multiples
|
|
103
102
|
|
|
104
|
-
def key_for(self, *identifiers: str) -> str:
|
|
103
|
+
def key_for(self, *identifiers: bytes | str | int | float) -> str:
|
|
105
104
|
"""
|
|
106
105
|
Constructs a key for the current limit and any additional
|
|
107
106
|
identifiers provided.
|
|
108
107
|
|
|
109
108
|
:param identifiers: a list of strings to append to the key
|
|
110
109
|
:return: a string key identifying this resource with
|
|
111
|
-
each identifier
|
|
110
|
+
each identifier separated with a '/' delimiter.
|
|
112
111
|
"""
|
|
113
112
|
remainder = "/".join(
|
|
114
113
|
[safe_string(k) for k in identifiers]
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
-- Time is in milliseconds in this script: TTL, expiry...
|
|
2
|
+
|
|
3
|
+
local limit = tonumber(ARGV[1])
|
|
4
|
+
local expiry = tonumber(ARGV[2]) * 1000
|
|
5
|
+
local amount = tonumber(ARGV[3])
|
|
6
|
+
|
|
7
|
+
if amount > limit then
|
|
8
|
+
return false
|
|
9
|
+
end
|
|
10
|
+
|
|
11
|
+
local current_ttl = tonumber(redis.call('pttl', KEYS[2]))
|
|
12
|
+
|
|
13
|
+
if current_ttl > 0 and current_ttl < expiry then
|
|
14
|
+
-- Current window expired, shift it to the previous window
|
|
15
|
+
redis.call('rename', KEYS[2], KEYS[1])
|
|
16
|
+
redis.call('set', KEYS[2], 0, 'PX', current_ttl + expiry)
|
|
17
|
+
end
|
|
18
|
+
|
|
19
|
+
local previous_count = tonumber(redis.call('get', KEYS[1])) or 0
|
|
20
|
+
local previous_ttl = tonumber(redis.call('pttl', KEYS[1])) or 0
|
|
21
|
+
local current_count = tonumber(redis.call('get', KEYS[2])) or 0
|
|
22
|
+
current_ttl = tonumber(redis.call('pttl', KEYS[2])) or 0
|
|
23
|
+
|
|
24
|
+
-- If the values don't exist yet, consider the TTL is 0
|
|
25
|
+
if previous_ttl <= 0 then
|
|
26
|
+
previous_ttl = 0
|
|
27
|
+
end
|
|
28
|
+
if current_ttl <= 0 then
|
|
29
|
+
current_ttl = 0
|
|
30
|
+
end
|
|
31
|
+
local weighted_count = math.floor(previous_count * previous_ttl / expiry) + current_count
|
|
32
|
+
|
|
33
|
+
if (weighted_count + amount) > limit then
|
|
34
|
+
return false
|
|
35
|
+
end
|
|
36
|
+
|
|
37
|
+
-- If the current counter exists, increase its value
|
|
38
|
+
if redis.call('exists', KEYS[2]) == 1 then
|
|
39
|
+
redis.call('incrby', KEYS[2], amount)
|
|
40
|
+
else
|
|
41
|
+
-- Otherwise, set the value with twice the expiry time
|
|
42
|
+
redis.call('set', KEYS[2], amount, 'PX', expiry * 2)
|
|
43
|
+
end
|
|
44
|
+
|
|
45
|
+
return true
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
local expiry = tonumber(ARGV[1]) * 1000
|
|
2
|
+
local previous_count = redis.call('get', KEYS[1])
|
|
3
|
+
local previous_ttl = redis.call('pttl', KEYS[1])
|
|
4
|
+
local current_count = redis.call('get', KEYS[2])
|
|
5
|
+
local current_ttl = redis.call('pttl', KEYS[2])
|
|
6
|
+
|
|
7
|
+
if current_ttl > 0 and current_ttl < expiry then
|
|
8
|
+
-- Current window expired, shift it to the previous window
|
|
9
|
+
redis.call('rename', KEYS[2], KEYS[1])
|
|
10
|
+
redis.call('set', KEYS[2], 0, 'PX', current_ttl + expiry)
|
|
11
|
+
previous_count = redis.call('get', KEYS[1])
|
|
12
|
+
previous_ttl = redis.call('pttl', KEYS[1])
|
|
13
|
+
current_count = redis.call('get', KEYS[2])
|
|
14
|
+
current_ttl = redis.call('pttl', KEYS[2])
|
|
15
|
+
end
|
|
16
|
+
|
|
17
|
+
return {previous_count, previous_ttl, current_count, current_ttl}
|
limits/storage/__init__.py
CHANGED
|
@@ -3,13 +3,15 @@ Implementations of storage backends to be used with
|
|
|
3
3
|
:class:`limits.strategies.RateLimiter` strategies
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
6
8
|
import urllib
|
|
7
|
-
from typing import Union, cast
|
|
8
9
|
|
|
9
|
-
import limits
|
|
10
|
+
import limits # noqa
|
|
10
11
|
|
|
11
12
|
from ..errors import ConfigurationError
|
|
12
|
-
from
|
|
13
|
+
from ..typing import Union, cast
|
|
14
|
+
from .base import MovingWindowSupport, SlidingWindowCounterSupport, Storage
|
|
13
15
|
from .etcd import EtcdStorage
|
|
14
16
|
from .memcached import MemcachedStorage
|
|
15
17
|
from .memory import MemoryStorage
|
|
@@ -67,6 +69,7 @@ __all__ = [
|
|
|
67
69
|
"storage_from_string",
|
|
68
70
|
"Storage",
|
|
69
71
|
"MovingWindowSupport",
|
|
72
|
+
"SlidingWindowCounterSupport",
|
|
70
73
|
"EtcdStorage",
|
|
71
74
|
"MongoDBStorageBase",
|
|
72
75
|
"MemoryStorage",
|