limits 4.0.1__py3-none-any.whl → 4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,26 +2,31 @@ import inspect
2
2
  import threading
3
3
  import time
4
4
  import urllib.parse
5
+ from collections.abc import Iterable
6
+ from math import ceil, floor
5
7
  from types import ModuleType
6
- from typing import cast
7
8
 
8
9
  from limits.errors import ConfigurationError
9
- from limits.storage.base import Storage
10
+ from limits.storage.base import (
11
+ SlidingWindowCounterSupport,
12
+ Storage,
13
+ TimestampedSlidingWindow,
14
+ )
10
15
  from limits.typing import (
16
+ Any,
11
17
  Callable,
12
- List,
13
18
  MemcachedClientP,
14
19
  Optional,
15
20
  P,
16
21
  R,
17
- Tuple,
18
22
  Type,
19
23
  Union,
24
+ cast,
20
25
  )
21
26
  from limits.util import get_dependency
22
27
 
23
28
 
24
- class MemcachedStorage(Storage):
29
+ class MemcachedStorage(Storage, SlidingWindowCounterSupport, TimestampedSlidingWindow):
25
30
  """
26
31
  Rate limit storage with memcached as backend.
27
32
 
@@ -70,7 +75,7 @@ class MemcachedStorage(Storage):
70
75
  options.pop("cluster_library", "pymemcache.client.hash")
71
76
  )
72
77
  self.client_getter = cast(
73
- Callable[[ModuleType, List[Tuple[str, int]]], MemcachedClientP],
78
+ Callable[[ModuleType, list[tuple[str, int]]], MemcachedClientP],
74
79
  options.pop("client_getter", self.get_client),
75
80
  )
76
81
  self.options = options
@@ -86,11 +91,11 @@ class MemcachedStorage(Storage):
86
91
  @property
87
92
  def base_exceptions(
88
93
  self,
89
- ) -> Union[Type[Exception], Tuple[Type[Exception], ...]]: # pragma: no cover
94
+ ) -> Union[Type[Exception], tuple[Type[Exception], ...]]: # pragma: no cover
90
95
  return self.dependency.MemcacheError # type: ignore[no-any-return]
91
96
 
92
97
  def get_client(
93
- self, module: ModuleType, hosts: List[Tuple[str, int]], **kwargs: str
98
+ self, module: ModuleType, hosts: list[tuple[str, int]], **kwargs: str
94
99
  ) -> MemcachedClientP:
95
100
  """
96
101
  returns a memcached client.
@@ -142,8 +147,15 @@ class MemcachedStorage(Storage):
142
147
  """
143
148
  :param key: the key to get the counter value for
144
149
  """
150
+ return int(self.storage.get(key, "0"))
151
+
152
+ def get_many(self, keys: Iterable[str]) -> dict[str, Any]: # type:ignore[explicit-any]
153
+ """
154
+ Return multiple counters at once
145
155
 
146
- return int(self.storage.get(key) or 0)
156
+ :param keys: the keys to get the counter values for
157
+ """
158
+ return self.storage.get_many(keys)
147
159
 
148
160
  def clear(self, key: str) -> None:
149
161
  """
@@ -152,7 +164,12 @@ class MemcachedStorage(Storage):
152
164
  self.storage.delete(key)
153
165
 
154
166
  def incr(
155
- self, key: str, expiry: int, elastic_expiry: bool = False, amount: int = 1
167
+ self,
168
+ key: str,
169
+ expiry: float,
170
+ elastic_expiry: bool = False,
171
+ amount: int = 1,
172
+ set_expiration_key: bool = True,
156
173
  ) -> int:
157
174
  """
158
175
  increments the counter for a given rate limit key
@@ -162,41 +179,67 @@ class MemcachedStorage(Storage):
162
179
  :param elastic_expiry: whether to keep extending the rate limit
163
180
  window every hit.
164
181
  :param amount: the number to increment by
