limits 4.0.0__py3-none-any.whl → 4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
limits/storage/base.py CHANGED
@@ -3,30 +3,32 @@ from __future__ import annotations
3
3
  import functools
4
4
  import threading
5
5
  from abc import ABC, abstractmethod
6
- from typing import Any, cast
7
6
 
8
7
  from limits import errors
9
8
  from limits.storage.registry import StorageRegistry
10
9
  from limits.typing import (
10
+ Any,
11
11
  Callable,
12
- List,
13
12
  Optional,
14
13
  P,
15
14
  R,
16
- Tuple,
17
15
  Type,
18
16
  Union,
17
+ cast,
19
18
  )
20
19
  from limits.util import LazyDependency
21
20
 
22
21
 
23
- def _wrap_errors(storage: Storage, fn: Callable[P, R]) -> Callable[P, R]:
22
+ def _wrap_errors(
23
+ fn: Callable[P, R],
24
+ ) -> Callable[P, R]:
24
25
  @functools.wraps(fn)
25
26
  def inner(*args: P.args, **kwargs: P.kwargs) -> R:
27
+ instance = cast(Storage, args[0])
26
28
  try:
27
29
  return fn(*args, **kwargs)
28
- except storage.base_exceptions as exc:
29
- if storage.wrap_exceptions:
30
+ except instance.base_exceptions as exc:
31
+ if instance.wrap_exceptions:
30
32
  raise errors.StorageError(exc) from exc
31
33
  raise
32
34
 
@@ -38,12 +40,10 @@ class Storage(LazyDependency, metaclass=StorageRegistry):
38
40
  Base class to extend when implementing a storage backend.
39
41
  """
40
42
 
41
- STORAGE_SCHEME: Optional[List[str]]
43
+ STORAGE_SCHEME: Optional[list[str]]
42
44
  """The storage schemes to register against this implementation"""
43
45
 
44
- def __new__(cls, *args: Any, **kwargs: Any) -> Storage: # type: ignore[misc]
45
- inst = super().__new__(cls)
46
-
46
+ def __init_subclass__(cls, **kwargs: Any) -> None: # type: ignore[explicit-any]
47
47
  for method in {
48
48
  "incr",
49
49
  "get",
@@ -52,9 +52,8 @@ class Storage(LazyDependency, metaclass=StorageRegistry):
52
52
  "reset",
53
53
  "clear",
54
54
  }:
55
- setattr(inst, method, _wrap_errors(inst, getattr(inst, method)))
56
-
57
- return inst
55
+ setattr(cls, method, _wrap_errors(getattr(cls, method)))
56
+ super().__init_subclass__(**kwargs)
58
57
 
59
58
  def __init__(
60
59
  self,
@@ -73,7 +72,7 @@ class Storage(LazyDependency, metaclass=StorageRegistry):
73
72
 
74
73
  @property
75
74
  @abstractmethod
76
- def base_exceptions(self) -> Union[Type[Exception], Tuple[Type[Exception], ...]]:
75
+ def base_exceptions(self) -> Union[Type[Exception], tuple[Type[Exception], ...]]:
77
76
  raise NotImplementedError
78
77
 
79
78
  @abstractmethod
@@ -131,24 +130,21 @@ class Storage(LazyDependency, metaclass=StorageRegistry):
131
130
 
132
131
  class MovingWindowSupport(ABC):
133
132
  """
134
- Abstract base for storages that intend to support
135
- the moving window strategy
133
+ Abstract base class for storages that support
134
+ the :ref:`strategies:moving window` strategy
136
135
  """
137
136
 
138
- def __new__(cls, *args: Any, **kwargs: Any) -> MovingWindowSupport: # type: ignore[misc]
139
- inst = super().__new__(cls)
140
-
137
+ def __init_subclass__(cls, **kwargs: Any) -> None: # type: ignore[explicit-any]
141
138
  for method in {
142
139
  "acquire_entry",
143
140
  "get_moving_window",
144
141
  }:
145
142
  setattr(
146
- inst,
143
+ cls,
147
144
  method,
148
- _wrap_errors(cast(Storage, inst), getattr(inst, method)),
145
+ _wrap_errors(getattr(cls, method)),
149
146
  )
150
-
151
- return inst
147
+ super().__init_subclass__(**kwargs)
152
148
 
153
149
  @abstractmethod
154
150
  def acquire_entry(self, key: str, limit: int, expiry: int, amount: int = 1) -> bool:
@@ -161,7 +157,7 @@ class MovingWindowSupport(ABC):
161
157
  raise NotImplementedError
162
158
 
163
159
  @abstractmethod
164
- def get_moving_window(self, key: str, limit: int, expiry: int) -> Tuple[float, int]:
160
+ def get_moving_window(self, key: str, limit: int, expiry: int) -> tuple[float, int]:
165
161
  """
