limits 4.0.1__py3-none-any.whl → 4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. limits/__init__.py +3 -1
  2. limits/_version.py +4 -4
  3. limits/aio/__init__.py +2 -0
  4. limits/aio/storage/__init__.py +4 -1
  5. limits/aio/storage/base.py +70 -24
  6. limits/aio/storage/etcd.py +8 -2
  7. limits/aio/storage/memcached.py +159 -33
  8. limits/aio/storage/memory.py +100 -13
  9. limits/aio/storage/mongodb.py +217 -9
  10. limits/aio/storage/redis/__init__.py +341 -0
  11. limits/aio/storage/redis/bridge.py +121 -0
  12. limits/aio/storage/redis/coredis.py +209 -0
  13. limits/aio/storage/redis/redispy.py +257 -0
  14. limits/aio/strategies.py +124 -1
  15. limits/errors.py +2 -0
  16. limits/limits.py +10 -11
  17. limits/resources/redis/lua_scripts/acquire_sliding_window.lua +45 -0
  18. limits/resources/redis/lua_scripts/sliding_window.lua +17 -0
  19. limits/storage/__init__.py +6 -3
  20. limits/storage/base.py +92 -24
  21. limits/storage/etcd.py +8 -2
  22. limits/storage/memcached.py +143 -34
  23. limits/storage/memory.py +99 -12
  24. limits/storage/mongodb.py +204 -11
  25. limits/storage/redis.py +159 -138
  26. limits/storage/redis_cluster.py +5 -3
  27. limits/storage/redis_sentinel.py +14 -35
  28. limits/storage/registry.py +3 -3
  29. limits/strategies.py +121 -5
  30. limits/typing.py +55 -19
  31. limits/util.py +29 -18
  32. limits-4.2.dist-info/METADATA +268 -0
  33. limits-4.2.dist-info/RECORD +42 -0
  34. limits/aio/storage/redis.py +0 -470
  35. limits-4.0.1.dist-info/METADATA +0 -192
  36. limits-4.0.1.dist-info/RECORD +0 -37
  37. {limits-4.0.1.dist-info → limits-4.2.dist-info}/LICENSE.txt +0 -0
  38. {limits-4.0.1.dist-info → limits-4.2.dist-info}/WHEEL +0 -0
  39. {limits-4.0.1.dist-info → limits-4.2.dist-info}/top_level.txt +0 -0
limits/storage/base.py CHANGED
@@ -3,30 +3,32 @@ from __future__ import annotations
3
3
  import functools
4
4
  import threading
5
5
  from abc import ABC, abstractmethod
6
- from typing import Any, cast
7
6
 
8
7
  from limits import errors
9
8
  from limits.storage.registry import StorageRegistry
10
9
  from limits.typing import (
10
+ Any,
11
11
  Callable,
12
- List,
13
12
  Optional,
14
13
  P,
15
14
  R,
16
- Tuple,
17
15
  Type,
18
16
  Union,
17
+ cast,
19
18
  )
20
19
  from limits.util import LazyDependency
21
20
 
22
21
 
23
- def _wrap_errors(storage: Storage, fn: Callable[P, R]) -> Callable[P, R]:
22
+ def _wrap_errors(
23
+ fn: Callable[P, R],
24
+ ) -> Callable[P, R]:
24
25
  @functools.wraps(fn)
25
26
  def inner(*args: P.args, **kwargs: P.kwargs) -> R:
27
+ instance = cast(Storage, args[0])
26
28
  try:
27
29
  return fn(*args, **kwargs)
28
- except storage.base_exceptions as exc:
29
- if storage.wrap_exceptions:
30
+ except instance.base_exceptions as exc:
31
+ if instance.wrap_exceptions:
30
32
  raise errors.StorageError(exc) from exc
31
33
  raise
32
34
 
@@ -38,12 +40,10 @@ class Storage(LazyDependency, metaclass=StorageRegistry):
38
40
  Base class to extend when implementing a storage backend.
