redis 5.2.1__py3-none-any.whl → 5.3.0b3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- redis/asyncio/client.py +42 -1
- redis/asyncio/cluster.py +54 -0
- redis/asyncio/connection.py +65 -8
- redis/auth/__init__.py +0 -0
- redis/auth/err.py +31 -0
- redis/auth/idp.py +28 -0
- redis/auth/token.py +126 -0
- redis/auth/token_manager.py +370 -0
- redis/client.py +51 -3
- redis/cluster.py +39 -0
- redis/connection.py +73 -0
- redis/credentials.py +40 -1
- redis/event.py +394 -0
- {redis-5.2.1.dist-info → redis-5.3.0b3.dist-info}/METADATA +1 -1
- {redis-5.2.1.dist-info → redis-5.3.0b3.dist-info}/RECORD +18 -12
- {redis-5.2.1.dist-info → redis-5.3.0b3.dist-info}/LICENSE +0 -0
- {redis-5.2.1.dist-info → redis-5.3.0b3.dist-info}/WHEEL +0 -0
- {redis-5.2.1.dist-info → redis-5.3.0b3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,370 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import logging
|
|
3
|
+
import threading
|
|
4
|
+
from datetime import datetime, timezone
|
|
5
|
+
from time import sleep
|
|
6
|
+
from typing import Any, Awaitable, Callable, Union
|
|
7
|
+
|
|
8
|
+
from redis.auth.err import RequestTokenErr, TokenRenewalErr
|
|
9
|
+
from redis.auth.idp import IdentityProviderInterface
|
|
10
|
+
from redis.auth.token import TokenResponse
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class CredentialsListener:
|
|
16
|
+
"""
|
|
17
|
+
Listeners that will be notified on events related to credentials.
|
|
18
|
+
Accepts callbacks and awaitable callbacks.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(self):
|
|
22
|
+
self._on_next = None
|
|
23
|
+
self._on_error = None
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
def on_next(self) -> Union[Callable[[Any], None], Awaitable]:
|
|
27
|
+
return self._on_next
|
|
28
|
+
|
|
29
|
+
@on_next.setter
|
|
30
|
+
def on_next(self, callback: Union[Callable[[Any], None], Awaitable]) -> None:
|
|
31
|
+
self._on_next = callback
|
|
32
|
+
|
|
33
|
+
@property
|
|
34
|
+
def on_error(self) -> Union[Callable[[Exception], None], Awaitable]:
|
|
35
|
+
return self._on_error
|
|
36
|
+
|
|
37
|
+
@on_error.setter
|
|
38
|
+
def on_error(self, callback: Union[Callable[[Exception], None], Awaitable]) -> None:
|
|
39
|
+
self._on_error = callback
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class RetryPolicy:
|
|
43
|
+
def __init__(self, max_attempts: int, delay_in_ms: float):
|
|
44
|
+
self.max_attempts = max_attempts
|
|
45
|
+
self.delay_in_ms = delay_in_ms
|
|
46
|
+
|
|
47
|
+
def get_max_attempts(self) -> int:
|
|
48
|
+
"""
|
|
49
|
+
Retry attempts before exception will be thrown.
|
|
50
|
+
|
|
51
|
+
:return: int
|
|
52
|
+
"""
|
|
53
|
+
return self.max_attempts
|
|
54
|
+
|
|
55
|
+
def get_delay_in_ms(self) -> float:
|
|
56
|
+
"""
|
|
57
|
+
Delay between retries in seconds.
|
|
58
|
+
|
|
59
|
+
:return: int
|
|
60
|
+
"""
|
|
61
|
+
return self.delay_in_ms
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class TokenManagerConfig:
|
|
65
|
+
def __init__(
|
|
66
|
+
self,
|
|
67
|
+
expiration_refresh_ratio: float,
|
|
68
|
+
lower_refresh_bound_millis: int,
|
|
69
|
+
token_request_execution_timeout_in_ms: int,
|
|
70
|
+
retry_policy: RetryPolicy,
|
|
71
|
+
):
|
|
72
|
+
self._expiration_refresh_ratio = expiration_refresh_ratio
|
|
73
|
+
self._lower_refresh_bound_millis = lower_refresh_bound_millis
|
|
74
|
+
self._token_request_execution_timeout_in_ms = (
|
|
75
|
+
token_request_execution_timeout_in_ms
|
|
76
|
+
)
|
|
77
|
+
self._retry_policy = retry_policy
|
|
78
|
+
|
|
79
|
+
def get_expiration_refresh_ratio(self) -> float:
|
|
80
|
+
"""
|
|
81
|
+
Represents the ratio of a token's lifetime at which a refresh should be triggered. # noqa: E501
|
|
82
|
+
For example, a value of 0.75 means the token should be refreshed
|
|
83
|
+
when 75% of its lifetime has elapsed (or when 25% of its lifetime remains).
|
|
84
|
+
|
|
85
|
+
:return: float
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
return self._expiration_refresh_ratio
|
|
89
|
+
|
|
90
|
+
def get_lower_refresh_bound_millis(self) -> int:
|
|
91
|
+
"""
|
|
92
|
+
Represents the minimum time in milliseconds before token expiration
|
|
93
|
+
to trigger a refresh, in milliseconds.
|
|
94
|
+
This value sets a fixed lower bound for when a token refresh should occur,
|
|
95
|
+
regardless of the token's total lifetime.
|
|
96
|
+
If set to 0 there will be no lower bound and the refresh will be triggered
|
|
97
|
+
based on the expirationRefreshRatio only.
|
|
98
|
+
|
|
99
|
+
:return: int
|
|
100
|
+
"""
|
|
101
|
+
return self._lower_refresh_bound_millis
|
|
102
|
+
|
|
103
|
+
def get_token_request_execution_timeout_in_ms(self) -> int:
|
|
104
|
+
"""
|
|
105
|
+
Represents the maximum time in milliseconds to wait
|
|
106
|
+
for a token request to complete.
|
|
107
|
+
|
|
108
|
+
:return: int
|
|
109
|
+
"""
|
|
110
|
+
return self._token_request_execution_timeout_in_ms
|
|
111
|
+
|
|
112
|
+
def get_retry_policy(self) -> RetryPolicy:
|
|
113
|
+
"""
|
|
114
|
+
Represents the retry policy for token requests.
|
|
115
|
+
|
|
116
|
+
:return: RetryPolicy
|
|
117
|
+
"""
|
|
118
|
+
return self._retry_policy
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class TokenManager:
|
|
122
|
+
def __init__(
|
|
123
|
+
self, identity_provider: IdentityProviderInterface, config: TokenManagerConfig
|
|
124
|
+
):
|
|
125
|
+
self._idp = identity_provider
|
|
126
|
+
self._config = config
|
|
127
|
+
self._next_timer = None
|
|
128
|
+
self._listener = None
|
|
129
|
+
self._init_timer = None
|
|
130
|
+
self._retries = 0
|
|
131
|
+
|
|
132
|
+
def __del__(self):
|
|
133
|
+
logger.info("Token manager are disposed")
|
|
134
|
+
self.stop()
|
|
135
|
+
|
|
136
|
+
def start(
|
|
137
|
+
self,
|
|
138
|
+
listener: CredentialsListener,
|
|
139
|
+
skip_initial: bool = False,
|
|
140
|
+
) -> Callable[[], None]:
|
|
141
|
+
self._listener = listener
|
|
142
|
+
|
|
143
|
+
try:
|
|
144
|
+
loop = asyncio.get_running_loop()
|
|
145
|
+
except RuntimeError:
|
|
146
|
+
# Run loop in a separate thread to unblock main thread.
|
|
147
|
+
loop = asyncio.new_event_loop()
|
|
148
|
+
thread = threading.Thread(
|
|
149
|
+
target=_start_event_loop_in_thread, args=(loop,), daemon=True
|
|
150
|
+
)
|
|
151
|
+
thread.start()
|
|
152
|
+
|
|
153
|
+
# Event to block for initial execution.
|
|
154
|
+
init_event = asyncio.Event()
|
|
155
|
+
self._init_timer = loop.call_later(
|
|
156
|
+
0, self._renew_token, skip_initial, init_event
|
|
157
|
+
)
|
|
158
|
+
logger.info("Token manager started")
|
|
159
|
+
|
|
160
|
+
# Blocks in thread-safe manner.
|
|
161
|
+
asyncio.run_coroutine_threadsafe(init_event.wait(), loop).result()
|
|
162
|
+
return self.stop
|
|
163
|
+
|
|
164
|
+
async def start_async(
|
|
165
|
+
self,
|
|
166
|
+
listener: CredentialsListener,
|
|
167
|
+
block_for_initial: bool = False,
|
|
168
|
+
initial_delay_in_ms: float = 0,
|
|
169
|
+
skip_initial: bool = False,
|
|
170
|
+
) -> Callable[[], None]:
|
|
171
|
+
self._listener = listener
|
|
172
|
+
|
|
173
|
+
loop = asyncio.get_running_loop()
|
|
174
|
+
init_event = asyncio.Event()
|
|
175
|
+
|
|
176
|
+
# Wraps the async callback with async wrapper to schedule with loop.call_later()
|
|
177
|
+
wrapped = _async_to_sync_wrapper(
|
|
178
|
+
loop, self._renew_token_async, skip_initial, init_event
|
|
179
|
+
)
|
|
180
|
+
self._init_timer = loop.call_later(initial_delay_in_ms / 1000, wrapped)
|
|
181
|
+
logger.info("Token manager started")
|
|
182
|
+
|
|
183
|
+
if block_for_initial:
|
|
184
|
+
await init_event.wait()
|
|
185
|
+
|
|
186
|
+
return self.stop
|
|
187
|
+
|
|
188
|
+
def stop(self):
|
|
189
|
+
if self._init_timer is not None:
|
|
190
|
+
self._init_timer.cancel()
|
|
191
|
+
if self._next_timer is not None:
|
|
192
|
+
self._next_timer.cancel()
|
|
193
|
+
|
|
194
|
+
def acquire_token(self, force_refresh=False) -> TokenResponse:
|
|
195
|
+
try:
|
|
196
|
+
token = self._idp.request_token(force_refresh)
|
|
197
|
+
except RequestTokenErr as e:
|
|
198
|
+
if self._retries < self._config.get_retry_policy().get_max_attempts():
|
|
199
|
+
self._retries += 1
|
|
200
|
+
sleep(self._config.get_retry_policy().get_delay_in_ms() / 1000)
|
|
201
|
+
return self.acquire_token(force_refresh)
|
|
202
|
+
else:
|
|
203
|
+
raise e
|
|
204
|
+
|
|
205
|
+
self._retries = 0
|
|
206
|
+
return TokenResponse(token)
|
|
207
|
+
|
|
208
|
+
async def acquire_token_async(self, force_refresh=False) -> TokenResponse:
|
|
209
|
+
try:
|
|
210
|
+
token = self._idp.request_token(force_refresh)
|
|
211
|
+
except RequestTokenErr as e:
|
|
212
|
+
if self._retries < self._config.get_retry_policy().get_max_attempts():
|
|
213
|
+
self._retries += 1
|
|
214
|
+
await asyncio.sleep(
|
|
215
|
+
self._config.get_retry_policy().get_delay_in_ms() / 1000
|
|
216
|
+
)
|
|
217
|
+
return await self.acquire_token_async(force_refresh)
|
|
218
|
+
else:
|
|
219
|
+
raise e
|
|
220
|
+
|
|
221
|
+
self._retries = 0
|
|
222
|
+
return TokenResponse(token)
|
|
223
|
+
|
|
224
|
+
def _calculate_renewal_delay(self, expire_date: float, issue_date: float) -> float:
|
|
225
|
+
delay_for_lower_refresh = self._delay_for_lower_refresh(expire_date)
|
|
226
|
+
delay_for_ratio_refresh = self._delay_for_ratio_refresh(expire_date, issue_date)
|
|
227
|
+
delay = min(delay_for_ratio_refresh, delay_for_lower_refresh)
|
|
228
|
+
|
|
229
|
+
return 0 if delay < 0 else delay / 1000
|
|
230
|
+
|
|
231
|
+
def _delay_for_lower_refresh(self, expire_date: float):
|
|
232
|
+
return (
|
|
233
|
+
expire_date
|
|
234
|
+
- self._config.get_lower_refresh_bound_millis()
|
|
235
|
+
- (datetime.now(timezone.utc).timestamp() * 1000)
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
def _delay_for_ratio_refresh(self, expire_date: float, issue_date: float):
|
|
239
|
+
token_ttl = expire_date - issue_date
|
|
240
|
+
refresh_before = token_ttl - (
|
|
241
|
+
token_ttl * self._config.get_expiration_refresh_ratio()
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
return (
|
|
245
|
+
expire_date
|
|
246
|
+
- refresh_before
|
|
247
|
+
- (datetime.now(timezone.utc).timestamp() * 1000)
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
def _renew_token(
|
|
251
|
+
self, skip_initial: bool = False, init_event: asyncio.Event = None
|
|
252
|
+
):
|
|
253
|
+
"""
|
|
254
|
+
Task to renew token from identity provider.
|
|
255
|
+
Schedules renewal tasks based on token TTL.
|
|
256
|
+
"""
|
|
257
|
+
|
|
258
|
+
try:
|
|
259
|
+
token_res = self.acquire_token(force_refresh=True)
|
|
260
|
+
delay = self._calculate_renewal_delay(
|
|
261
|
+
token_res.get_token().get_expires_at_ms(),
|
|
262
|
+
token_res.get_token().get_received_at_ms(),
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
if token_res.get_token().is_expired():
|
|
266
|
+
raise TokenRenewalErr("Requested token is expired")
|
|
267
|
+
|
|
268
|
+
if self._listener.on_next is None:
|
|
269
|
+
logger.warning(
|
|
270
|
+
"No registered callback for token renewal task. Renewal cancelled"
|
|
271
|
+
)
|
|
272
|
+
return
|
|
273
|
+
|
|
274
|
+
if not skip_initial:
|
|
275
|
+
try:
|
|
276
|
+
self._listener.on_next(token_res.get_token())
|
|
277
|
+
except Exception as e:
|
|
278
|
+
raise TokenRenewalErr(e)
|
|
279
|
+
|
|
280
|
+
if delay <= 0:
|
|
281
|
+
return
|
|
282
|
+
|
|
283
|
+
loop = asyncio.get_running_loop()
|
|
284
|
+
self._next_timer = loop.call_later(delay, self._renew_token)
|
|
285
|
+
logger.info(f"Next token renewal scheduled in {delay} seconds")
|
|
286
|
+
return token_res
|
|
287
|
+
except Exception as e:
|
|
288
|
+
if self._listener.on_error is None:
|
|
289
|
+
raise e
|
|
290
|
+
|
|
291
|
+
self._listener.on_error(e)
|
|
292
|
+
finally:
|
|
293
|
+
if init_event:
|
|
294
|
+
init_event.set()
|
|
295
|
+
|
|
296
|
+
async def _renew_token_async(
|
|
297
|
+
self, skip_initial: bool = False, init_event: asyncio.Event = None
|
|
298
|
+
):
|
|
299
|
+
"""
|
|
300
|
+
Async task to renew tokens from identity provider.
|
|
301
|
+
Schedules renewal tasks based on token TTL.
|
|
302
|
+
"""
|
|
303
|
+
|
|
304
|
+
try:
|
|
305
|
+
token_res = await self.acquire_token_async(force_refresh=True)
|
|
306
|
+
delay = self._calculate_renewal_delay(
|
|
307
|
+
token_res.get_token().get_expires_at_ms(),
|
|
308
|
+
token_res.get_token().get_received_at_ms(),
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
if token_res.get_token().is_expired():
|
|
312
|
+
raise TokenRenewalErr("Requested token is expired")
|
|
313
|
+
|
|
314
|
+
if self._listener.on_next is None:
|
|
315
|
+
logger.warning(
|
|
316
|
+
"No registered callback for token renewal task. Renewal cancelled"
|
|
317
|
+
)
|
|
318
|
+
return
|
|
319
|
+
|
|
320
|
+
if not skip_initial:
|
|
321
|
+
try:
|
|
322
|
+
await self._listener.on_next(token_res.get_token())
|
|
323
|
+
except Exception as e:
|
|
324
|
+
raise TokenRenewalErr(e)
|
|
325
|
+
|
|
326
|
+
if delay <= 0:
|
|
327
|
+
return
|
|
328
|
+
|
|
329
|
+
loop = asyncio.get_running_loop()
|
|
330
|
+
wrapped = _async_to_sync_wrapper(loop, self._renew_token_async)
|
|
331
|
+
logger.info(f"Next token renewal scheduled in {delay} seconds")
|
|
332
|
+
loop.call_later(delay, wrapped)
|
|
333
|
+
except Exception as e:
|
|
334
|
+
if self._listener.on_error is None:
|
|
335
|
+
raise e
|
|
336
|
+
|
|
337
|
+
await self._listener.on_error(e)
|
|
338
|
+
finally:
|
|
339
|
+
if init_event:
|
|
340
|
+
init_event.set()
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
def _async_to_sync_wrapper(loop, coro_func, *args, **kwargs):
|
|
344
|
+
"""
|
|
345
|
+
Wraps an asynchronous function so it can be used with loop.call_later.
|
|
346
|
+
|
|
347
|
+
:param loop: The event loop in which the coroutine will be executed.
|
|
348
|
+
:param coro_func: The coroutine function to wrap.
|
|
349
|
+
:param args: Positional arguments to pass to the coroutine function.
|
|
350
|
+
:param kwargs: Keyword arguments to pass to the coroutine function.
|
|
351
|
+
:return: A regular function suitable for loop.call_later.
|
|
352
|
+
"""
|
|
353
|
+
|
|
354
|
+
def wrapped():
|
|
355
|
+
# Schedule the coroutine in the event loop
|
|
356
|
+
asyncio.ensure_future(coro_func(*args, **kwargs), loop=loop)
|
|
357
|
+
|
|
358
|
+
return wrapped
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
def _start_event_loop_in_thread(event_loop: asyncio.AbstractEventLoop):
|
|
362
|
+
"""
|
|
363
|
+
Starts event loop in a thread.
|
|
364
|
+
Used to be able to schedule tasks using loop.call_later.
|
|
365
|
+
|
|
366
|
+
:param event_loop:
|
|
367
|
+
:return:
|
|
368
|
+
"""
|
|
369
|
+
asyncio.set_event_loop(event_loop)
|
|
370
|
+
event_loop.run_forever()
|
redis/client.py
CHANGED
|
@@ -27,6 +27,13 @@ from redis.connection import (
|
|
|
27
27
|
UnixDomainSocketConnection,
|
|
28
28
|
)
|
|
29
29
|
from redis.credentials import CredentialProvider
|
|
30
|
+
from redis.event import (
|
|
31
|
+
AfterPooledConnectionsInstantiationEvent,
|
|
32
|
+
AfterPubSubConnectionInstantiationEvent,
|
|
33
|
+
AfterSingleConnectionInstantiationEvent,
|
|
34
|
+
ClientType,
|
|
35
|
+
EventDispatcher,
|
|
36
|
+
)
|
|
30
37
|
from redis.exceptions import (
|
|
31
38
|
ConnectionError,
|
|
32
39
|
ExecAbortError,
|
|
@@ -213,6 +220,7 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands):
|
|
|
213
220
|
protocol: Optional[int] = 2,
|
|
214
221
|
cache: Optional[CacheInterface] = None,
|
|
215
222
|
cache_config: Optional[CacheConfig] = None,
|
|
223
|
+
event_dispatcher: Optional[EventDispatcher] = None,
|
|
216
224
|
) -> None:
|
|
217
225
|
"""
|
|
218
226
|
Initialize a new Redis client.
|
|
@@ -227,6 +235,10 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands):
|
|
|
227
235
|
if `True`, connection pool is not used. In that case `Redis`
|
|
228
236
|
instance use is not thread safe.
|
|
229
237
|
"""
|
|
238
|
+
if event_dispatcher is None:
|
|
239
|
+
self._event_dispatcher = EventDispatcher()
|
|
240
|
+
else:
|
|
241
|
+
self._event_dispatcher = event_dispatcher
|
|
230
242
|
if not connection_pool:
|
|
231
243
|
if charset is not None:
|
|
232
244
|
warnings.warn(
|
|
@@ -313,9 +325,19 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands):
|
|
|
313
325
|
}
|
|
314
326
|
)
|
|
315
327
|
connection_pool = ConnectionPool(**kwargs)
|
|
328
|
+
self._event_dispatcher.dispatch(
|
|
329
|
+
AfterPooledConnectionsInstantiationEvent(
|
|
330
|
+
[connection_pool], ClientType.SYNC, credential_provider
|
|
331
|
+
)
|
|
332
|
+
)
|
|
316
333
|
self.auto_close_connection_pool = True
|
|
317
334
|
else:
|
|
318
335
|
self.auto_close_connection_pool = False
|
|
336
|
+
self._event_dispatcher.dispatch(
|
|
337
|
+
AfterPooledConnectionsInstantiationEvent(
|
|
338
|
+
[connection_pool], ClientType.SYNC, credential_provider
|
|
339
|
+
)
|
|
340
|
+
)
|
|
319
341
|
|
|
320
342
|
self.connection_pool = connection_pool
|
|
321
343
|
|
|
@@ -325,9 +347,16 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands):
|
|
|
325
347
|
]:
|
|
326
348
|
raise RedisError("Client caching is only supported with RESP version 3")
|
|
327
349
|
|
|
350
|
+
self.single_connection_lock = threading.Lock()
|
|
328
351
|
self.connection = None
|
|
329
|
-
|
|
352
|
+
self._single_connection_client = single_connection_client
|
|
353
|
+
if self._single_connection_client:
|
|
330
354
|
self.connection = self.connection_pool.get_connection("_")
|
|
355
|
+
self._event_dispatcher.dispatch(
|
|
356
|
+
AfterSingleConnectionInstantiationEvent(
|
|
357
|
+
self.connection, ClientType.SYNC, self.single_connection_lock
|
|
358
|
+
)
|
|
359
|
+
)
|
|
331
360
|
|
|
332
361
|
self.response_callbacks = CaseInsensitiveDict(_RedisCallbacks)
|
|
333
362
|
|
|
@@ -500,7 +529,9 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands):
|
|
|
500
529
|
subscribe to channels and listen for messages that get published to
|
|
501
530
|
them.
|
|
502
531
|
"""
|
|
503
|
-
return PubSub(
|
|
532
|
+
return PubSub(
|
|
533
|
+
self.connection_pool, event_dispatcher=self._event_dispatcher, **kwargs
|
|
534
|
+
)
|
|
504
535
|
|
|
505
536
|
def monitor(self):
|
|
506
537
|
return Monitor(self.connection_pool)
|
|
@@ -563,6 +594,9 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands):
|
|
|
563
594
|
pool = self.connection_pool
|
|
564
595
|
command_name = args[0]
|
|
565
596
|
conn = self.connection or pool.get_connection(command_name, **options)
|
|
597
|
+
|
|
598
|
+
if self._single_connection_client:
|
|
599
|
+
self.single_connection_lock.acquire()
|
|
566
600
|
try:
|
|
567
601
|
return conn.retry.call_with_retry(
|
|
568
602
|
lambda: self._send_command_parse_response(
|
|
@@ -571,6 +605,8 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands):
|
|
|
571
605
|
lambda error: self._disconnect_raise(conn, error),
|
|
572
606
|
)
|
|
573
607
|
finally:
|
|
608
|
+
if self._single_connection_client:
|
|
609
|
+
self.single_connection_lock.release()
|
|
574
610
|
if not self.connection:
|
|
575
611
|
pool.release(conn)
|
|
576
612
|
|
|
@@ -691,6 +727,7 @@ class PubSub:
|
|
|
691
727
|
ignore_subscribe_messages: bool = False,
|
|
692
728
|
encoder: Optional["Encoder"] = None,
|
|
693
729
|
push_handler_func: Union[None, Callable[[str], None]] = None,
|
|
730
|
+
event_dispatcher: Optional["EventDispatcher"] = None,
|
|
694
731
|
):
|
|
695
732
|
self.connection_pool = connection_pool
|
|
696
733
|
self.shard_hint = shard_hint
|
|
@@ -701,6 +738,11 @@ class PubSub:
|
|
|
701
738
|
# to lookup channel and pattern names for callback handlers.
|
|
702
739
|
self.encoder = encoder
|
|
703
740
|
self.push_handler_func = push_handler_func
|
|
741
|
+
if event_dispatcher is None:
|
|
742
|
+
self._event_dispatcher = EventDispatcher()
|
|
743
|
+
else:
|
|
744
|
+
self._event_dispatcher = event_dispatcher
|
|
745
|
+
self._lock = threading.Lock()
|
|
704
746
|
if self.encoder is None:
|
|
705
747
|
self.encoder = self.connection_pool.get_encoder()
|
|
706
748
|
self.health_check_response_b = self.encoder.encode(self.HEALTH_CHECK_MESSAGE)
|
|
@@ -791,11 +833,17 @@ class PubSub:
|
|
|
791
833
|
self.connection.register_connect_callback(self.on_connect)
|
|
792
834
|
if self.push_handler_func is not None and not HIREDIS_AVAILABLE:
|
|
793
835
|
self.connection._parser.set_pubsub_push_handler(self.push_handler_func)
|
|
836
|
+
self._event_dispatcher.dispatch(
|
|
837
|
+
AfterPubSubConnectionInstantiationEvent(
|
|
838
|
+
self.connection, self.connection_pool, ClientType.SYNC, self._lock
|
|
839
|
+
)
|
|
840
|
+
)
|
|
794
841
|
connection = self.connection
|
|
795
842
|
kwargs = {"check_health": not self.subscribed}
|
|
796
843
|
if not self.subscribed:
|
|
797
844
|
self.clean_health_check_responses()
|
|
798
|
-
self.
|
|
845
|
+
with self._lock:
|
|
846
|
+
self._execute(connection, connection.send_command, *args, **kwargs)
|
|
799
847
|
|
|
800
848
|
def clean_health_check_responses(self) -> None:
|
|
801
849
|
"""
|
redis/cluster.py
CHANGED
|
@@ -15,6 +15,12 @@ from redis.commands import READ_COMMANDS, RedisClusterCommands
|
|
|
15
15
|
from redis.commands.helpers import list_or_args
|
|
16
16
|
from redis.connection import ConnectionPool, DefaultParser, parse_url
|
|
17
17
|
from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot
|
|
18
|
+
from redis.event import (
|
|
19
|
+
AfterPooledConnectionsInstantiationEvent,
|
|
20
|
+
AfterPubSubConnectionInstantiationEvent,
|
|
21
|
+
ClientType,
|
|
22
|
+
EventDispatcher,
|
|
23
|
+
)
|
|
18
24
|
from redis.exceptions import (
|
|
19
25
|
AskError,
|
|
20
26
|
AuthenticationError,
|
|
@@ -505,6 +511,7 @@ class RedisCluster(AbstractRedisCluster, RedisClusterCommands):
|
|
|
505
511
|
address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None,
|
|
506
512
|
cache: Optional[CacheInterface] = None,
|
|
507
513
|
cache_config: Optional[CacheConfig] = None,
|
|
514
|
+
event_dispatcher: Optional[EventDispatcher] = None,
|
|
508
515
|
**kwargs,
|
|
509
516
|
):
|
|
510
517
|
"""
|
|
@@ -638,6 +645,10 @@ class RedisCluster(AbstractRedisCluster, RedisClusterCommands):
|
|
|
638
645
|
self.read_from_replicas = read_from_replicas
|
|
639
646
|
self.reinitialize_counter = 0
|
|
640
647
|
self.reinitialize_steps = reinitialize_steps
|
|
648
|
+
if event_dispatcher is None:
|
|
649
|
+
self._event_dispatcher = EventDispatcher()
|
|
650
|
+
else:
|
|
651
|
+
self._event_dispatcher = event_dispatcher
|
|
641
652
|
self.nodes_manager = NodesManager(
|
|
642
653
|
startup_nodes=startup_nodes,
|
|
643
654
|
from_url=from_url,
|
|
@@ -646,6 +657,7 @@ class RedisCluster(AbstractRedisCluster, RedisClusterCommands):
|
|
|
646
657
|
address_remap=address_remap,
|
|
647
658
|
cache=cache,
|
|
648
659
|
cache_config=cache_config,
|
|
660
|
+
event_dispatcher=self._event_dispatcher,
|
|
649
661
|
**kwargs,
|
|
650
662
|
)
|
|
651
663
|
|
|
@@ -1332,6 +1344,7 @@ class NodesManager:
|
|
|
1332
1344
|
cache: Optional[CacheInterface] = None,
|
|
1333
1345
|
cache_config: Optional[CacheConfig] = None,
|
|
1334
1346
|
cache_factory: Optional[CacheFactoryInterface] = None,
|
|
1347
|
+
event_dispatcher: Optional[EventDispatcher] = None,
|
|
1335
1348
|
**kwargs,
|
|
1336
1349
|
):
|
|
1337
1350
|
self.nodes_cache = {}
|
|
@@ -1353,6 +1366,13 @@ class NodesManager:
|
|
|
1353
1366
|
if lock is None:
|
|
1354
1367
|
lock = threading.Lock()
|
|
1355
1368
|
self._lock = lock
|
|
1369
|
+
if event_dispatcher is None:
|
|
1370
|
+
self._event_dispatcher = EventDispatcher()
|
|
1371
|
+
else:
|
|
1372
|
+
self._event_dispatcher = event_dispatcher
|
|
1373
|
+
self._credential_provider = self.connection_kwargs.get(
|
|
1374
|
+
"credential_provider", None
|
|
1375
|
+
)
|
|
1356
1376
|
self.initialize()
|
|
1357
1377
|
|
|
1358
1378
|
def get_node(self, host=None, port=None, node_name=None):
|
|
@@ -1479,11 +1499,19 @@ class NodesManager:
|
|
|
1479
1499
|
"""
|
|
1480
1500
|
This function will create a redis connection to all nodes in :nodes:
|
|
1481
1501
|
"""
|
|
1502
|
+
connection_pools = []
|
|
1482
1503
|
for node in nodes:
|
|
1483
1504
|
if node.redis_connection is None:
|
|
1484
1505
|
node.redis_connection = self.create_redis_node(
|
|
1485
1506
|
host=node.host, port=node.port, **self.connection_kwargs
|
|
1486
1507
|
)
|
|
1508
|
+
connection_pools.append(node.redis_connection.connection_pool)
|
|
1509
|
+
|
|
1510
|
+
self._event_dispatcher.dispatch(
|
|
1511
|
+
AfterPooledConnectionsInstantiationEvent(
|
|
1512
|
+
connection_pools, ClientType.SYNC, self._credential_provider
|
|
1513
|
+
)
|
|
1514
|
+
)
|
|
1487
1515
|
|
|
1488
1516
|
def create_redis_node(self, host, port, **kwargs):
|
|
1489
1517
|
if self.from_url:
|
|
@@ -1698,6 +1726,7 @@ class ClusterPubSub(PubSub):
|
|
|
1698
1726
|
host=None,
|
|
1699
1727
|
port=None,
|
|
1700
1728
|
push_handler_func=None,
|
|
1729
|
+
event_dispatcher: Optional["EventDispatcher"] = None,
|
|
1701
1730
|
**kwargs,
|
|
1702
1731
|
):
|
|
1703
1732
|
"""
|
|
@@ -1723,10 +1752,15 @@ class ClusterPubSub(PubSub):
|
|
|
1723
1752
|
self.cluster = redis_cluster
|
|
1724
1753
|
self.node_pubsub_mapping = {}
|
|
1725
1754
|
self._pubsubs_generator = self._pubsubs_generator()
|
|
1755
|
+
if event_dispatcher is None:
|
|
1756
|
+
self._event_dispatcher = EventDispatcher()
|
|
1757
|
+
else:
|
|
1758
|
+
self._event_dispatcher = event_dispatcher
|
|
1726
1759
|
super().__init__(
|
|
1727
1760
|
connection_pool=connection_pool,
|
|
1728
1761
|
encoder=redis_cluster.encoder,
|
|
1729
1762
|
push_handler_func=push_handler_func,
|
|
1763
|
+
event_dispatcher=self._event_dispatcher,
|
|
1730
1764
|
**kwargs,
|
|
1731
1765
|
)
|
|
1732
1766
|
|
|
@@ -1813,6 +1847,11 @@ class ClusterPubSub(PubSub):
|
|
|
1813
1847
|
self.connection.register_connect_callback(self.on_connect)
|
|
1814
1848
|
if self.push_handler_func is not None and not HIREDIS_AVAILABLE:
|
|
1815
1849
|
self.connection._parser.set_pubsub_push_handler(self.push_handler_func)
|
|
1850
|
+
self._event_dispatcher.dispatch(
|
|
1851
|
+
AfterPubSubConnectionInstantiationEvent(
|
|
1852
|
+
self.connection, self.connection_pool, ClientType.SYNC, self._lock
|
|
1853
|
+
)
|
|
1854
|
+
)
|
|
1816
1855
|
connection = self.connection
|
|
1817
1856
|
self._execute(connection, connection.send_command, *args)
|
|
1818
1857
|
|