166
162
  returns the starting point and the number of entries in the moving
167
163
  window
@@ -171,3 +167,75 @@ class MovingWindowSupport(ABC):
171
167
  :return: (start of window, number of acquired entries)
172
168
  """
173
169
  raise NotImplementedError
170
+
171
+
172
+ class SlidingWindowCounterSupport(ABC):
173
+ """
174
+ Abstract base class for storages that support
175
+ the :ref:`strategies:sliding window counter` strategy.
176
+ """
177
+
178
+ def __init_subclass__(cls, **kwargs: Any) -> None: # type: ignore[explicit-any]
179
+ for method in {"acquire_sliding_window_entry", "get_sliding_window"}:
180
+ setattr(
181
+ cls,
182
+ method,
183
+ _wrap_errors(getattr(cls, method)),
184
+ )
185
+ super().__init_subclass__(**kwargs)
186
+
187
+ @abstractmethod
188
+ def acquire_sliding_window_entry(
189
+ self, key: str, limit: int, expiry: int, amount: int = 1
190
+ ) -> bool:
191
+ """
192
+ Acquire an entry if the weighted count of the current and previous
193
+ windows is less than or equal to the limit
194
+
195
+ :param key: rate limit key to acquire an entry in
196
+ :param limit: amount of entries allowed
197
+ :param expiry: expiry of the entry
198
+ :param amount: the number of entries to acquire
199
+ """
200
+ raise NotImplementedError
201
+
202
+ @abstractmethod
203
+ def get_sliding_window(
204
+ self, key: str, expiry: int
205
+ ) -> tuple[int, float, int, float]:
206
+ """
207
+ Return the previous and current window information.
208
+
209
+ :param key: the rate limit key
210
+ :param expiry: the rate limit expiry, needed to compute the key in some implementations
211
+ :return: a tuple of (int, float, int, float) with the following information:
212
+ - previous window counter
213
+ - previous window TTL
214
+ - current window counter
215
+ - current window TTL
216
+ """
217
+ raise NotImplementedError
218
+
219
+
220
+ class TimestampedSlidingWindow:
221
+ """Helper class for storage that support the sliding window counter, with timestamp based keys."""
222
+
223
+ @classmethod
224
+ def sliding_window_keys(cls, key: str, expiry: int, at: float) -> tuple[str, str]:
225
+ """
226
+ returns the previous and the current window's keys.
227
+
228
+ :param key: the key to get the window's keys from
229
+ :param expiry: the expiry of the limit item, in seconds
230
+ :param at: the timestamp to get the keys from. Default to now, ie ``time.time()``
231
+
232
+ Returns a tuple with the previous and the current key: (previous, current).
233
+
234
+ Example:
235
+ - key = "mykey"
236
+ - expiry = 60
237
+ - at = 1738576292.6631825
238
+
239
+ The return value will be the tuple ``("mykey/28976271", "mykey/28976270")``.
240
+ """
241
+ return f"{key}/{int((at - expiry) / expiry)}", f"{key}/{int(at / expiry)}"
limits/storage/etcd.py CHANGED
@@ -1,9 +1,9 @@
1
1
  import time
2
2
  import urllib.parse
3
- from typing import TYPE_CHECKING, Optional, Tuple, Type, Union
4
3
 
5
4
  from limits.errors import ConcurrentUpdateError
6
5
  from limits.storage.base import Storage
6
+ from limits.typing import TYPE_CHECKING, Optional, Union
7
7
 
8
8
  if TYPE_CHECKING:
9
9
  import etcd3
@@ -26,6 +26,7 @@ class EtcdStorage(Storage):
26
26
  self,
27
27
  uri: str,
28
28
  max_retries: int = MAX_RETRIES,
29
+ wrap_exceptions: bool = False,
29
30
  **options: str,
30
31
  ) -> None:
31
32
  """
