limits 4.7.3__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- limits/_version.py +3 -3
- limits/aio/storage/__init__.py +0 -2
- limits/aio/storage/base.py +1 -5
- limits/aio/storage/memcached/__init__.py +184 -0
- limits/aio/storage/memcached/bridge.py +73 -0
- limits/aio/storage/memcached/emcache.py +112 -0
- limits/aio/storage/memcached/memcachio.py +104 -0
- limits/aio/storage/memory.py +41 -48
- limits/aio/storage/mongodb.py +26 -31
- limits/aio/storage/redis/__init__.py +2 -4
- limits/aio/storage/redis/bridge.py +0 -1
- limits/aio/storage/redis/coredis.py +2 -6
- limits/aio/storage/redis/redispy.py +1 -8
- limits/aio/strategies.py +1 -28
- limits/resources/redis/lua_scripts/acquire_moving_window.lua +5 -2
- limits/resources/redis/lua_scripts/moving_window.lua +23 -14
- limits/storage/__init__.py +0 -2
- limits/storage/base.py +1 -5
- limits/storage/memcached.py +8 -29
- limits/storage/memory.py +16 -35
- limits/storage/mongodb.py +25 -34
- limits/storage/redis.py +1 -7
- limits/strategies.py +1 -31
- limits/typing.py +1 -50
- {limits-4.7.3.dist-info → limits-5.0.0rc2.dist-info}/METADATA +8 -14
- limits-5.0.0rc2.dist-info/RECORD +44 -0
- limits/aio/storage/etcd.py +0 -146
- limits/aio/storage/memcached.py +0 -281
- limits/storage/etcd.py +0 -139
- limits-4.7.3.dist-info/RECORD +0 -43
- {limits-4.7.3.dist-info → limits-5.0.0rc2.dist-info}/WHEEL +0 -0
- {limits-4.7.3.dist-info → limits-5.0.0rc2.dist-info}/licenses/LICENSE.txt +0 -0
- {limits-4.7.3.dist-info → limits-5.0.0rc2.dist-info}/top_level.txt +0 -0
limits/_version.py
CHANGED
|
@@ -8,11 +8,11 @@ import json
|
|
|
8
8
|
|
|
9
9
|
version_json = '''
|
|
10
10
|
{
|
|
11
|
-
"date": "2025-04-
|
|
11
|
+
"date": "2025-04-15T12:47:18-0700",
|
|
12
12
|
"dirty": false,
|
|
13
13
|
"error": null,
|
|
14
|
-
"full-revisionid": "
|
|
15
|
-
"version": "
|
|
14
|
+
"full-revisionid": "ca0e9ca30c696af1102471218171c07ce8ee7644",
|
|
15
|
+
"version": "5.0.0rc2"
|
|
16
16
|
}
|
|
17
17
|
''' # END VERSION_JSON
|
|
18
18
|
|
limits/aio/storage/__init__.py
CHANGED
|
@@ -6,14 +6,12 @@ Implementations of storage backends to be used with
|
|
|
6
6
|
from __future__ import annotations
|
|
7
7
|
|
|
8
8
|
from .base import MovingWindowSupport, SlidingWindowCounterSupport, Storage
|
|
9
|
-
from .etcd import EtcdStorage
|
|
10
9
|
from .memcached import MemcachedStorage
|
|
11
10
|
from .memory import MemoryStorage
|
|
12
11
|
from .mongodb import MongoDBStorage
|
|
13
12
|
from .redis import RedisClusterStorage, RedisSentinelStorage, RedisStorage
|
|
14
13
|
|
|
15
14
|
__all__ = [
|
|
16
|
-
"EtcdStorage",
|
|
17
15
|
"MemcachedStorage",
|
|
18
16
|
"MemoryStorage",
|
|
19
17
|
"MongoDBStorage",
|
limits/aio/storage/base.py
CHANGED
|
@@ -75,16 +75,12 @@ class Storage(LazyDependency, metaclass=StorageRegistry):
|
|
|
75
75
|
raise NotImplementedError
|
|
76
76
|
|
|
77
77
|
@abstractmethod
|
|
78
|
-
async def incr(
|
|
79
|
-
self, key: str, expiry: int, elastic_expiry: bool = False, amount: int = 1
|
|
80
|
-
) -> int:
|
|
78
|
+
async def incr(self, key: str, expiry: int, amount: int = 1) -> int:
|
|
81
79
|
"""
|
|
82
80
|
increments the counter for a given rate limit key
|
|
83
81
|
|
|
84
82
|
:param key: the key to increment
|
|
85
83
|
:param expiry: amount in seconds for the key to expire in
|
|
86
|
-
:param elastic_expiry: whether to keep extending the rate limit
|
|
87
|
-
window every hit.
|
|
88
84
|
:param amount: the number to increment by
|
|
89
85
|
"""
|
|
90
86
|
raise NotImplementedError
|
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import time
|
|
4
|
+
from math import floor
|
|
5
|
+
|
|
6
|
+
from deprecated.sphinx import versionadded, versionchanged
|
|
7
|
+
from packaging.version import Version
|
|
8
|
+
|
|
9
|
+
from limits.aio.storage import SlidingWindowCounterSupport, Storage
|
|
10
|
+
from limits.aio.storage.memcached.bridge import MemcachedBridge
|
|
11
|
+
from limits.aio.storage.memcached.emcache import EmcacheBridge
|
|
12
|
+
from limits.aio.storage.memcached.memcachio import MemcachioBridge
|
|
13
|
+
from limits.storage.base import TimestampedSlidingWindow
|
|
14
|
+
from limits.typing import Literal
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@versionadded(version="2.1")
|
|
18
|
+
@versionchanged(
|
|
19
|
+
version="5.0",
|
|
20
|
+
reason="Switched default implementation to :pypi:`memcachio`",
|
|
21
|
+
)
|
|
22
|
+
class MemcachedStorage(Storage, SlidingWindowCounterSupport, TimestampedSlidingWindow):
|
|
23
|
+
"""
|
|
24
|
+
Rate limit storage with memcached as backend.
|
|
25
|
+
|
|
26
|
+
Depends on :pypi:`memcachio`
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
STORAGE_SCHEME = ["async+memcached"]
|
|
30
|
+
"""The storage scheme for memcached to be used in an async context"""
|
|
31
|
+
|
|
32
|
+
DEPENDENCIES = {
|
|
33
|
+
"memcachio": Version("0.3"),
|
|
34
|
+
"emcache": Version("0.0"),
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
bridge: MemcachedBridge
|
|
38
|
+
storage_exceptions: tuple[Exception, ...]
|
|
39
|
+
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
uri: str,
|
|
43
|
+
wrap_exceptions: bool = False,
|
|
44
|
+
implementation: Literal["memcachio", "emcache"] = "memcachio",
|
|
45
|
+
**options: float | str | bool,
|
|
46
|
+
) -> None:
|
|
47
|
+
"""
|
|
48
|
+
:param uri: memcached location of the form
|
|
49
|
+
``async+memcached://host:port,host:port``
|
|
50
|
+
:param wrap_exceptions: Whether to wrap storage exceptions in
|
|
51
|
+
:exc:`limits.errors.StorageError` before raising it.
|
|
52
|
+
:param implementation: Whether to use the client implementation from
|
|
53
|
+
|
|
54
|
+
- ``memcachio``: :class:`memcachio.Client`
|
|
55
|
+
- ``emcache``: :class:`emcache.Client`
|
|
56
|
+
:param options: all remaining keyword arguments are passed
|
|
57
|
+
directly to the constructor of :class:`memcachio.Client`
|
|
58
|
+
:raise ConfigurationError: when :pypi:`memcachio` is not available
|
|
59
|
+
"""
|
|
60
|
+
if implementation == "emcache":
|
|
61
|
+
self.bridge = EmcacheBridge(
|
|
62
|
+
uri, self.dependencies["emcache"].module, **options
|
|
63
|
+
)
|
|
64
|
+
else:
|
|
65
|
+
self.bridge = MemcachioBridge(
|
|
66
|
+
uri, self.dependencies["memcachio"].module, **options
|
|
67
|
+
)
|
|
68
|
+
super().__init__(uri, wrap_exceptions=wrap_exceptions, **options)
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
def base_exceptions(
|
|
72
|
+
self,
|
|
73
|
+
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
|
|
74
|
+
return self.bridge.base_exceptions
|
|
75
|
+
|
|
76
|
+
async def get(self, key: str) -> int:
|
|
77
|
+
"""
|
|
78
|
+
:param key: the key to get the counter value for
|
|
79
|
+
"""
|
|
80
|
+
return await self.bridge.get(key)
|
|
81
|
+
|
|
82
|
+
async def clear(self, key: str) -> None:
|
|
83
|
+
"""
|
|
84
|
+
:param key: the key to clear rate limits for
|
|
85
|
+
"""
|
|
86
|
+
await self.bridge.clear(key)
|
|
87
|
+
|
|
88
|
+
async def incr(
|
|
89
|
+
self,
|
|
90
|
+
key: str,
|
|
91
|
+
expiry: float,
|
|
92
|
+
amount: int = 1,
|
|
93
|
+
set_expiration_key: bool = True,
|
|
94
|
+
) -> int:
|
|
95
|
+
"""
|
|
96
|
+
increments the counter for a given rate limit key
|
|
97
|
+
|
|
98
|
+
:param key: the key to increment
|
|
99
|
+
:param expiry: amount in seconds for the key to expire in
|
|
100
|
+
window every hit.
|
|
101
|
+
:param amount: the number to increment by
|
|
102
|
+
:param set_expiration_key: if set to False, the expiration time won't be stored but the key will still expire
|
|
103
|
+
"""
|
|
104
|
+
return await self.bridge.incr(
|
|
105
|
+
key, expiry, amount, set_expiration_key=set_expiration_key
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
async def get_expiry(self, key: str) -> float:
|
|
109
|
+
"""
|
|
110
|
+
:param key: the key to get the expiry for
|
|
111
|
+
"""
|
|
112
|
+
return await self.bridge.get_expiry(key)
|
|
113
|
+
|
|
114
|
+
async def reset(self) -> int | None:
|
|
115
|
+
raise NotImplementedError
|
|
116
|
+
|
|
117
|
+
async def check(self) -> bool:
|
|
118
|
+
return await self.bridge.check()
|
|
119
|
+
|
|
120
|
+
async def acquire_sliding_window_entry(
|
|
121
|
+
self,
|
|
122
|
+
key: str,
|
|
123
|
+
limit: int,
|
|
124
|
+
expiry: int,
|
|
125
|
+
amount: int = 1,
|
|
126
|
+
) -> bool:
|
|
127
|
+
if amount > limit:
|
|
128
|
+
return False
|
|
129
|
+
now = time.time()
|
|
130
|
+
previous_key, current_key = self.sliding_window_keys(key, expiry, now)
|
|
131
|
+
(
|
|
132
|
+
previous_count,
|
|
133
|
+
previous_ttl,
|
|
134
|
+
current_count,
|
|
135
|
+
_,
|
|
136
|
+
) = await self._get_sliding_window_info(previous_key, current_key, expiry, now)
|
|
137
|
+
t0 = time.time()
|
|
138
|
+
weighted_count = previous_count * previous_ttl / expiry + current_count
|
|
139
|
+
if floor(weighted_count) + amount > limit:
|
|
140
|
+
return False
|
|
141
|
+
else:
|
|
142
|
+
# Hit, increase the current counter.
|
|
143
|
+
# If the counter doesn't exist yet, set twice the theorical expiry.
|
|
144
|
+
# We don't need the expiration key as it is estimated with the timestamps directly.
|
|
145
|
+
current_count = await self.incr(
|
|
146
|
+
current_key, 2 * expiry, amount=amount, set_expiration_key=False
|
|
147
|
+
)
|
|
148
|
+
t1 = time.time()
|
|
149
|
+
actualised_previous_ttl = max(0, previous_ttl - (t1 - t0))
|
|
150
|
+
weighted_count = (
|
|
151
|
+
previous_count * actualised_previous_ttl / expiry + current_count
|
|
152
|
+
)
|
|
153
|
+
if floor(weighted_count) > limit:
|
|
154
|
+
# Another hit won the race condition: revert the increment and refuse this hit
|
|
155
|
+
# Limitation: during high concurrency at the end of the window,
|
|
156
|
+
# the counter is shifted and cannot be decremented, so less requests than expected are allowed.
|
|
157
|
+
await self.bridge.decr(current_key, amount, noreply=True)
|
|
158
|
+
return False
|
|
159
|
+
return True
|
|
160
|
+
|
|
161
|
+
async def get_sliding_window(
|
|
162
|
+
self, key: str, expiry: int
|
|
163
|
+
) -> tuple[int, float, int, float]:
|
|
164
|
+
now = time.time()
|
|
165
|
+
previous_key, current_key = self.sliding_window_keys(key, expiry, now)
|
|
166
|
+
return await self._get_sliding_window_info(
|
|
167
|
+
previous_key, current_key, expiry, now
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
async def _get_sliding_window_info(
|
|
171
|
+
self, previous_key: str, current_key: str, expiry: int, now: float
|
|
172
|
+
) -> tuple[int, float, int, float]:
|
|
173
|
+
result = await self.bridge.get_many([previous_key, current_key])
|
|
174
|
+
|
|
175
|
+
previous_count = result.get(previous_key.encode("utf-8"), 0)
|
|
176
|
+
current_count = result.get(current_key.encode("utf-8"), 0)
|
|
177
|
+
|
|
178
|
+
if previous_count == 0:
|
|
179
|
+
previous_ttl = float(0)
|
|
180
|
+
else:
|
|
181
|
+
previous_ttl = (1 - (((now - expiry) / expiry) % 1)) * expiry
|
|
182
|
+
current_ttl = (1 - ((now / expiry) % 1)) * expiry + expiry
|
|
183
|
+
|
|
184
|
+
return previous_count, previous_ttl, current_count, current_ttl
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import urllib
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from types import ModuleType
|
|
6
|
+
|
|
7
|
+
from limits.typing import Iterable
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class MemcachedBridge(ABC):
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
uri: str,
|
|
14
|
+
dependency: ModuleType,
|
|
15
|
+
**options: float | str | bool,
|
|
16
|
+
) -> None:
|
|
17
|
+
self.uri = uri
|
|
18
|
+
self.parsed_uri = urllib.parse.urlparse(self.uri)
|
|
19
|
+
self.dependency = dependency
|
|
20
|
+
self.hosts = []
|
|
21
|
+
self.options = options
|
|
22
|
+
|
|
23
|
+
sep = self.parsed_uri.netloc.strip().find("@") + 1
|
|
24
|
+
for loc in self.parsed_uri.netloc.strip()[sep:].split(","):
|
|
25
|
+
host, port = loc.split(":")
|
|
26
|
+
self.hosts.append((host, int(port)))
|
|
27
|
+
|
|
28
|
+
if self.parsed_uri.username:
|
|
29
|
+
self.options["username"] = self.parsed_uri.username
|
|
30
|
+
if self.parsed_uri.password:
|
|
31
|
+
self.options["password"] = self.parsed_uri.password
|
|
32
|
+
|
|
33
|
+
def _expiration_key(self, key: str) -> str:
|
|
34
|
+
"""
|
|
35
|
+
Return the expiration key for the given counter key.
|
|
36
|
+
|
|
37
|
+
Memcached doesn't natively return the expiration time or TTL for a given key,
|
|
38
|
+
so we implement the expiration time on a separate key.
|
|
39
|
+
"""
|
|
40
|
+
return key + "/expires"
|
|
41
|
+
|
|
42
|
+
@property
|
|
43
|
+
@abstractmethod
|
|
44
|
+
def base_exceptions(
|
|
45
|
+
self,
|
|
46
|
+
) -> type[Exception] | tuple[type[Exception], ...]: ...
|
|
47
|
+
|
|
48
|
+
@abstractmethod
|
|
49
|
+
async def get(self, key: str) -> int: ...
|
|
50
|
+
|
|
51
|
+
@abstractmethod
|
|
52
|
+
async def get_many(self, keys: Iterable[str]) -> dict[bytes, int]: ...
|
|
53
|
+
|
|
54
|
+
@abstractmethod
|
|
55
|
+
async def clear(self, key: str) -> None: ...
|
|
56
|
+
|
|
57
|
+
@abstractmethod
|
|
58
|
+
async def decr(self, key: str, amount: int = 1, noreply: bool = False) -> int: ...
|
|
59
|
+
|
|
60
|
+
@abstractmethod
|
|
61
|
+
async def incr(
|
|
62
|
+
self,
|
|
63
|
+
key: str,
|
|
64
|
+
expiry: float,
|
|
65
|
+
amount: int = 1,
|
|
66
|
+
set_expiration_key: bool = True,
|
|
67
|
+
) -> int: ...
|
|
68
|
+
|
|
69
|
+
@abstractmethod
|
|
70
|
+
async def get_expiry(self, key: str) -> float: ...
|
|
71
|
+
|
|
72
|
+
@abstractmethod
|
|
73
|
+
async def check(self) -> bool: ...
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import time
|
|
4
|
+
from math import ceil
|
|
5
|
+
from types import ModuleType
|
|
6
|
+
|
|
7
|
+
from limits.typing import TYPE_CHECKING, Iterable
|
|
8
|
+
|
|
9
|
+
from .bridge import MemcachedBridge
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
import emcache
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class EmcacheBridge(MemcachedBridge):
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
uri: str,
|
|
19
|
+
dependency: ModuleType,
|
|
20
|
+
**options: float | str | bool,
|
|
21
|
+
) -> None:
|
|
22
|
+
super().__init__(uri, dependency, **options)
|
|
23
|
+
self._storage = None
|
|
24
|
+
|
|
25
|
+
async def get_storage(self) -> emcache.Client:
|
|
26
|
+
if not self._storage:
|
|
27
|
+
self._storage = await self.dependency.create_client(
|
|
28
|
+
[self.dependency.MemcachedHostAddress(h, p) for h, p in self.hosts],
|
|
29
|
+
**self.options,
|
|
30
|
+
)
|
|
31
|
+
assert self._storage
|
|
32
|
+
return self._storage
|
|
33
|
+
|
|
34
|
+
async def get(self, key: str) -> int:
|
|
35
|
+
item = await (await self.get_storage()).get(key.encode("utf-8"))
|
|
36
|
+
return item and int(item.value) or 0
|
|
37
|
+
|
|
38
|
+
async def get_many(self, keys: Iterable[str]) -> dict[bytes, int]:
|
|
39
|
+
results = await (await self.get_storage()).get_many(
|
|
40
|
+
[k.encode("utf-8") for k in keys]
|
|
41
|
+
)
|
|
42
|
+
return {k: int(item.value) if item else 0 for k, item in results.items()}
|
|
43
|
+
|
|
44
|
+
async def clear(self, key: str) -> None:
|
|
45
|
+
try:
|
|
46
|
+
await (await self.get_storage()).delete(key.encode("utf-8"))
|
|
47
|
+
except self.dependency.NotFoundCommandError:
|
|
48
|
+
pass
|
|
49
|
+
|
|
50
|
+
async def decr(self, key: str, amount: int = 1, noreply: bool = False) -> int:
|
|
51
|
+
storage = await self.get_storage()
|
|
52
|
+
limit_key = key.encode("utf-8")
|
|
53
|
+
try:
|
|
54
|
+
value = await storage.decrement(limit_key, amount, noreply=noreply) or 0
|
|
55
|
+
except self.dependency.NotFoundCommandError:
|
|
56
|
+
value = 0
|
|
57
|
+
return value
|
|
58
|
+
|
|
59
|
+
async def incr(
|
|
60
|
+
self, key: str, expiry: float, amount: int = 1, set_expiration_key: bool = True
|
|
61
|
+
) -> int:
|
|
62
|
+
storage = await self.get_storage()
|
|
63
|
+
limit_key = key.encode("utf-8")
|
|
64
|
+
expire_key = self._expiration_key(key).encode()
|
|
65
|
+
try:
|
|
66
|
+
return await storage.increment(limit_key, amount) or amount
|
|
67
|
+
except self.dependency.NotFoundCommandError:
|
|
68
|
+
storage = await self.get_storage()
|
|
69
|
+
try:
|
|
70
|
+
await storage.add(limit_key, f"{amount}".encode(), exptime=ceil(expiry))
|
|
71
|
+
if set_expiration_key:
|
|
72
|
+
await storage.set(
|
|
73
|
+
expire_key,
|
|
74
|
+
str(expiry + time.time()).encode("utf-8"),
|
|
75
|
+
exptime=ceil(expiry),
|
|
76
|
+
noreply=False,
|
|
77
|
+
)
|
|
78
|
+
value = amount
|
|
79
|
+
except self.dependency.NotStoredStorageCommandError:
|
|
80
|
+
# Coult not add the key, probably because a concurrent call has added it
|
|
81
|
+
storage = await self.get_storage()
|
|
82
|
+
value = await storage.increment(limit_key, amount) or amount
|
|
83
|
+
return value
|
|
84
|
+
|
|
85
|
+
async def get_expiry(self, key: str) -> float:
|
|
86
|
+
storage = await self.get_storage()
|
|
87
|
+
item = await storage.get(self._expiration_key(key).encode("utf-8"))
|
|
88
|
+
|
|
89
|
+
return item and float(item.value) or time.time()
|
|
90
|
+
pass
|
|
91
|
+
|
|
92
|
+
@property
|
|
93
|
+
def base_exceptions(
|
|
94
|
+
self,
|
|
95
|
+
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
|
|
96
|
+
return (
|
|
97
|
+
self.dependency.ClusterNoAvailableNodes,
|
|
98
|
+
self.dependency.CommandError,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
async def check(self) -> bool:
|
|
102
|
+
"""
|
|
103
|
+
Check if storage is healthy by calling the ``get`` command
|
|
104
|
+
on the key ``limiter-check``
|
|
105
|
+
"""
|
|
106
|
+
try:
|
|
107
|
+
storage = await self.get_storage()
|
|
108
|
+
await storage.get(b"limiter-check")
|
|
109
|
+
|
|
110
|
+
return True
|
|
111
|
+
except: # noqa
|
|
112
|
+
return False
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import time
|
|
4
|
+
from math import ceil
|
|
5
|
+
from types import ModuleType
|
|
6
|
+
from typing import TYPE_CHECKING, Iterable
|
|
7
|
+
|
|
8
|
+
from .bridge import MemcachedBridge
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
import memcachio
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class MemcachioBridge(MemcachedBridge):
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
uri: str,
|
|
18
|
+
dependency: ModuleType,
|
|
19
|
+
**options: float | str | bool,
|
|
20
|
+
) -> None:
|
|
21
|
+
super().__init__(uri, dependency, **options)
|
|
22
|
+
self._storage: memcachio.Client[bytes] | None = None
|
|
23
|
+
|
|
24
|
+
@property
|
|
25
|
+
def base_exceptions(
|
|
26
|
+
self,
|
|
27
|
+
) -> type[Exception] | tuple[type[Exception], ...]:
|
|
28
|
+
return (
|
|
29
|
+
self.dependency.errors.NoAvailableNodes,
|
|
30
|
+
self.dependency.errors.MemcachioConnectionError,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
async def get_storage(self) -> memcachio.Client[bytes]:
|
|
34
|
+
if not self._storage:
|
|
35
|
+
self._storage = self.dependency.Client(
|
|
36
|
+
[(h, p) for h, p in self.hosts],
|
|
37
|
+
**self.options,
|
|
38
|
+
)
|
|
39
|
+
assert self._storage
|
|
40
|
+
return self._storage
|
|
41
|
+
|
|
42
|
+
async def get(self, key: str) -> int:
|
|
43
|
+
return (await self.get_many([key])).get(key.encode("utf-8"), 0)
|
|
44
|
+
|
|
45
|
+
async def get_many(self, keys: Iterable[str]) -> dict[bytes, int]:
|
|
46
|
+
"""
|
|
47
|
+
Return multiple counters at once
|
|
48
|
+
|
|
49
|
+
:param keys: the keys to get the counter values for
|
|
50
|
+
"""
|
|
51
|
+
results = await (await self.get_storage()).get(
|
|
52
|
+
*[k.encode("utf-8") for k in keys]
|
|
53
|
+
)
|
|
54
|
+
return {k: int(v.value) for k, v in results.items()}
|
|
55
|
+
|
|
56
|
+
async def clear(self, key: str) -> None:
|
|
57
|
+
await (await self.get_storage()).delete(key.encode("utf-8"))
|
|
58
|
+
|
|
59
|
+
async def decr(self, key: str, amount: int = 1, noreply: bool = False) -> int:
|
|
60
|
+
storage = await self.get_storage()
|
|
61
|
+
limit_key = key.encode("utf-8")
|
|
62
|
+
return await storage.decr(limit_key, amount, noreply=noreply) or 0
|
|
63
|
+
|
|
64
|
+
async def incr(
|
|
65
|
+
self, key: str, expiry: float, amount: int = 1, set_expiration_key: bool = True
|
|
66
|
+
) -> int:
|
|
67
|
+
storage = await self.get_storage()
|
|
68
|
+
limit_key = key.encode("utf-8")
|
|
69
|
+
expire_key = self._expiration_key(key).encode()
|
|
70
|
+
if (value := (await storage.incr(limit_key, amount))) is None:
|
|
71
|
+
storage = await self.get_storage()
|
|
72
|
+
if await storage.add(limit_key, f"{amount}".encode(), expiry=ceil(expiry)):
|
|
73
|
+
if set_expiration_key:
|
|
74
|
+
await storage.set(
|
|
75
|
+
expire_key,
|
|
76
|
+
str(expiry + time.time()).encode("utf-8"),
|
|
77
|
+
expiry=ceil(expiry),
|
|
78
|
+
noreply=False,
|
|
79
|
+
)
|
|
80
|
+
return amount
|
|
81
|
+
else:
|
|
82
|
+
storage = await self.get_storage()
|
|
83
|
+
return await storage.incr(limit_key, amount) or amount
|
|
84
|
+
return value
|
|
85
|
+
|
|
86
|
+
async def get_expiry(self, key: str) -> float:
|
|
87
|
+
storage = await self.get_storage()
|
|
88
|
+
expiration_key = self._expiration_key(key).encode("utf-8")
|
|
89
|
+
item = (await storage.get(expiration_key)).get(expiration_key, None)
|
|
90
|
+
|
|
91
|
+
return item and float(item.value) or time.time()
|
|
92
|
+
|
|
93
|
+
async def check(self) -> bool:
|
|
94
|
+
"""
|
|
95
|
+
Check if storage is healthy by calling the ``get`` command
|
|
96
|
+
on the key ``limiter-check``
|
|
97
|
+
"""
|
|
98
|
+
try:
|
|
99
|
+
storage = await self.get_storage()
|
|
100
|
+
await storage.get(b"limiter-check")
|
|
101
|
+
|
|
102
|
+
return True
|
|
103
|
+
except: # noqa
|
|
104
|
+
return False
|
limits/aio/storage/memory.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
|
+
import bisect
|
|
4
5
|
import time
|
|
5
6
|
from collections import Counter, defaultdict
|
|
6
7
|
from math import floor
|
|
@@ -28,7 +29,7 @@ class MemoryStorage(
|
|
|
28
29
|
):
|
|
29
30
|
"""
|
|
30
31
|
rate limit storage using :class:`collections.Counter`
|
|
31
|
-
as an in memory storage for fixed
|
|
32
|
+
as an in memory storage for fixed & sliding window strategies,
|
|
32
33
|
and a simple list to implement moving window strategy.
|
|
33
34
|
"""
|
|
34
35
|
|
|
@@ -61,20 +62,29 @@ class MemoryStorage(
|
|
|
61
62
|
asyncio.ensure_future(self.__schedule_expiry())
|
|
62
63
|
|
|
63
64
|
async def __expire_events(self) -> None:
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
65
|
+
try:
|
|
66
|
+
now = time.time()
|
|
67
|
+
for key in list(self.events.keys()):
|
|
68
|
+
cutoff = await asyncio.to_thread(
|
|
69
|
+
lambda evts: bisect.bisect_left(
|
|
70
|
+
evts, -now, key=lambda event: -event.expiry
|
|
71
|
+
),
|
|
72
|
+
self.events[key],
|
|
73
|
+
)
|
|
74
|
+
async with self.locks[key]:
|
|
75
|
+
if self.events.get(key, []):
|
|
76
|
+
self.events[key] = self.events[key][:cutoff]
|
|
77
|
+
if not self.events.get(key, None):
|
|
78
|
+
self.events.pop(key, None)
|
|
79
|
+
self.locks.pop(key, None)
|
|
80
|
+
|
|
81
|
+
for key in list(self.expirations.keys()):
|
|
82
|
+
if self.expirations[key] <= time.time():
|
|
83
|
+
self.storage.pop(key, None)
|
|
84
|
+
self.expirations.pop(key, None)
|
|
71
85
|
self.locks.pop(key, None)
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
if self.expirations[key] <= time.time():
|
|
75
|
-
self.storage.pop(key, None)
|
|
76
|
-
self.expirations.pop(key, None)
|
|
77
|
-
self.locks.pop(key, None)
|
|
86
|
+
except asyncio.CancelledError:
|
|
87
|
+
return
|
|
78
88
|
|
|
79
89
|
async def __schedule_expiry(self) -> None:
|
|
80
90
|
if not self.timer or self.timer.done():
|
|
@@ -86,26 +96,20 @@ class MemoryStorage(
|
|
|
86
96
|
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
|
|
87
97
|
return ValueError
|
|
88
98
|
|
|
89
|
-
async def incr(
|
|
90
|
-
self, key: str, expiry: float, elastic_expiry: bool = False, amount: int = 1
|
|
91
|
-
) -> int:
|
|
99
|
+
async def incr(self, key: str, expiry: float, amount: int = 1) -> int:
|
|
92
100
|
"""
|
|
93
101
|
increments the counter for a given rate limit key
|
|
94
102
|
|
|
95
103
|
:param key: the key to increment
|
|
96
104
|
:param expiry: amount in seconds for the key to expire in
|
|
97
|
-
:param elastic_expiry: whether to keep extending the rate limit
|
|
98
|
-
window every hit.
|
|
99
105
|
:param amount: the number to increment by
|
|
100
106
|
"""
|
|
101
107
|
await self.get(key)
|
|
102
108
|
await self.__schedule_expiry()
|
|
103
109
|
async with self.locks[key]:
|
|
104
110
|
self.storage[key] += amount
|
|
105
|
-
|
|
106
|
-
if elastic_expiry or self.storage[key] == amount:
|
|
111
|
+
if self.storage[key] == amount:
|
|
107
112
|
self.expirations[key] = time.time() + expiry
|
|
108
|
-
|
|
109
113
|
return self.storage.get(key, amount)
|
|
110
114
|
|
|
111
115
|
async def decr(self, key: str, amount: int = 1) -> int:
|
|
@@ -165,8 +169,7 @@ class MemoryStorage(
|
|
|
165
169
|
if entry and entry.atime >= timestamp - expiry:
|
|
166
170
|
return False
|
|
167
171
|
else:
|
|
168
|
-
self.events[key][:0] = [Entry(expiry)
|
|
169
|
-
|
|
172
|
+
self.events[key][:0] = [Entry(expiry)] * amount
|
|
170
173
|
return True
|
|
171
174
|
|
|
172
175
|
async def get_expiry(self, key: str) -> float:
|
|
@@ -176,22 +179,6 @@ class MemoryStorage(
|
|
|
176
179
|
|
|
177
180
|
return self.expirations.get(key, time.time())
|
|
178
181
|
|
|
179
|
-
async def get_num_acquired(self, key: str, expiry: int) -> int:
|
|
180
|
-
"""
|
|
181
|
-
returns the number of entries already acquired
|
|
182
|
-
|
|
183
|
-
:param key: rate limit key to acquire an entry in
|
|
184
|
-
:param expiry: expiry of the entry
|
|
185
|
-
"""
|
|
186
|
-
timestamp = time.time()
|
|
187
|
-
|
|
188
|
-
return (
|
|
189
|
-
len([k for k in self.events.get(key, []) if k.atime >= timestamp - expiry])
|
|
190
|
-
if self.events.get(key)
|
|
191
|
-
else 0
|
|
192
|
-
)
|
|
193
|
-
|
|
194
|
-
# FIXME: arg limit is not used
|
|
195
182
|
async def get_moving_window(
|
|
196
183
|
self, key: str, limit: int, expiry: int
|
|
197
184
|
) -> tuple[float, int]:
|
|
@@ -203,14 +190,14 @@ class MemoryStorage(
|
|
|
203
190
|
:param expiry: expiry of entry
|
|
204
191
|
:return: (start of window, number of acquired entries)
|
|
205
192
|
"""
|
|
206
|
-
timestamp = time.time()
|
|
207
|
-
acquired = await self.get_num_acquired(key, expiry)
|
|
208
|
-
|
|
209
|
-
for item in self.events.get(key, [])[::-1]:
|
|
210
|
-
if item.atime >= timestamp - expiry:
|
|
211
|
-
return item.atime, acquired
|
|
212
193
|
|
|
213
|
-
|
|
194
|
+
timestamp = time.time()
|
|
195
|
+
if events := self.events.get(key, []):
|
|
196
|
+
oldest = bisect.bisect_left(
|
|
197
|
+
events, -(timestamp - expiry), key=lambda entry: -entry.atime
|
|
198
|
+
)
|
|
199
|
+
return events[oldest - 1].atime, oldest
|
|
200
|
+
return timestamp, 0
|
|
214
201
|
|
|
215
202
|
async def acquire_sliding_window_entry(
|
|
216
203
|
self,
|
|
@@ -242,7 +229,6 @@ class MemoryStorage(
|
|
|
242
229
|
# Limitation: during high concurrency at the end of the window,
|
|
243
230
|
# the counter is shifted and cannot be decremented, so less requests than expected are allowed.
|
|
244
231
|
await self.decr(current_key, amount)
|
|
245
|
-
# print("Concurrent call, reverting the counter increase")
|
|
246
232
|
return False
|
|
247
233
|
return True
|
|
248
234
|
|
|
@@ -286,3 +272,10 @@ class MemoryStorage(
|
|
|
286
272
|
self.locks.clear()
|
|
287
273
|
|
|
288
274
|
return num_items
|
|
275
|
+
|
|
276
|
+
def __del__(self) -> None:
|
|
277
|
+
try:
|
|
278
|
+
if self.timer and not self.timer.done():
|
|
279
|
+
self.timer.cancel()
|
|
280
|
+
except RuntimeError: # noqa
|
|
281
|
+
pass
|