39
41
  """
40
42
 
41
- STORAGE_SCHEME: Optional[List[str]]
43
+ STORAGE_SCHEME: Optional[list[str]]
42
44
  """The storage schemes to register against this implementation"""
43
45
 
44
- def __new__(cls, *args: Any, **kwargs: Any) -> Storage: # type: ignore[misc]
45
- inst = super().__new__(cls)
46
-
46
+ def __init_subclass__(cls, **kwargs: Any) -> None: # type: ignore[explicit-any]
47
47
  for method in {
48
48
  "incr",
49
49
  "get",
@@ -52,9 +52,8 @@ class Storage(LazyDependency, metaclass=StorageRegistry):
52
52
  "reset",
53
53
  "clear",
54
54
  }:
55
- setattr(inst, method, _wrap_errors(inst, getattr(inst, method)))
56
-
57
- return inst
55
+ setattr(cls, method, _wrap_errors(getattr(cls, method)))
56
+ super().__init_subclass__(**kwargs)
58
57
 
59
58
  def __init__(
60
59
  self,
@@ -73,7 +72,7 @@ class Storage(LazyDependency, metaclass=StorageRegistry):
73
72
 
74
73
  @property
75
74
  @abstractmethod
76
- def base_exceptions(self) -> Union[Type[Exception], Tuple[Type[Exception], ...]]:
75
+ def base_exceptions(self) -> Union[Type[Exception], tuple[Type[Exception], ...]]:
77
76
  raise NotImplementedError
78
77
 
79
78
  @abstractmethod
@@ -131,24 +130,21 @@ class Storage(LazyDependency, metaclass=StorageRegistry):
131
130
 
132
131
  class MovingWindowSupport(ABC):
133
132
  """
134
- Abstract base for storages that intend to support
135
- the moving window strategy
133
+ Abstract base class for storages that support
134
+ the :ref:`strategies:moving window` strategy
136
135
  """
137
136
 
138
- def __new__(cls, *args: Any, **kwargs: Any) -> MovingWindowSupport: # type: ignore[misc]
139
- inst = super().__new__(cls)
140
-
137
+ def __init_subclass__(cls, **kwargs: Any) -> None: # type: ignore[explicit-any]
141
138
  for method in {
142
139
  "acquire_entry",
143
140
  "get_moving_window",
144
141
  }:
145
142
  setattr(
146
- inst,
143
+ cls,
147
144
  method,
148
- _wrap_errors(cast(Storage, inst), getattr(inst, method)),
145
+ _wrap_errors(getattr(cls, method)),
149
146
  )
150
-
151
- return inst
147
+ super().__init_subclass__(**kwargs)
152
148
 
153
149
  @abstractmethod
154
150
  def acquire_entry(self, key: str, limit: int, expiry: int, amount: int = 1) -> bool:
@@ -161,7 +157,7 @@ class MovingWindowSupport(ABC):
161
157
  raise NotImplementedError
162
158
 
163
159
  @abstractmethod
164
- def get_moving_window(self, key: str, limit: int, expiry: int) -> Tuple[float, int]:
160
+ def get_moving_window(self, key: str, limit: int, expiry: int) -> tuple[float, int]:
165
161
  """
166
162
  returns the starting point and the number of entries in the moving
167
163
  window
@@ -171,3 +167,75 @@ class MovingWindowSupport(ABC):
171
167
  :return: (start of window, number of acquired entries)