@@ -33,6 +34,8 @@ class EtcdStorage(Storage):
33
34
  ``etcd://host:port``,
34
35
  :param max_retries: Maximum number of attempts to retry
35
36
  in the case of concurrent updates to a rate limit key
37
+ :param wrap_exceptions: Whether to wrap storage exceptions in
38
+ :exc:`limits.errors.StorageError` before raising it.
36
39
  :param options: all remaining keyword arguments are passed
37
40
  directly to the constructor of :class:`etcd3.Etcd3Client`
38
41
  :raise ConfigurationError: when :pypi:`etcd3` is not available
@@ -43,11 +46,12 @@ class EtcdStorage(Storage):
43
46
  parsed.hostname, parsed.port, **options
44
47
  )
45
48
  self.max_retries = max_retries
49
+ super().__init__(uri, wrap_exceptions=wrap_exceptions)
46
50
 
47
51
  @property
48
52
  def base_exceptions(
49
53
  self,
50
- ) -> Union[Type[Exception], Tuple[Type[Exception], ...]]: # pragma: no cover
54
+ ) -> Union[type[Exception], tuple[type[Exception], ...]]: # pragma: no cover
51
55
  return self.lib.Etcd3Exception # type: ignore[no-any-return]
52
56
 
53
57
  def prefixed_key(self, key: str) -> bytes:
@@ -2,26 +2,31 @@ import inspect
2
2
  import threading
3
3
  import time
4
4
  import urllib.parse
5
+ from collections.abc import Iterable
6
+ from math import ceil, floor
5
7
  from types import ModuleType
6
- from typing import cast
7
8
 
8
9
  from limits.errors import ConfigurationError
9
- from limits.storage.base import Storage
10
+ from limits.storage.base import (
11
+ SlidingWindowCounterSupport,
12
+ Storage,
13
+ TimestampedSlidingWindow,
14
+ )
10
15
  from limits.typing import (
16
+ Any,
11
17
  Callable,
12
- List,
13
18
  MemcachedClientP,
14
19
  Optional,
15
20
  P,
16
21
  R,
17
- Tuple,
18
22
  Type,
19
23
  Union,
24
+ cast,
20
25
  )
21
26
  from limits.util import get_dependency
22
27
 
23
28
 
24
- class MemcachedStorage(Storage):
29
+ class MemcachedStorage(Storage, SlidingWindowCounterSupport, TimestampedSlidingWindow):
25
30
  """
26
31
  Rate limit storage with memcached as backend.
27
32
 
@@ -70,15 +75,14 @@ class MemcachedStorage(Storage):
70
75
  options.pop("cluster_library", "pymemcache.client.hash")
71
76
  )
72
77
  self.client_getter = cast(
73
- Callable[[ModuleType, List[Tuple[str, int]]], MemcachedClientP],
78
+ Callable[[ModuleType, list[tuple[str, int]]], MemcachedClientP],
74
79
  options.pop("client_getter", self.get_client),
75
80
  )
76
81
  self.options = options
77
82
 
78
83
  if not get_dependency(self.library):
79
84
  raise ConfigurationError(
80
- "memcached prerequisite not available."
81
- " please install %s" % self.library
85
+ "memcached prerequisite not available. please install %s" % self.library
82
86
  ) # pragma: no cover
83
87
  self.local_storage = threading.local()
84
88
  self.local_storage.storage = None
@@ -87,11 +91,11 @@ class MemcachedStorage(Storage):
87
91
  @property
88
92
  def base_exceptions(
89
93
  self,
90
- ) -> Union[Type[Exception], Tuple[Type[Exception], ...]]: # pragma: no cover
94
+ ) -> Union[Type[Exception], tuple[Type[Exception], ...]]: # pragma: no cover
91
95
  return self.dependency.MemcacheError # type: ignore[no-any-return]
92
96
 
93
97
  def get_client(
94
- self, module: ModuleType, hosts: List[Tuple[str, int]], **kwargs: str
98
+ self, module: ModuleType, hosts: list[tuple[str, int]], **kwargs: str
95
99
  ) -> MemcachedClientP:
96
100
  """