182
+ :param set_expiration_key: set the expiration key with the expiration time if needed. If set to False, the key will still expire, but memcached cannot provide the expiration time.
165
183
  """
166
-
167
- if not self.call_memcached_func(
168
- self.storage.add, key, amount, expiry, noreply=False
169
- ):
170
- value = self.storage.incr(key, amount) or amount
171
-
184
+ value = self.call_memcached_func(self.storage.incr, key, amount, noreply=False)
185
+ if value is not None:
172
186
  if elastic_expiry:
173
- self.call_memcached_func(self.storage.touch, key, expiry)
174
- self.call_memcached_func(
175
- self.storage.set,
176
- key + "/expires",
177
- expiry + time.time(),
178
- expire=expiry,
179
- noreply=False,
180
- )
187
+ self.call_memcached_func(self.storage.touch, key, ceil(expiry))
188
+ if set_expiration_key:
189
+ self.call_memcached_func(
190
+ self.storage.set,
191
+ self._expiration_key(key),
192
+ expiry + time.time(),
193
+ expire=ceil(expiry),
194
+ noreply=False,
195
+ )
181
196
 
182
197
  return value
183
198
  else:
184
- self.call_memcached_func(
185
- self.storage.set,
186
- key + "/expires",
187
- expiry + time.time(),
188
- expire=expiry,
189
- noreply=False,
190
- )
191
-
192
- return amount
199
+ if not self.call_memcached_func(
200
+ self.storage.add, key, amount, ceil(expiry), noreply=False
201
+ ):
202
+ value = self.storage.incr(key, amount) or amount
203
+
204
+ if elastic_expiry:
205
+ self.call_memcached_func(self.storage.touch, key, ceil(expiry))
206
+ if set_expiration_key:
207
+ self.call_memcached_func(
208
+ self.storage.set,
209
+ self._expiration_key(key),
210
+ expiry + time.time(),
211
+ expire=ceil(expiry),
212
+ noreply=False,
213
+ )
214
+
215
+ return value
216
+ else:
217
+ if set_expiration_key:
218
+ self.call_memcached_func(
219
+ self.storage.set,
220
+ self._expiration_key(key),
221
+ expiry + time.time(),
222
+ expire=ceil(expiry),
223
+ noreply=False,
224
+ )
225
+
226
+ return amount
193
227
 
194
228
  def get_expiry(self, key: str) -> float:
195
229
  """
196
230
  :param key: the key to get the expiry for
197
231
  """
198
232
 
199
- return float(self.storage.get(key + "/expires") or time.time())
233
+ return float(self.storage.get(self._expiration_key(key)) or time.time())
234
+
235
+ def _expiration_key(self, key: str) -> str:
236
+ """
237
+ Return the expiration key for the given counter key.
238
+
239
+ Memcached doesn't natively return the expiration time or TTL for a given key,
240
+ so we implement the expiration time on a separate key.
241
+ """
242
+ return key + "/expires"
200
243
 
201
244
  def check(self) -> bool:
202
245
  """
@@ -212,3 +255,67 @@ class MemcachedStorage(Storage):
212
255
 
213
256
  def reset(self) -> Optional[int]:
214
257
  raise NotImplementedError
258
+
259
+ def acquire_sliding_window_entry(
260
+ self,
261
+ key: str,
262
+ limit: int,
263
+ expiry: int,
264
+ amount: int = 1,
265
+ ) -> bool:
266
+ if amount > limit:
267
+ return False
268
+ now = time.time()
269
+ previous_key, current_key = self.sliding_window_keys(key, expiry, now)
270
+ previous_count, previous_ttl, current_count, _ = self._get_sliding_window_info(
271
+ previous_key, current_key, expiry, now=now
272
+ )
273
+ weighted_count = previous_count * previous_ttl / expiry + current_count
274
+ if floor(weighted_count) + amount > limit:
275
+ return False
276
+ else:
277
+ # Hit, increase the current counter.
278
+ # If the counter doesn't exist yet, set twice the theorical expiry.
279
+ # We don't need the expiration key as it is estimated with the timestamps directly.
280
+ current_count = self.incr(
281
+ current_key, 2 * expiry, amount=amount, set_expiration_key=False
282
+ )
283
+ actualised_previous_ttl = min(0, previous_ttl - (time.time() - now))
284
+ weighted_count = (
285
+ previous_count * actualised_previous_ttl / expiry + current_count
286
+ )
287
+ if floor(weighted_count) > limit:
288
+ # Another hit won the race condition: revert the incrementation and refuse this hit
289
+ # Limitation: during high concurrency at the end of the window,
290
+ # the counter is shifted and cannot be decremented, so less requests than expected are allowed.
291
+ self.call_memcached_func(
292
+ self.storage.decr,
293
+ current_key,
294
+ amount,
295
+ noreply=True,
296
+ )
297
+ return False
298
+ return True
299
+
300
+ def get_sliding_window(
301
+ self, key: str, expiry: int
302
+ ) -> tuple[int, float, int, float]:
303
+ now = time.time()
304
+ previous_key, current_key = self.sliding_window_keys(key, expiry, now)
305
+ return self._get_sliding_window_info(previous_key, current_key, expiry, now)
306
+
307
+ def _get_sliding_window_info(
308
+ self, previous_key: str, current_key: str, expiry: int, now: float
309
+ ) -> tuple[int, float, int, float]:
310
+ result = self.get_many([previous_key, current_key])
311
+ previous_count, current_count = (
312
+ int(result.get(previous_key, 0)),
313
+ int(result.get(current_key, 0)),
314
+ )
315
+
316
+ if previous_count == 0:
317
+ previous_ttl = float(0)
318
+ else:
319
+ previous_ttl = (1 - (((now - expiry) / expiry) % 1)) * expiry
320
+ current_ttl = (1 - ((now / expiry) % 1)) * expiry + expiry
321
+ return previous_count, previous_ttl, current_count, current_ttl
limits/storage/memory.py CHANGED
@@ -1,10 +1,16 @@
1
1
  import threading