172
168
  """
173
169
  raise NotImplementedError
170
+
171
+
172
+ class SlidingWindowCounterSupport(ABC):
173
+ """
174
+ Abstract base class for storages that support
175
+ the :ref:`strategies:sliding window counter` strategy.
176
+ """
177
+
178
+ def __init_subclass__(cls, **kwargs: Any) -> None: # type: ignore[explicit-any]
179
+ for method in {"acquire_sliding_window_entry", "get_sliding_window"}:
180
+ setattr(
181
+ cls,
182
+ method,
183
+ _wrap_errors(getattr(cls, method)),
184
+ )
185
+ super().__init_subclass__(**kwargs)
186
+
187
+ @abstractmethod
188
+ def acquire_sliding_window_entry(
189
+ self, key: str, limit: int, expiry: int, amount: int = 1
190
+ ) -> bool:
191
+ """
192
+ Acquire an entry if the weighted count of the current and previous
193
+ windows is less than or equal to the limit
194
+
195
+ :param key: rate limit key to acquire an entry in
196
+ :param limit: amount of entries allowed
197
+ :param expiry: expiry of the entry
198
+ :param amount: the number of entries to acquire
199
+ """
200
+ raise NotImplementedError
201
+
202
+ @abstractmethod
203
+ def get_sliding_window(
204
+ self, key: str, expiry: int
205
+ ) -> tuple[int, float, int, float]:
206
+ """
207
+ Return the previous and current window information.
208
+
209
+ :param key: the rate limit key
210
+ :param expiry: the rate limit expiry, needed to compute the key in some implementations
211
+ :return: a tuple of (int, float, int, float) with the following information:
212
+ - previous window counter
213
+ - previous window TTL
214
+ - current window counter
215
+ - current window TTL
216
+ """
217
+ raise NotImplementedError
218
+
219
+
220
+ class TimestampedSlidingWindow:
221
+ """Helper class for storage that support the sliding window counter, with timestamp based keys."""
222
+
223
+ @classmethod
224
+ def sliding_window_keys(cls, key: str, expiry: int, at: float) -> tuple[str, str]:
225
+ """
226
+ returns the previous and the current window's keys.
227
+
228
+ :param key: the key to get the window's keys from
229
+ :param expiry: the expiry of the limit item, in seconds
230
+ :param at: the timestamp to get the keys from. Default to now, ie ``time.time()``
231
+
232
+ Returns a tuple with the previous and the current key: (previous, current).
233
+
234
+ Example:
235
+ - key = "mykey"
236
+ - expiry = 60
237
+ - at = 1738576292.6631825
238
+
239
+ The return value will be the tuple ``("mykey/28976271", "mykey/28976270")``.
240
+ """
241
+ return f"{key}/{int((at - expiry) / expiry)}", f"{key}/{int(at / expiry)}"
limits/storage/etcd.py CHANGED
@@ -1,9 +1,11 @@
1
+ from __future__ import annotations
2
+
1
3
  import time
2
4
  import urllib.parse
3
- from typing import TYPE_CHECKING, Optional, Tuple, Type, Union
4
5
 
5
6
  from limits.errors import ConcurrentUpdateError
6
7
  from limits.storage.base import Storage
8
+ from limits.typing import TYPE_CHECKING, Optional, Union
7
9
 
8
10
  if TYPE_CHECKING:
9
11
  import etcd3
@@ -26,6 +28,7 @@ class EtcdStorage(Storage):
26
28
  self,
27
29
  uri: str,
28
30
  max_retries: int = MAX_RETRIES,
31
+ wrap_exceptions: bool = False,
29
32
  **options: str,
30
33
  ) -> None:
31
34
  """
@@ -33,6 +36,8 @@ class EtcdStorage(Storage):
33
36
  ``etcd://host:port``,
34
37
  :param max_retries: Maximum number of attempts to retry
35
38
  in the case of concurrent updates to a rate limit key
39
+ :param wrap_exceptions: Whether to wrap storage exceptions in
40
+ :exc:`limits.errors.StorageError` before raising it.
36
41
  :param options: all remaining keyword arguments are passed
37
42
  directly to the constructor of :class:`etcd3.Etcd3Client`
38
43
  :raise ConfigurationError: when :pypi:`etcd3` is not available