97
101
  returns a memcached client.
@@ -143,8 +147,15 @@ class MemcachedStorage(Storage):
143
147
  """
144
148
  :param key: the key to get the counter value for
145
149
  """
150
+ return int(self.storage.get(key, "0"))
151
+
152
+ def get_many(self, keys: Iterable[str]) -> dict[str, Any]: # type:ignore[explicit-any]
153
+ """
154
+ Return multiple counters at once
146
155
 
147
- return int(self.storage.get(key) or 0)
156
+ :param keys: the keys to get the counter values for
157
+ """
158
+ return self.storage.get_many(keys)
148
159
 
149
160
  def clear(self, key: str) -> None:
150
161
  """
@@ -153,7 +164,12 @@ class MemcachedStorage(Storage):
153
164
  self.storage.delete(key)
154
165
 
155
166
  def incr(
156
- self, key: str, expiry: int, elastic_expiry: bool = False, amount: int = 1
167
+ self,
168
+ key: str,
169
+ expiry: float,
170
+ elastic_expiry: bool = False,
171
+ amount: int = 1,
172
+ set_expiration_key: bool = True,
157
173
  ) -> int:
158
174
  """
159
175
  increments the counter for a given rate limit key
@@ -163,41 +179,67 @@ class MemcachedStorage(Storage):
163
179
  :param elastic_expiry: whether to keep extending the rate limit
164
180
  window every hit.
165
181
  :param amount: the number to increment by
182
+ :param set_expiration_key: set the expiration key with the expiration time if needed. If set to False, the key will still expire, but memcached cannot provide the expiration time.
166
183
  """
167
-
168
- if not self.call_memcached_func(
169
- self.storage.add, key, amount, expiry, noreply=False
170
- ):
171
- value = self.storage.incr(key, amount) or amount
172
-
184
+ value = self.call_memcached_func(self.storage.incr, key, amount, noreply=False)
185
+ if value is not None:
173
186
  if elastic_expiry:
174
- self.call_memcached_func(self.storage.touch, key, expiry)
175
- self.call_memcached_func(
176
- self.storage.set,
177
- key + "/expires",
178
- expiry + time.time(),
179
- expire=expiry,
180
- noreply=False,
181
- )
187
+ self.call_memcached_func(self.storage.touch, key, ceil(expiry))
188
+ if set_expiration_key:
189
+ self.call_memcached_func(
190
+ self.storage.set,
191
+ self._expiration_key(key),
192
+ expiry + time.time(),
193
+ expire=ceil(expiry),
194
+ noreply=False,
195
+ )
182
196
 
183
197
  return value
184
198
  else:
185
- self.call_memcached_func(
186
- self.storage.set,
187
- key + "/expires",
188
- expiry + time.time(),
189
- expire=expiry,
190
- noreply=False,
191
- )
192
-
193
- return amount
199
+ if not self.call_memcached_func(
200
+ self.storage.add, key, amount, ceil(expiry), noreply=False
201
+ ):
202
+ value = self.storage.incr(key, amount) or amount
203
+
204
+ if elastic_expiry:
205
+ self.call_memcached_func(self.storage.touch, key, ceil(expiry))
206
+ if set_expiration_key:
207
+ self.call_memcached_func(
208
+ self.storage.set,
209
+ self._expiration_key(key),
210
+ expiry + time.time(),
211
+ expire=ceil(expiry),
212
+ noreply=False,
213
+ )
214
+
215
+ return value
216
+ else:
217
+ if set_expiration_key:
218
+ self.call_memcached_func(
219
+ self.storage.set,
220
+ self._expiration_key(key),
221
+ expiry + time.time(),
222
+ expire=ceil(expiry),
223
+ noreply=False,
224
+ )
225
+
226
+ return amount
194
227
 
195
228
  def get_expiry(self, key: str) -> float:
196
229
  """