2
2
  import time
3
- from collections import Counter
3
+ from collections import Counter, defaultdict
4
+ from math import floor
4
5
 
5
6
  import limits.typing
6
- from limits.storage.base import MovingWindowSupport, Storage
7
- from limits.typing import Dict, List, Optional, Tuple, Type, Union
7
+ from limits.storage.base import (
8
+ MovingWindowSupport,
9
+ SlidingWindowCounterSupport,
10
+ Storage,
11
+ TimestampedSlidingWindow,
12
+ )
13
+ from limits.typing import Optional, Type, Union
8
14
 
9
15
 
10
16
  class LockableEntry(threading._RLock): # type: ignore
@@ -14,7 +20,9 @@ class LockableEntry(threading._RLock): # type: ignore
14
20
  super().__init__()
15
21
 
16
22
 
17
- class MemoryStorage(Storage, MovingWindowSupport):
23
+ class MemoryStorage(
24
+ Storage, MovingWindowSupport, SlidingWindowCounterSupport, TimestampedSlidingWindow
25
+ ):
18
26
  """
19
27
  rate limit storage using :class:`collections.Counter`
20
28
  as an in memory storage for fixed and elastic window strategies,
@@ -28,8 +36,9 @@ class MemoryStorage(Storage, MovingWindowSupport):
28
36
  self, uri: Optional[str] = None, wrap_exceptions: bool = False, **_: str
29
37
  ):
30
38
  self.storage: limits.typing.Counter[str] = Counter()
31
- self.expirations: Dict[str, float] = {}
32
- self.events: Dict[str, List[LockableEntry]] = {}
39
+ self.locks: defaultdict[str, threading.RLock] = defaultdict(threading.RLock)
40
+ self.expirations: dict[str, float] = {}
41
+ self.events: dict[str, list[LockableEntry]] = {}
33
42
  self.timer = threading.Timer(0.01, self.__expire_events)
34
43
  self.timer.start()
35
44
  super().__init__(uri, wrap_exceptions=wrap_exceptions, **_)
@@ -37,7 +46,7 @@ class MemoryStorage(Storage, MovingWindowSupport):
37
46
  @property
38
47
  def base_exceptions(
39
48
  self,
40
- ) -> Union[Type[Exception], Tuple[Type[Exception], ...]]: # pragma: no cover
49
+ ) -> Union[Type[Exception], tuple[Type[Exception], ...]]: # pragma: no cover
41
50
  return ValueError
42
51
 
43
52
  def __expire_events(self) -> None:
@@ -51,6 +60,7 @@ class MemoryStorage(Storage, MovingWindowSupport):
51
60
  if self.expirations[key] <= time.time():
52
61
  self.storage.pop(key, None)
53
62
  self.expirations.pop(key, None)
63
+ self.locks.pop(key, None)
54
64
 
55
65
  def __schedule_expiry(self) -> None:
56
66
  if not self.timer.is_alive():
@@ -58,7 +68,7 @@ class MemoryStorage(Storage, MovingWindowSupport):
58
68
  self.timer.start()
59
69
 
60
70
  def incr(
61
- self, key: str, expiry: int, elastic_expiry: bool = False, amount: int = 1
71
+ self, key: str, expiry: float, elastic_expiry: bool = False, amount: int = 1
62
72
  ) -> int:
63
73
  """
64
74
  increments the counter for a given rate limit key