@@ -43,11 +48,12 @@ class EtcdStorage(Storage):
43
48
  parsed.hostname, parsed.port, **options
44
49
  )
45
50
  self.max_retries = max_retries
51
+ super().__init__(uri, wrap_exceptions=wrap_exceptions)
46
52
 
47
53
  @property
48
54
  def base_exceptions(
49
55
  self,
50
- ) -> Union[Type[Exception], Tuple[Type[Exception], ...]]: # pragma: no cover
56
+ ) -> Union[type[Exception], tuple[type[Exception], ...]]: # pragma: no cover
51
57
  return self.lib.Etcd3Exception # type: ignore[no-any-return]
52
58
 
53
59
  def prefixed_key(self, key: str) -> bytes:
@@ -1,27 +1,34 @@
1
+ from __future__ import annotations
2
+
1
3
  import inspect
2
4
  import threading
3
5
  import time
4
6
  import urllib.parse
7
+ from collections.abc import Iterable
8
+ from math import ceil, floor
5
9
  from types import ModuleType
6
- from typing import cast
7
10
 
8
11
  from limits.errors import ConfigurationError
9
- from limits.storage.base import Storage
12
+ from limits.storage.base import (
13
+ SlidingWindowCounterSupport,
14
+ Storage,
15
+ TimestampedSlidingWindow,
16
+ )
10
17
  from limits.typing import (
18
+ Any,
11
19
  Callable,
12
- List,
13
20
  MemcachedClientP,
14
21
  Optional,
15
22
  P,
16
23
  R,
17
- Tuple,
18
24
  Type,
19
25
  Union,
26
+ cast,
20
27
  )
21
28
  from limits.util import get_dependency
22
29
 
23
30
 
24
- class MemcachedStorage(Storage):
31
+ class MemcachedStorage(Storage, SlidingWindowCounterSupport, TimestampedSlidingWindow):
25
32
  """
26
33
  Rate limit storage with memcached as backend.
27
34
 
@@ -70,7 +77,7 @@ class MemcachedStorage(Storage):
70
77
  options.pop("cluster_library", "pymemcache.client.hash")
71
78
  )
72
79
  self.client_getter = cast(
73
- Callable[[ModuleType, List[Tuple[str, int]]], MemcachedClientP],
80
+ Callable[[ModuleType, list[tuple[str, int]]], MemcachedClientP],
74
81
  options.pop("client_getter", self.get_client),
75
82
  )
76
83
  self.options = options
@@ -86,11 +93,11 @@ class MemcachedStorage(Storage):
86
93
  @property
87
94
  def base_exceptions(
88
95
  self,
89
- ) -> Union[Type[Exception], Tuple[Type[Exception], ...]]: # pragma: no cover
96
+ ) -> Union[Type[Exception], tuple[Type[Exception], ...]]: # pragma: no cover
90
97
  return self.dependency.MemcacheError # type: ignore[no-any-return]
91
98
 
92
99
  def get_client(
93
- self, module: ModuleType, hosts: List[Tuple[str, int]], **kwargs: str
100
+ self, module: ModuleType, hosts: list[tuple[str, int]], **kwargs: str
94
101
  ) -> MemcachedClientP:
95
102
  """
96
103
  returns a memcached client.
@@ -142,8 +149,15 @@ class MemcachedStorage(Storage):
142
149
  """
143
150
  :param key: the key to get the counter value for
144
151
  """
152
+ return int(self.storage.get(key, "0"))
153
+
154
+ def get_many(self, keys: Iterable[str]) -> dict[str, Any]: # type:ignore[explicit-any]
155
+ """
156
+ Return multiple counters at once
145
157
 
146
- return int(self.storage.get(key) or 0)
158
+ :param keys: the keys to get the counter values for
159
+ """
160
+ return self.storage.get_many(keys)
147
161
 
148
162
  def clear(self, key: str) -> None:
149
163
  """
@@ -152,7 +166,12 @@ class MemcachedStorage(Storage):
152
166
  self.storage.delete(key)
153
167
 
154
168
  def incr(
155
- self, key: str, expiry: int, elastic_expiry: bool = False, amount: int = 1
169
+ self,
170
+ key: str,
171
+ expiry: float,
172
+ elastic_expiry: bool = False,
173
+ amount: int = 1,
174
+ set_expiration_key: bool = True,
156
175
  ) -> int:
157
176
  """
158
177
  increments the counter for a given rate limit key
@@ -162,41 +181,67 @@ class MemcachedStorage(Storage):
162
181
  :param elastic_expiry: whether to keep extending the rate limit
163
182
  window every hit.
164
183
  :param amount: the number to increment by
184
+ :param set_expiration_key: set the expiration key with the expiration time if needed. If set to False, the key will still expire, but memcached cannot provide the expiration time.
165
185
  """
166
-
167
- if not self.call_memcached_func(
168
- self.storage.add, key, amount, expiry, noreply=False
169
- ):
170
- value = self.storage.incr(key, amount) or amount
171
-
186
+ value = self.call_memcached_func(self.storage.incr, key, amount, noreply=False)
187
+ if value is not None:
172
188
  if elastic_expiry:
173
- self.call_memcached_func(self.storage.touch, key, expiry)
174
- self.call_memcached_func(
175
- self.storage.set,
176
- key + "/expires",
177
- expiry + time.time(),
178
- expire=expiry,
179
- noreply=False,
180
- )
189
+ self.call_memcached_func(self.storage.touch, key, ceil(expiry))
190
+ if set_expiration_key:
191
+ self.call_memcached_func(
192
+ self.storage.set,
193
+ self._expiration_key(key),
194
+ expiry + time.time(),
195
+ expire=ceil(expiry),
196
+ noreply=False,
197
+ )
181
198
 
182
199
  return value
183
200
  else:
184
- self.call_memcached_func(
185
- self.storage.set,
186
- key + "/expires",
187
- expiry + time.time(),
188
- expire=expiry,
189
- noreply=False,
190
- )
191
-
192
- return amount
201
+ if not self.call_memcached_func(
202
+ self.storage.add, key, amount, ceil(expiry), noreply=False
203
+ ):
204
+ value = self.storage.incr(key, amount) or amount
205
+
206
+ if elastic_expiry:
207
+ self.call_memcached_func(self.storage.touch, key, ceil(expiry))
208
+ if set_expiration_key:
209
+ self.call_memcached_func(
210
+ self.storage.set,
211
+ self._expiration_key(key),
212
+ expiry + time.time(),
213
+ expire=ceil(expiry),
214
+ noreply=False,
215
+ )
216
+
217
+ return value
218
+ else:
219
+ if set_expiration_key:
220
+ self.call_memcached_func(
221
+ self.storage.set,
222
+ self._expiration_key(key),
223
+ expiry + time.time(),
224
+ expire=ceil(expiry),
225
+ noreply=False,
226
+ )
227
+
228
+ return amount
193
229
 
194
230
  def get_expiry(self, key: str) -> float:
195
231
  """
196
232
  :param key: the key to get the expiry for
197
233
  """
198
234
 
199
- return float(self.storage.get(key + "/expires") or time.time())
235
+ return float(self.storage.get(self._expiration_key(key)) or time.time())
236
+
237
+ def _expiration_key(self, key: str) -> str:
238
+ """
239
+ Return the expiration key for the given counter key.
240
+
241
+ Memcached doesn't natively return the expiration time or TTL for a given key,
242
+ so we implement the expiration time on a separate key.
243
+ """
244
+ return key + "/expires"
200
245
 
201
246
  def check(self) -> bool:
202
247
  """
@@ -212,3 +257,67 @@ class MemcachedStorage(Storage):
212
257
 
213
258
  def reset(self) -> Optional[int]:
214
259
  raise NotImplementedError
260
+
261
+ def acquire_sliding_window_entry(
262
+ self,
263
+ key: str,
264
+ limit: int,
265
+ expiry: int,
266
+ amount: int = 1,
267
+ ) -> bool:
268
+ if amount > limit:
269
+ return False
270
+ now = time.time()
271
+ previous_key, current_key = self.sliding_window_keys(key, expiry, now)
272
+ previous_count, previous_ttl, current_count, _ = self._get_sliding_window_info(
273
+ previous_key, current_key, expiry, now=now
274
+ )
275
+ weighted_count = previous_count * previous_ttl / expiry + current_count
276
+ if floor(weighted_count) + amount > limit:
277
+ return False
278
+ else:
279
+ # Hit, increase the current counter.
280
+ # If the counter doesn't exist yet, set twice the theorical expiry.
281
+ # We don't need the expiration key as it is estimated with the timestamps directly.
282
+ current_count = self.incr(
283
+ current_key, 2 * expiry, amount=amount, set_expiration_key=False
284
+ )
285
+ actualised_previous_ttl = min(0, previous_ttl - (time.time() - now))
286
+ weighted_count = (
287
+ previous_count * actualised_previous_ttl / expiry + current_count
288
+ )
289
+ if floor(weighted_count) > limit:
290
+ # Another hit won the race condition: revert the incrementation and refuse this hit
291
+ # Limitation: during high concurrency at the end of the window,
292
+ # the counter is shifted and cannot be decremented, so less requests than expected are allowed.
293
+ self.call_memcached_func(
294
+ self.storage.decr,
295
+ current_key,
296
+ amount,
297
+ noreply=True,
298
+ )
299
+ return False
300
+ return True
301
+
302
+ def get_sliding_window(
303
+ self, key: str, expiry: int
304
+ ) -> tuple[int, float, int, float]:
305
+ now = time.time()
306
+ previous_key, current_key = self.sliding_window_keys(key, expiry, now)
307
+ return self._get_sliding_window_info(previous_key, current_key, expiry, now)
308
+
309
+ def _get_sliding_window_info(
310
+ self, previous_key: str, current_key: str, expiry: int, now: float
311
+ ) -> tuple[int, float, int, float]:
312
+ result = self.get_many([previous_key, current_key])
313
+ previous_count, current_count = (
314
+ int(result.get(previous_key, 0)),
315
+ int(result.get(current_key, 0)),
316
+ )
317
+
318
+ if previous_count == 0:
319
+ previous_ttl = float(0)
320
+ else:
321
+ previous_ttl = (1 - (((now - expiry) / expiry) % 1)) * expiry
322
+ current_ttl = (1 - ((now / expiry) % 1)) * expiry + expiry
323
+ return previous_count, previous_ttl, current_count, current_ttl
limits/storage/memory.py CHANGED
@@ -1,10 +1,18 @@
1
+ from __future__ import annotations
2
+
1
3
  import threading
2
4
  import time
3
- from collections import Counter
5
+ from collections import Counter, defaultdict
6
+ from math import floor
4
7
 
5
8
  import limits.typing
6
- from limits.storage.base import MovingWindowSupport, Storage
7
- from limits.typing import Dict, List, Optional, Tuple, Type, Union
9
+ from limits.storage.base import (
10
+ MovingWindowSupport,
11
+ SlidingWindowCounterSupport,
12
+ Storage,
13
+ TimestampedSlidingWindow,
14
+ )
15
+ from limits.typing import Optional, Type, Union
8
16
 
9
17
 
10
18
  class LockableEntry(threading._RLock): # type: ignore
@@ -14,7 +22,9 @@ class LockableEntry(threading._RLock): # type: ignore
14
22
  super().__init__()
15
23
 
16
24
 