197
230
  :param key: the key to get the expiry for
198
231
  """
199
232
 
200
- return float(self.storage.get(key + "/expires") or time.time())
233
+ return float(self.storage.get(self._expiration_key(key)) or time.time())
234
+
235
+ def _expiration_key(self, key: str) -> str:
236
+ """
237
+ Return the expiration key for the given counter key.
238
+
239
+ Memcached doesn't natively return the expiration time or TTL for a given key,
240
+ so we implement the expiration time on a separate key.
241
+ """
242
+ return key + "/expires"
201
243
 
202
244
  def check(self) -> bool:
203
245
  """
@@ -213,3 +255,67 @@ class MemcachedStorage(Storage):
213
255
 
214
256
  def reset(self) -> Optional[int]:
215
257
  raise NotImplementedError
258
+
259
+ def acquire_sliding_window_entry(
260
+ self,
261
+ key: str,
262
+ limit: int,
263
+ expiry: int,
264
+ amount: int = 1,
265
+ ) -> bool:
266
+ if amount > limit:
267
+ return False
268
+ now = time.time()
269
+ previous_key, current_key = self.sliding_window_keys(key, expiry, now)
270
+ previous_count, previous_ttl, current_count, _ = self._get_sliding_window_info(
271
+ previous_key, current_key, expiry, now=now
272
+ )
273
+ weighted_count = previous_count * previous_ttl / expiry + current_count
274
+ if floor(weighted_count) + amount > limit:
275
+ return False
276
+ else:
277
+ # Hit, increase the current counter.
278
+ # If the counter doesn't exist yet, set twice the theorical expiry.
279
+ # We don't need the expiration key as it is estimated with the timestamps directly.
280
+ current_count = self.incr(
281
+ current_key, 2 * expiry, amount=amount, set_expiration_key=False
282
+ )
283
+ actualised_previous_ttl = min(0, previous_ttl - (time.time() - now))
284
+ weighted_count = (
285
+ previous_count * actualised_previous_ttl / expiry + current_count
286
+ )
287
+ if floor(weighted_count) > limit:
288
+ # Another hit won the race condition: revert the incrementation and refuse this hit
289
+ # Limitation: during high concurrency at the end of the window,
290
+ # the counter is shifted and cannot be decremented, so less requests than expected are allowed.
291
+ self.call_memcached_func(
292
+ self.storage.decr,
293
+ current_key,
294
+ amount,
295
+ noreply=True,
296
+ )
297
+ return False
298
+ return True
299
+
300
+ def get_sliding_window(
301
+ self, key: str, expiry: int
302
+ ) -> tuple[int, float, int, float]:
303
+ now = time.time()
304
+ previous_key, current_key = self.sliding_window_keys(key, expiry, now)
305
+ return self._get_sliding_window_info(previous_key, current_key, expiry, now)
306
+
307
+ def _get_sliding_window_info(
308
+ self, previous_key: str, current_key: str, expiry: int, now: float
309
+ ) -> tuple[int, float, int, float]:
310
+ result = self.get_many([previous_key, current_key])
311
+ previous_count, current_count = (
312
+ int(result.get(previous_key, 0)),
313
+ int(result.get(current_key, 0)),
314
+ )
315
+
316
+ if previous_count == 0:
317
+ previous_ttl = float(0)
318
+ else:
319
+ previous_ttl = (1 - (((now - expiry) / expiry) % 1)) * expiry
320
+ current_ttl = (1 - ((now / expiry) % 1)) * expiry + expiry
321
+ return previous_count, previous_ttl, current_count, current_ttl
limits/storage/memory.py CHANGED
@@ -1,10 +1,16 @@
1
1
  import threading
2
2
  import time
3
- from collections import Counter
3
+ from collections import Counter, defaultdict
4
+ from math import floor
4
5
 
5
6
  import limits.typing
6
- from limits.storage.base import MovingWindowSupport, Storage
7
- from limits.typing import Dict, List, Optional, Tuple, Type, Union
7
+ from limits.storage.base import (
8
+ MovingWindowSupport,
9
+ SlidingWindowCounterSupport,
10
+ Storage,
11
+ TimestampedSlidingWindow,
12
+ )
13
+ from limits.typing import Optional, Type, Union
8
14
 
