redis 5.2.1__py3-none-any.whl → 5.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- redis/asyncio/client.py +49 -12
- redis/asyncio/cluster.py +101 -12
- redis/asyncio/connection.py +78 -11
- redis/auth/__init__.py +0 -0
- redis/auth/err.py +31 -0
- redis/auth/idp.py +28 -0
- redis/auth/token.py +126 -0
- redis/auth/token_manager.py +370 -0
- redis/backoff.py +15 -0
- redis/client.py +116 -56
- redis/cluster.py +157 -33
- redis/connection.py +103 -11
- redis/credentials.py +40 -1
- redis/event.py +394 -0
- redis/typing.py +1 -1
- redis/utils.py +65 -0
- {redis-5.2.1.dist-info → redis-5.3.0.dist-info}/METADATA +2 -1
- {redis-5.2.1.dist-info → redis-5.3.0.dist-info}/RECORD +21 -15
- {redis-5.2.1.dist-info → redis-5.3.0.dist-info}/LICENSE +0 -0
- {redis-5.2.1.dist-info → redis-5.3.0.dist-info}/WHEEL +0 -0
- {redis-5.2.1.dist-info → redis-5.3.0.dist-info}/top_level.txt +0 -0
redis/auth/token.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from datetime import datetime, timezone
|
|
3
|
+
|
|
4
|
+
import jwt
|
|
5
|
+
from redis.auth.err import InvalidTokenSchemaErr
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class TokenInterface(ABC):
|
|
9
|
+
@abstractmethod
|
|
10
|
+
def is_expired(self) -> bool:
|
|
11
|
+
pass
|
|
12
|
+
|
|
13
|
+
@abstractmethod
|
|
14
|
+
def ttl(self) -> float:
|
|
15
|
+
pass
|
|
16
|
+
|
|
17
|
+
@abstractmethod
|
|
18
|
+
def try_get(self, key: str) -> str:
|
|
19
|
+
pass
|
|
20
|
+
|
|
21
|
+
@abstractmethod
|
|
22
|
+
def get_value(self) -> str:
|
|
23
|
+
pass
|
|
24
|
+
|
|
25
|
+
@abstractmethod
|
|
26
|
+
def get_expires_at_ms(self) -> float:
|
|
27
|
+
pass
|
|
28
|
+
|
|
29
|
+
@abstractmethod
|
|
30
|
+
def get_received_at_ms(self) -> float:
|
|
31
|
+
pass
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class TokenResponse:
|
|
35
|
+
def __init__(self, token: TokenInterface):
|
|
36
|
+
self._token = token
|
|
37
|
+
|
|
38
|
+
def get_token(self) -> TokenInterface:
|
|
39
|
+
return self._token
|
|
40
|
+
|
|
41
|
+
def get_ttl_ms(self) -> float:
|
|
42
|
+
return self._token.get_expires_at_ms() - self._token.get_received_at_ms()
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class SimpleToken(TokenInterface):
|
|
46
|
+
def __init__(
|
|
47
|
+
self, value: str, expires_at_ms: float, received_at_ms: float, claims: dict
|
|
48
|
+
) -> None:
|
|
49
|
+
self.value = value
|
|
50
|
+
self.expires_at = expires_at_ms
|
|
51
|
+
self.received_at = received_at_ms
|
|
52
|
+
self.claims = claims
|
|
53
|
+
|
|
54
|
+
def ttl(self) -> float:
|
|
55
|
+
if self.expires_at == -1:
|
|
56
|
+
return -1
|
|
57
|
+
|
|
58
|
+
return self.expires_at - (datetime.now(timezone.utc).timestamp() * 1000)
|
|
59
|
+
|
|
60
|
+
def is_expired(self) -> bool:
|
|
61
|
+
if self.expires_at == -1:
|
|
62
|
+
return False
|
|
63
|
+
|
|
64
|
+
return self.ttl() <= 0
|
|
65
|
+
|
|
66
|
+
def try_get(self, key: str) -> str:
|
|
67
|
+
return self.claims.get(key)
|
|
68
|
+
|
|
69
|
+
def get_value(self) -> str:
|
|
70
|
+
return self.value
|
|
71
|
+
|
|
72
|
+
def get_expires_at_ms(self) -> float:
|
|
73
|
+
return self.expires_at
|
|
74
|
+
|
|
75
|
+
def get_received_at_ms(self) -> float:
|
|
76
|
+
return self.received_at
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class JWToken(TokenInterface):
|
|
80
|
+
|
|
81
|
+
REQUIRED_FIELDS = {"exp"}
|
|
82
|
+
|
|
83
|
+
def __init__(self, token: str):
|
|
84
|
+
self._value = token
|
|
85
|
+
self._decoded = jwt.decode(
|
|
86
|
+
self._value,
|
|
87
|
+
options={"verify_signature": False},
|
|
88
|
+
algorithms=[jwt.get_unverified_header(self._value).get("alg")],
|
|
89
|
+
)
|
|
90
|
+
self._validate_token()
|
|
91
|
+
|
|
92
|
+
def is_expired(self) -> bool:
|
|
93
|
+
exp = self._decoded["exp"]
|
|
94
|
+
if exp == -1:
|
|
95
|
+
return False
|
|
96
|
+
|
|
97
|
+
return (
|
|
98
|
+
self._decoded["exp"] * 1000 <= datetime.now(timezone.utc).timestamp() * 1000
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
def ttl(self) -> float:
|
|
102
|
+
exp = self._decoded["exp"]
|
|
103
|
+
if exp == -1:
|
|
104
|
+
return -1
|
|
105
|
+
|
|
106
|
+
return (
|
|
107
|
+
self._decoded["exp"] * 1000 - datetime.now(timezone.utc).timestamp() * 1000
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
def try_get(self, key: str) -> str:
|
|
111
|
+
return self._decoded.get(key)
|
|
112
|
+
|
|
113
|
+
def get_value(self) -> str:
|
|
114
|
+
return self._value
|
|
115
|
+
|
|
116
|
+
def get_expires_at_ms(self) -> float:
|
|
117
|
+
return float(self._decoded["exp"] * 1000)
|
|
118
|
+
|
|
119
|
+
def get_received_at_ms(self) -> float:
|
|
120
|
+
return datetime.now(timezone.utc).timestamp() * 1000
|
|
121
|
+
|
|
122
|
+
def _validate_token(self):
|
|
123
|
+
actual_fields = {x for x in self._decoded.keys()}
|
|
124
|
+
|
|
125
|
+
if len(self.REQUIRED_FIELDS - actual_fields) != 0:
|
|
126
|
+
raise InvalidTokenSchemaErr(self.REQUIRED_FIELDS - actual_fields)
|
|
@@ -0,0 +1,370 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import logging
|
|
3
|
+
import threading
|
|
4
|
+
from datetime import datetime, timezone
|
|
5
|
+
from time import sleep
|
|
6
|
+
from typing import Any, Awaitable, Callable, Union
|
|
7
|
+
|
|
8
|
+
from redis.auth.err import RequestTokenErr, TokenRenewalErr
|
|
9
|
+
from redis.auth.idp import IdentityProviderInterface
|
|
10
|
+
from redis.auth.token import TokenResponse
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class CredentialsListener:
|
|
16
|
+
"""
|
|
17
|
+
Listeners that will be notified on events related to credentials.
|
|
18
|
+
Accepts callbacks and awaitable callbacks.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(self):
|
|
22
|
+
self._on_next = None
|
|
23
|
+
self._on_error = None
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
def on_next(self) -> Union[Callable[[Any], None], Awaitable]:
|
|
27
|
+
return self._on_next
|
|
28
|
+
|
|
29
|
+
@on_next.setter
|
|
30
|
+
def on_next(self, callback: Union[Callable[[Any], None], Awaitable]) -> None:
|
|
31
|
+
self._on_next = callback
|
|
32
|
+
|
|
33
|
+
@property
|
|
34
|
+
def on_error(self) -> Union[Callable[[Exception], None], Awaitable]:
|
|
35
|
+
return self._on_error
|
|
36
|
+
|
|
37
|
+
@on_error.setter
|
|
38
|
+
def on_error(self, callback: Union[Callable[[Exception], None], Awaitable]) -> None:
|
|
39
|
+
self._on_error = callback
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class RetryPolicy:
|
|
43
|
+
def __init__(self, max_attempts: int, delay_in_ms: float):
|
|
44
|
+
self.max_attempts = max_attempts
|
|
45
|
+
self.delay_in_ms = delay_in_ms
|
|
46
|
+
|
|
47
|
+
def get_max_attempts(self) -> int:
|
|
48
|
+
"""
|
|
49
|
+
Retry attempts before exception will be thrown.
|
|
50
|
+
|
|
51
|
+
:return: int
|
|
52
|
+
"""
|
|
53
|
+
return self.max_attempts
|
|
54
|
+
|
|
55
|
+
def get_delay_in_ms(self) -> float:
|
|
56
|
+
"""
|
|
57
|
+
Delay between retries in seconds.
|
|
58
|
+
|
|
59
|
+
:return: int
|
|
60
|
+
"""
|
|
61
|
+
return self.delay_in_ms
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class TokenManagerConfig:
|
|
65
|
+
def __init__(
|
|
66
|
+
self,
|
|
67
|
+
expiration_refresh_ratio: float,
|
|
68
|
+
lower_refresh_bound_millis: int,
|
|
69
|
+
token_request_execution_timeout_in_ms: int,
|
|
70
|
+
retry_policy: RetryPolicy,
|
|
71
|
+
):
|
|
72
|
+
self._expiration_refresh_ratio = expiration_refresh_ratio
|
|
73
|
+
self._lower_refresh_bound_millis = lower_refresh_bound_millis
|
|
74
|
+
self._token_request_execution_timeout_in_ms = (
|
|
75
|
+
token_request_execution_timeout_in_ms
|
|
76
|
+
)
|
|
77
|
+
self._retry_policy = retry_policy
|
|
78
|
+
|
|
79
|
+
def get_expiration_refresh_ratio(self) -> float:
|
|
80
|
+
"""
|
|
81
|
+
Represents the ratio of a token's lifetime at which a refresh should be triggered. # noqa: E501
|
|
82
|
+
For example, a value of 0.75 means the token should be refreshed
|
|
83
|
+
when 75% of its lifetime has elapsed (or when 25% of its lifetime remains).
|
|
84
|
+
|
|
85
|
+
:return: float
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
return self._expiration_refresh_ratio
|
|
89
|
+
|
|
90
|
+
def get_lower_refresh_bound_millis(self) -> int:
|
|
91
|
+
"""
|
|
92
|
+
Represents the minimum time in milliseconds before token expiration
|
|
93
|
+
to trigger a refresh, in milliseconds.
|
|
94
|
+
This value sets a fixed lower bound for when a token refresh should occur,
|
|
95
|
+
regardless of the token's total lifetime.
|
|
96
|
+
If set to 0 there will be no lower bound and the refresh will be triggered
|
|
97
|
+
based on the expirationRefreshRatio only.
|
|
98
|
+
|
|
99
|
+
:return: int
|
|
100
|
+
"""
|
|
101
|
+
return self._lower_refresh_bound_millis
|
|
102
|
+
|
|
103
|
+
def get_token_request_execution_timeout_in_ms(self) -> int:
|
|
104
|
+
"""
|
|
105
|
+
Represents the maximum time in milliseconds to wait
|
|
106
|
+
for a token request to complete.
|
|
107
|
+
|
|
108
|
+
:return: int
|
|
109
|
+
"""
|
|
110
|
+
return self._token_request_execution_timeout_in_ms
|
|
111
|
+
|
|
112
|
+
def get_retry_policy(self) -> RetryPolicy:
|
|
113
|
+
"""
|
|
114
|
+
Represents the retry policy for token requests.
|
|
115
|
+
|
|
116
|
+
:return: RetryPolicy
|
|
117
|
+
"""
|
|
118
|
+
return self._retry_policy
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class TokenManager:
|
|
122
|
+
def __init__(
|
|
123
|
+
self, identity_provider: IdentityProviderInterface, config: TokenManagerConfig
|
|
124
|
+
):
|
|
125
|
+
self._idp = identity_provider
|
|
126
|
+
self._config = config
|
|
127
|
+
self._next_timer = None
|
|
128
|
+
self._listener = None
|
|
129
|
+
self._init_timer = None
|
|
130
|
+
self._retries = 0
|
|
131
|
+
|
|
132
|
+
def __del__(self):
|
|
133
|
+
logger.info("Token manager are disposed")
|
|
134
|
+
self.stop()
|
|
135
|
+
|
|
136
|
+
def start(
|
|
137
|
+
self,
|
|
138
|
+
listener: CredentialsListener,
|
|
139
|
+
skip_initial: bool = False,
|
|
140
|
+
) -> Callable[[], None]:
|
|
141
|
+
self._listener = listener
|
|
142
|
+
|
|
143
|
+
try:
|
|
144
|
+
loop = asyncio.get_running_loop()
|
|
145
|
+
except RuntimeError:
|
|
146
|
+
# Run loop in a separate thread to unblock main thread.
|
|
147
|
+
loop = asyncio.new_event_loop()
|
|
148
|
+
thread = threading.Thread(
|
|
149
|
+
target=_start_event_loop_in_thread, args=(loop,), daemon=True
|
|
150
|
+
)
|
|
151
|
+
thread.start()
|
|
152
|
+
|
|
153
|
+
# Event to block for initial execution.
|
|
154
|
+
init_event = asyncio.Event()
|
|
155
|
+
self._init_timer = loop.call_later(
|
|
156
|
+
0, self._renew_token, skip_initial, init_event
|
|
157
|
+
)
|
|
158
|
+
logger.info("Token manager started")
|
|
159
|
+
|
|
160
|
+
# Blocks in thread-safe manner.
|
|
161
|
+
asyncio.run_coroutine_threadsafe(init_event.wait(), loop).result()
|
|
162
|
+
return self.stop
|
|
163
|
+
|
|
164
|
+
async def start_async(
|
|
165
|
+
self,
|
|
166
|
+
listener: CredentialsListener,
|
|
167
|
+
block_for_initial: bool = False,
|
|
168
|
+
initial_delay_in_ms: float = 0,
|
|
169
|
+
skip_initial: bool = False,
|
|
170
|
+
) -> Callable[[], None]:
|
|
171
|
+
self._listener = listener
|
|
172
|
+
|
|
173
|
+
loop = asyncio.get_running_loop()
|
|
174
|
+
init_event = asyncio.Event()
|
|
175
|
+
|
|
176
|
+
# Wraps the async callback with async wrapper to schedule with loop.call_later()
|
|
177
|
+
wrapped = _async_to_sync_wrapper(
|
|
178
|
+
loop, self._renew_token_async, skip_initial, init_event
|
|
179
|
+
)
|
|
180
|
+
self._init_timer = loop.call_later(initial_delay_in_ms / 1000, wrapped)
|
|
181
|
+
logger.info("Token manager started")
|
|
182
|
+
|
|
183
|
+
if block_for_initial:
|
|
184
|
+
await init_event.wait()
|
|
185
|
+
|
|
186
|
+
return self.stop
|
|
187
|
+
|
|
188
|
+
def stop(self):
|
|
189
|
+
if self._init_timer is not None:
|
|
190
|
+
self._init_timer.cancel()
|
|
191
|
+
if self._next_timer is not None:
|
|
192
|
+
self._next_timer.cancel()
|
|
193
|
+
|
|
194
|
+
def acquire_token(self, force_refresh=False) -> TokenResponse:
|
|
195
|
+
try:
|
|
196
|
+
token = self._idp.request_token(force_refresh)
|
|
197
|
+
except RequestTokenErr as e:
|
|
198
|
+
if self._retries < self._config.get_retry_policy().get_max_attempts():
|
|
199
|
+
self._retries += 1
|
|
200
|
+
sleep(self._config.get_retry_policy().get_delay_in_ms() / 1000)
|
|
201
|
+
return self.acquire_token(force_refresh)
|
|
202
|
+
else:
|
|
203
|
+
raise e
|
|
204
|
+
|
|
205
|
+
self._retries = 0
|
|
206
|
+
return TokenResponse(token)
|
|
207
|
+
|
|
208
|
+
async def acquire_token_async(self, force_refresh=False) -> TokenResponse:
|
|
209
|
+
try:
|
|
210
|
+
token = self._idp.request_token(force_refresh)
|
|
211
|
+
except RequestTokenErr as e:
|
|
212
|
+
if self._retries < self._config.get_retry_policy().get_max_attempts():
|
|
213
|
+
self._retries += 1
|
|
214
|
+
await asyncio.sleep(
|
|
215
|
+
self._config.get_retry_policy().get_delay_in_ms() / 1000
|
|
216
|
+
)
|
|
217
|
+
return await self.acquire_token_async(force_refresh)
|
|
218
|
+
else:
|
|
219
|
+
raise e
|
|
220
|
+
|
|
221
|
+
self._retries = 0
|
|
222
|
+
return TokenResponse(token)
|
|
223
|
+
|
|
224
|
+
def _calculate_renewal_delay(self, expire_date: float, issue_date: float) -> float:
|
|
225
|
+
delay_for_lower_refresh = self._delay_for_lower_refresh(expire_date)
|
|
226
|
+
delay_for_ratio_refresh = self._delay_for_ratio_refresh(expire_date, issue_date)
|
|
227
|
+
delay = min(delay_for_ratio_refresh, delay_for_lower_refresh)
|
|
228
|
+
|
|
229
|
+
return 0 if delay < 0 else delay / 1000
|
|
230
|
+
|
|
231
|
+
def _delay_for_lower_refresh(self, expire_date: float):
|
|
232
|
+
return (
|
|
233
|
+
expire_date
|
|
234
|
+
- self._config.get_lower_refresh_bound_millis()
|
|
235
|
+
- (datetime.now(timezone.utc).timestamp() * 1000)
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
def _delay_for_ratio_refresh(self, expire_date: float, issue_date: float):
|
|
239
|
+
token_ttl = expire_date - issue_date
|
|
240
|
+
refresh_before = token_ttl - (
|
|
241
|
+
token_ttl * self._config.get_expiration_refresh_ratio()
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
return (
|
|
245
|
+
expire_date
|
|
246
|
+
- refresh_before
|
|
247
|
+
- (datetime.now(timezone.utc).timestamp() * 1000)
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
def _renew_token(
|
|
251
|
+
self, skip_initial: bool = False, init_event: asyncio.Event = None
|
|
252
|
+
):
|
|
253
|
+
"""
|
|
254
|
+
Task to renew token from identity provider.
|
|
255
|
+
Schedules renewal tasks based on token TTL.
|
|
256
|
+
"""
|
|
257
|
+
|
|
258
|
+
try:
|
|
259
|
+
token_res = self.acquire_token(force_refresh=True)
|
|
260
|
+
delay = self._calculate_renewal_delay(
|
|
261
|
+
token_res.get_token().get_expires_at_ms(),
|
|
262
|
+
token_res.get_token().get_received_at_ms(),
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
if token_res.get_token().is_expired():
|
|
266
|
+
raise TokenRenewalErr("Requested token is expired")
|
|
267
|
+
|
|
268
|
+
if self._listener.on_next is None:
|
|
269
|
+
logger.warning(
|
|
270
|
+
"No registered callback for token renewal task. Renewal cancelled"
|
|
271
|
+
)
|
|
272
|
+
return
|
|
273
|
+
|
|
274
|
+
if not skip_initial:
|
|
275
|
+
try:
|
|
276
|
+
self._listener.on_next(token_res.get_token())
|
|
277
|
+
except Exception as e:
|
|
278
|
+
raise TokenRenewalErr(e)
|
|
279
|
+
|
|
280
|
+
if delay <= 0:
|
|
281
|
+
return
|
|
282
|
+
|
|
283
|
+
loop = asyncio.get_running_loop()
|
|
284
|
+
self._next_timer = loop.call_later(delay, self._renew_token)
|
|
285
|
+
logger.info(f"Next token renewal scheduled in {delay} seconds")
|
|
286
|
+
return token_res
|
|
287
|
+
except Exception as e:
|
|
288
|
+
if self._listener.on_error is None:
|
|
289
|
+
raise e
|
|
290
|
+
|
|
291
|
+
self._listener.on_error(e)
|
|
292
|
+
finally:
|
|
293
|
+
if init_event:
|
|
294
|
+
init_event.set()
|
|
295
|
+
|
|
296
|
+
async def _renew_token_async(
|
|
297
|
+
self, skip_initial: bool = False, init_event: asyncio.Event = None
|
|
298
|
+
):
|
|
299
|
+
"""
|
|
300
|
+
Async task to renew tokens from identity provider.
|
|
301
|
+
Schedules renewal tasks based on token TTL.
|
|
302
|
+
"""
|
|
303
|
+
|
|
304
|
+
try:
|
|
305
|
+
token_res = await self.acquire_token_async(force_refresh=True)
|
|
306
|
+
delay = self._calculate_renewal_delay(
|
|
307
|
+
token_res.get_token().get_expires_at_ms(),
|
|
308
|
+
token_res.get_token().get_received_at_ms(),
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
if token_res.get_token().is_expired():
|
|
312
|
+
raise TokenRenewalErr("Requested token is expired")
|
|
313
|
+
|
|
314
|
+
if self._listener.on_next is None:
|
|
315
|
+
logger.warning(
|
|
316
|
+
"No registered callback for token renewal task. Renewal cancelled"
|
|
317
|
+
)
|
|
318
|
+
return
|
|
319
|
+
|
|
320
|
+
if not skip_initial:
|
|
321
|
+
try:
|
|
322
|
+
await self._listener.on_next(token_res.get_token())
|
|
323
|
+
except Exception as e:
|
|
324
|
+
raise TokenRenewalErr(e)
|
|
325
|
+
|
|
326
|
+
if delay <= 0:
|
|
327
|
+
return
|
|
328
|
+
|
|
329
|
+
loop = asyncio.get_running_loop()
|
|
330
|
+
wrapped = _async_to_sync_wrapper(loop, self._renew_token_async)
|
|
331
|
+
logger.info(f"Next token renewal scheduled in {delay} seconds")
|
|
332
|
+
loop.call_later(delay, wrapped)
|
|
333
|
+
except Exception as e:
|
|
334
|
+
if self._listener.on_error is None:
|
|
335
|
+
raise e
|
|
336
|
+
|
|
337
|
+
await self._listener.on_error(e)
|
|
338
|
+
finally:
|
|
339
|
+
if init_event:
|
|
340
|
+
init_event.set()
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
def _async_to_sync_wrapper(loop, coro_func, *args, **kwargs):
|
|
344
|
+
"""
|
|
345
|
+
Wraps an asynchronous function so it can be used with loop.call_later.
|
|
346
|
+
|
|
347
|
+
:param loop: The event loop in which the coroutine will be executed.
|
|
348
|
+
:param coro_func: The coroutine function to wrap.
|
|
349
|
+
:param args: Positional arguments to pass to the coroutine function.
|
|
350
|
+
:param kwargs: Keyword arguments to pass to the coroutine function.
|
|
351
|
+
:return: A regular function suitable for loop.call_later.
|
|
352
|
+
"""
|
|
353
|
+
|
|
354
|
+
def wrapped():
|
|
355
|
+
# Schedule the coroutine in the event loop
|
|
356
|
+
asyncio.ensure_future(coro_func(*args, **kwargs), loop=loop)
|
|
357
|
+
|
|
358
|
+
return wrapped
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
def _start_event_loop_in_thread(event_loop: asyncio.AbstractEventLoop):
|
|
362
|
+
"""
|
|
363
|
+
Starts event loop in a thread.
|
|
364
|
+
Used to be able to schedule tasks using loop.call_later.
|
|
365
|
+
|
|
366
|
+
:param event_loop:
|
|
367
|
+
:return:
|
|
368
|
+
"""
|
|
369
|
+
asyncio.set_event_loop(event_loop)
|
|
370
|
+
event_loop.run_forever()
|
redis/backoff.py
CHANGED
|
@@ -110,5 +110,20 @@ class DecorrelatedJitterBackoff(AbstractBackoff):
|
|
|
110
110
|
return self._previous_backoff
|
|
111
111
|
|
|
112
112
|
|
|
113
|
+
class ExponentialWithJitterBackoff(AbstractBackoff):
|
|
114
|
+
"""Exponential backoff upon failure, with jitter"""
|
|
115
|
+
|
|
116
|
+
def __init__(self, cap: float = DEFAULT_CAP, base: float = DEFAULT_BASE) -> None:
|
|
117
|
+
"""
|
|
118
|
+
`cap`: maximum backoff time in seconds
|
|
119
|
+
`base`: base backoff time in seconds
|
|
120
|
+
"""
|
|
121
|
+
self._cap = cap
|
|
122
|
+
self._base = base
|
|
123
|
+
|
|
124
|
+
def compute(self, failures: int) -> float:
|
|
125
|
+
return min(self._cap, random.random() * self._base * 2**failures)
|
|
126
|
+
|
|
127
|
+
|
|
113
128
|
def default_backoff():
|
|
114
129
|
return EqualJitterBackoff()
|