17
- class MemoryStorage(Storage, MovingWindowSupport):
25
+ class MemoryStorage(
26
+ Storage, MovingWindowSupport, SlidingWindowCounterSupport, TimestampedSlidingWindow
27
+ ):
18
28
  """
19
29
  rate limit storage using :class:`collections.Counter`
20
30
  as an in memory storage for fixed and elastic window strategies,
@@ -28,8 +38,9 @@ class MemoryStorage(Storage, MovingWindowSupport):
28
38
  self, uri: Optional[str] = None, wrap_exceptions: bool = False, **_: str
29
39
  ):
30
40
  self.storage: limits.typing.Counter[str] = Counter()
31
- self.expirations: Dict[str, float] = {}
32
- self.events: Dict[str, List[LockableEntry]] = {}
41
+ self.locks: defaultdict[str, threading.RLock] = defaultdict(threading.RLock)
42
+ self.expirations: dict[str, float] = {}
43
+ self.events: dict[str, list[LockableEntry]] = {}
33
44
  self.timer = threading.Timer(0.01, self.__expire_events)
34
45
  self.timer.start()
35
46
  super().__init__(uri, wrap_exceptions=wrap_exceptions, **_)
@@ -37,7 +48,7 @@ class MemoryStorage(Storage, MovingWindowSupport):
37
48
  @property
38
49
  def base_exceptions(
39
50
  self,
40
- ) -> Union[Type[Exception], Tuple[Type[Exception], ...]]: # pragma: no cover
51
+ ) -> Union[Type[Exception], tuple[Type[Exception], ...]]: # pragma: no cover
41
52
  return ValueError
42
53
 
43
54
  def __expire_events(self) -> None:
@@ -51,6 +62,7 @@ class MemoryStorage(Storage, MovingWindowSupport):
51
62
  if self.expirations[key] <= time.time():
52
63
  self.storage.pop(key, None)
53
64
  self.expirations.pop(key, None)
65
+ self.locks.pop(key, None)
54
66
 
55
67
  def __schedule_expiry(self) -> None:
56
68
  if not self.timer.is_alive():
@@ -58,7 +70,7 @@ class MemoryStorage(Storage, MovingWindowSupport):
58
70
  self.timer.start()
59
71
 
60
72
  def incr(
61
- self, key: str, expiry: int, elastic_expiry: bool = False, amount: int = 1
73
+ self, key: str, expiry: float, elastic_expiry: bool = False, amount: int = 1
62
74
  ) -> int:
63
75
  """
64
76
  increments the counter for a given rate limit key
@@ -71,10 +83,25 @@ class MemoryStorage(Storage, MovingWindowSupport):
71
83
  """
72
84
  self.get(key)
73
85
  self.__schedule_expiry()
74
- self.storage[key] += amount
86
+ with self.locks[key]:
87
+ self.storage[key] += amount
88
+
89
+ if elastic_expiry or self.storage[key] == amount:
90
+ self.expirations[key] = time.time() + expiry
91
+
92
+ return self.storage.get(key, 0)
93
+
94
+ def decr(self, key: str, amount: int = 1) -> int:
95
+ """
96
+ decrements the counter for a given rate limit key
75
97
 
76
- if elastic_expiry or self.storage[key] == amount:
77
- self.expirations[key] = time.time() + expiry
98
+ :param key: the key to decrement
99
+ :param amount: the number to decrement by
100
+ """
101
+ self.get(key)
102
+ self.__schedule_expiry()
103
+ with self.locks[key]:
104
+ self.storage[key] = max(self.storage[key] - amount, 0)
78
105
 
79
106
  return self.storage.get(key, 0)
80
107
 
@@ -86,6 +113,7 @@ class MemoryStorage(Storage, MovingWindowSupport):
86
113
  if self.expirations.get(key, 0) <= time.time():
87
114
  self.storage.pop(key, None)
88
115
  self.expirations.pop(key, None)
116
+ self.locks.pop(key, None)
89
117
 
90
118
  return self.storage.get(key, 0)
91
119
 
@@ -96,6 +124,7 @@ class MemoryStorage(Storage, MovingWindowSupport):
96
124
  self.storage.pop(key, None)
97
125
  self.expirations.pop(key, None)
98
126
  self.events.pop(key, None)
127
+ self.locks.pop(key, None)
99
128
 
100
129
  def acquire_entry(self, key: str, limit: int, expiry: int, amount: int = 1) -> bool:
101
130
  """