9
15
 
10
16
  class LockableEntry(threading._RLock): # type: ignore
@@ -14,7 +20,9 @@ class LockableEntry(threading._RLock): # type: ignore
14
20
  super().__init__()
15
21
 
16
22
 
17
- class MemoryStorage(Storage, MovingWindowSupport):
23
+ class MemoryStorage(
24
+ Storage, MovingWindowSupport, SlidingWindowCounterSupport, TimestampedSlidingWindow
25
+ ):
18
26
  """
19
27
  rate limit storage using :class:`collections.Counter`
20
28
  as an in memory storage for fixed and elastic window strategies,
@@ -28,8 +36,9 @@ class MemoryStorage(Storage, MovingWindowSupport):
28
36
  self, uri: Optional[str] = None, wrap_exceptions: bool = False, **_: str
29
37
  ):
30
38
  self.storage: limits.typing.Counter[str] = Counter()
31
- self.expirations: Dict[str, float] = {}
32
- self.events: Dict[str, List[LockableEntry]] = {}
39
+ self.locks: defaultdict[str, threading.RLock] = defaultdict(threading.RLock)
40
+ self.expirations: dict[str, float] = {}
41
+ self.events: dict[str, list[LockableEntry]] = {}
33
42
  self.timer = threading.Timer(0.01, self.__expire_events)
34
43
  self.timer.start()
35
44
  super().__init__(uri, wrap_exceptions=wrap_exceptions, **_)
@@ -37,7 +46,7 @@ class MemoryStorage(Storage, MovingWindowSupport):
37
46
  @property
38
47
  def base_exceptions(
39
48
  self,
40
- ) -> Union[Type[Exception], Tuple[Type[Exception], ...]]: # pragma: no cover
49
+ ) -> Union[Type[Exception], tuple[Type[Exception], ...]]: # pragma: no cover
41
50
  return ValueError
42
51
 
43
52
  def __expire_events(self) -> None:
@@ -51,6 +60,7 @@ class MemoryStorage(Storage, MovingWindowSupport):
51
60
  if self.expirations[key] <= time.time():
52
61
  self.storage.pop(key, None)
53
62
  self.expirations.pop(key, None)
63
+ self.locks.pop(key, None)
54
64
 
55
65
  def __schedule_expiry(self) -> None:
56
66
  if not self.timer.is_alive():
@@ -58,7 +68,7 @@ class MemoryStorage(Storage, MovingWindowSupport):
58
68
  self.timer.start()
59
69
 
60
70
  def incr(
61
- self, key: str, expiry: int, elastic_expiry: bool = False, amount: int = 1
71
+ self, key: str, expiry: float, elastic_expiry: bool = False, amount: int = 1
62
72
  ) -> int:
63
73
  """
64
74
  increments the counter for a given rate limit key
@@ -71,10 +81,25 @@ class MemoryStorage(Storage, MovingWindowSupport):
71
81
  """
72
82
  self.get(key)
73
83
  self.__schedule_expiry()
74
- self.storage[key] += amount
84
+ with self.locks[key]:
85
+ self.storage[key] += amount
75
86
 
76
- if elastic_expiry or self.storage[key] == amount:
77
- self.expirations[key] = time.time() + expiry
87
+ if elastic_expiry or self.storage[key] == amount:
88
+ self.expirations[key] = time.time() + expiry
89
+
90
+ return self.storage.get(key, 0)
91
+
92
+ def decr(self, key: str, amount: int = 1) -> int:
93
+ """
94
+ decrements the counter for a given rate limit key
95
+
96
+ :param key: the key to decrement
97
+ :param amount: the number to decrement by
98
+ """
99
+ self.get(key)
100
+ self.__schedule_expiry()
101
+ with self.locks[key]:
102
+ self.storage[key] = max(self.storage[key] - amount, 0)
78
103
 
79
104
  return self.storage.get(key, 0)
80
105
 
@@ -86,6 +111,7 @@ class MemoryStorage(Storage, MovingWindowSupport):
86
111
  if self.expirations.get(key, 0) <= time.time():
87
112
  self.storage.pop(key, None)
88
113
  self.expirations.pop(key, None)
114
+ self.locks.pop(key, None)
89
115
 
90
116
  return self.storage.get(key, 0)
91
117
 
@@ -96,6 +122,7 @@ class MemoryStorage(Storage, MovingWindowSupport):
96
122
  self.storage.pop(key, None)
97
123
  self.expirations.pop(key, None)
98
124
  self.events.pop(key, None)
125
+ self.locks.pop(key, None)
99
126
 
100
127
  def acquire_entry(self, key: str, limit: int, expiry: int, amount: int = 1) -> bool:
101
128
  """