@@ -71,10 +81,25 @@ class MemoryStorage(Storage, MovingWindowSupport):
71
81
  """
72
82
  self.get(key)
73
83
  self.__schedule_expiry()
74
- self.storage[key] += amount
84
+ with self.locks[key]:
85
+ self.storage[key] += amount
75
86
 
76
- if elastic_expiry or self.storage[key] == amount:
77
- self.expirations[key] = time.time() + expiry
87
+ if elastic_expiry or self.storage[key] == amount:
88
+ self.expirations[key] = time.time() + expiry
89
+
90
+ return self.storage.get(key, 0)
91
+
92
+ def decr(self, key: str, amount: int = 1) -> int:
93
+ """
94
+ decrements the counter for a given rate limit key
95
+
96
+ :param key: the key to decrement
97
+ :param amount: the number to decrement by
98
+ """
99
+ self.get(key)
100
+ self.__schedule_expiry()
101
+ with self.locks[key]:
102
+ self.storage[key] = max(self.storage[key] - amount, 0)
78
103
 
79
104
  return self.storage.get(key, 0)
80
105
 
@@ -86,6 +111,7 @@ class MemoryStorage(Storage, MovingWindowSupport):
86
111
  if self.expirations.get(key, 0) <= time.time():
87
112
  self.storage.pop(key, None)
88
113
  self.expirations.pop(key, None)
114
+ self.locks.pop(key, None)
89
115
 
90
116
  return self.storage.get(key, 0)
91
117
 
@@ -96,6 +122,7 @@ class MemoryStorage(Storage, MovingWindowSupport):
96
122
  self.storage.pop(key, None)
97
123
  self.expirations.pop(key, None)
98
124
  self.events.pop(key, None)
125
+ self.locks.pop(key, None)
99
126
 
100
127
  def acquire_entry(self, key: str, limit: int, expiry: int, amount: int = 1) -> bool:
101
128
  """
@@ -143,7 +170,7 @@ class MemoryStorage(Storage, MovingWindowSupport):
143
170
  else 0
144
171
  )
145
172
 
146
- def get_moving_window(self, key: str, limit: int, expiry: int) -> Tuple[float, int]:
173
+ def get_moving_window(self, key: str, limit: int, expiry: int) -> tuple[float, int]:
147
174
  """
148
175
  returns the starting point and the number of entries in the moving
149
176
  window
@@ -161,6 +188,63 @@ class MemoryStorage(Storage, MovingWindowSupport):
161
188
 
162
189
  return timestamp, acquired
163
190
 
191
+ def acquire_sliding_window_entry(
192
+ self,
193
+ key: str,
194
+ limit: int,
195
+ expiry: int,
196
+ amount: int = 1,
197
+ ) -> bool:
198
+ if amount > limit:
199
+ return False
200
+ now = time.time()
201
+ previous_key, current_key = self.sliding_window_keys(key, expiry, now)
202
+ (
203
+ previous_count,
204
+ previous_ttl,
205
+ current_count,
206
+ _,
207
+ ) = self._get_sliding_window_info(previous_key, current_key, expiry, now)
208
+ weighted_count = previous_count * previous_ttl / expiry + current_count
209
+ if floor(weighted_count) + amount > limit:
210
+ return False
211
+ else:
212
+ # Hit, increase the current counter.
213
+ # If the counter doesn't exist yet, set twice the theorical expiry.
214
+ current_count = self.incr(current_key, 2 * expiry, amount=amount)
215
+ weighted_count = previous_count * previous_ttl / expiry + current_count
216
+ if floor(weighted_count) > limit:
217
+ # Another hit won the race condition: revert the incrementation and refuse this hit
218
+ # Limitation: during high concurrency at the end of the window,
219
+ # the counter is shifted and cannot be decremented, so less requests than expected are allowed.
220
+ self.decr(current_key, amount)
221
+ # print("Concurrent call, reverting the counter increase")
222
+ return False
223
+ return True
224
+
225
+ def _get_sliding_window_info(
226
+ self,
227
+ previous_key: str,
228
+ current_key: str,
229
+ expiry: int,
230
+ now: float,
231
+ ) -> tuple[int, float, int, float]:
232
+ previous_count = self.get(previous_key)
233
+ current_count = self.get(current_key)
234
+ if previous_count == 0:
235
+ previous_ttl = float(0)
236
+ else:
237
+ previous_ttl = (1 - (((now - expiry) / expiry) % 1)) * expiry
238
+ current_ttl = (1 - ((now / expiry) % 1)) * expiry + expiry
239
+ return previous_count, previous_ttl, current_count, current_ttl
240
+
241
+ def get_sliding_window(
242
+ self, key: str, expiry: int
243
+ ) -> tuple[int, float, int, float]:
244
+ now = time.time()
245
+ previous_key, current_key = self.sliding_window_keys(key, expiry, now)
246
+ return self._get_sliding_window_info(previous_key, current_key, expiry, now)
247
+
164
248
  def check(self) -> bool:
165
249
  """
166
250
  check if storage is healthy
@@ -173,4 +257,5 @@ class MemoryStorage(Storage, MovingWindowSupport):
173
257
  self.storage.clear()
174
258
  self.expirations.clear()
175
259
  self.events.clear()
260
+ self.locks.clear()
176
261
  return num_items