@@ -143,7 +172,7 @@ class MemoryStorage(Storage, MovingWindowSupport):
143
172
  else 0
144
173
  )
145
174
 
146
- def get_moving_window(self, key: str, limit: int, expiry: int) -> Tuple[float, int]:
175
+ def get_moving_window(self, key: str, limit: int, expiry: int) -> tuple[float, int]:
147
176
  """
148
177
  returns the starting point and the number of entries in the moving
149
178
  window
@@ -161,6 +190,63 @@ class MemoryStorage(Storage, MovingWindowSupport):
161
190
 
162
191
  return timestamp, acquired
163
192
 
193
+ def acquire_sliding_window_entry(
194
+ self,
195
+ key: str,
196
+ limit: int,
197
+ expiry: int,
198
+ amount: int = 1,
199
+ ) -> bool:
200
+ if amount > limit:
201
+ return False
202
+ now = time.time()
203
+ previous_key, current_key = self.sliding_window_keys(key, expiry, now)
204
+ (
205
+ previous_count,
206
+ previous_ttl,
207
+ current_count,
208
+ _,
209
+ ) = self._get_sliding_window_info(previous_key, current_key, expiry, now)
210
+ weighted_count = previous_count * previous_ttl / expiry + current_count
211
+ if floor(weighted_count) + amount > limit:
212
+ return False
213
+ else:
214
+ # Hit, increase the current counter.
215
+ # If the counter doesn't exist yet, set twice the theorical expiry.
216
+ current_count = self.incr(current_key, 2 * expiry, amount=amount)
217
+ weighted_count = previous_count * previous_ttl / expiry + current_count
218
+ if floor(weighted_count) > limit:
219
+ # Another hit won the race condition: revert the incrementation and refuse this hit
220
+ # Limitation: during high concurrency at the end of the window,
221
+ # the counter is shifted and cannot be decremented, so less requests than expected are allowed.
222
+ self.decr(current_key, amount)
223
+ # print("Concurrent call, reverting the counter increase")
224
+ return False
225
+ return True
226
+
227
+ def _get_sliding_window_info(
228
+ self,
229
+ previous_key: str,
230
+ current_key: str,
231
+ expiry: int,
232
+ now: float,
233
+ ) -> tuple[int, float, int, float]:
234
+ previous_count = self.get(previous_key)
235
+ current_count = self.get(current_key)
236
+ if previous_count == 0:
237
+ previous_ttl = float(0)
238
+ else:
239
+ previous_ttl = (1 - (((now - expiry) / expiry) % 1)) * expiry
240
+ current_ttl = (1 - ((now / expiry) % 1)) * expiry + expiry
241
+ return previous_count, previous_ttl, current_count, current_ttl
242
+
243
+ def get_sliding_window(
244
+ self, key: str, expiry: int
245
+ ) -> tuple[int, float, int, float]:
246
+ now = time.time()
247
+ previous_key, current_key = self.sliding_window_keys(key, expiry, now)
248
+ return self._get_sliding_window_info(previous_key, current_key, expiry, now)
249
+
164
250
  def check(self) -> bool:
165
251
  """
166
252
  check if storage is healthy
@@ -173,4 +259,5 @@ class MemoryStorage(Storage, MovingWindowSupport):
173
259
  self.storage.clear()
174
260
  self.expirations.clear()
175
261
  self.events.clear()
262
+ self.locks.clear()
176
263
  return num_items