@@ -143,7 +170,7 @@ class MemoryStorage(Storage, MovingWindowSupport):
143
170
  else 0
144
171
  )
145
172
 
146
- def get_moving_window(self, key: str, limit: int, expiry: int) -> Tuple[float, int]:
173
+ def get_moving_window(self, key: str, limit: int, expiry: int) -> tuple[float, int]:
147
174
  """
148
175
  returns the starting point and the number of entries in the moving
149
176
  window
@@ -161,6 +188,63 @@ class MemoryStorage(Storage, MovingWindowSupport):
161
188
 
162
189
  return timestamp, acquired
163
190
 
191
+ def acquire_sliding_window_entry(
192
+ self,
193
+ key: str,
194
+ limit: int,
195
+ expiry: int,
196
+ amount: int = 1,
197
+ ) -> bool:
198
+ if amount > limit:
199
+ return False
200
+ now = time.time()
201
+ previous_key, current_key = self.sliding_window_keys(key, expiry, now)
202
+ (
203
+ previous_count,
204
+ previous_ttl,
205
+ current_count,
206
+ _,
207
+ ) = self._get_sliding_window_info(previous_key, current_key, expiry, now)
208
+ weighted_count = previous_count * previous_ttl / expiry + current_count
209
+ if floor(weighted_count) + amount > limit:
210
+ return False
211
+ else:
212
+ # Hit, increase the current counter.
213
+ # If the counter doesn't exist yet, set twice the theorical expiry.
214
+ current_count = self.incr(current_key, 2 * expiry, amount=amount)
215
+ weighted_count = previous_count * previous_ttl / expiry + current_count
216
+ if floor(weighted_count) > limit:
217
+ # Another hit won the race condition: revert the incrementation and refuse this hit
218
+ # Limitation: during high concurrency at the end of the window,
219
+ # the counter is shifted and cannot be decremented, so less requests than expected are allowed.
220
+ self.decr(current_key, amount)
221
+ # print("Concurrent call, reverting the counter increase")
222
+ return False
223
+ return True
224
+
225
+ def _get_sliding_window_info(
226
+ self,
227
+ previous_key: str,
228
+ current_key: str,
229
+ expiry: int,
230
+ now: float,
231
+ ) -> tuple[int, float, int, float]:
232
+ previous_count = self.get(previous_key)
233
+ current_count = self.get(current_key)
234
+ if previous_count == 0:
235
+ previous_ttl = float(0)
236
+ else:
237
+ previous_ttl = (1 - (((now - expiry) / expiry) % 1)) * expiry
238
+ current_ttl = (1 - ((now / expiry) % 1)) * expiry + expiry
239
+ return previous_count, previous_ttl, current_count, current_ttl
240
+
241
+ def get_sliding_window(
242
+ self, key: str, expiry: int
243
+ ) -> tuple[int, float, int, float]:
244
+ now = time.time()
245
+ previous_key, current_key = self.sliding_window_keys(key, expiry, now)
246
+ return self._get_sliding_window_info(previous_key, current_key, expiry, now)
247
+
164
248
  def check(self) -> bool:
165
249
  """
166
250
  check if storage is healthy
@@ -173,4 +257,5 @@ class MemoryStorage(Storage, MovingWindowSupport):
173
257
  self.storage.clear()
174
258
  self.expirations.clear()
175
259
  self.events.clear()
260
+ self.locks.clear()
176
261
  return num_items