redis 5.2.1__py3-none-any.whl → 5.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
redis/auth/token.py ADDED
@@ -0,0 +1,126 @@
1
+ from abc import ABC, abstractmethod
2
+ from datetime import datetime, timezone
3
+
4
+ import jwt
5
+ from redis.auth.err import InvalidTokenSchemaErr
6
+
7
+
8
+ class TokenInterface(ABC):
9
+ @abstractmethod
10
+ def is_expired(self) -> bool:
11
+ pass
12
+
13
+ @abstractmethod
14
+ def ttl(self) -> float:
15
+ pass
16
+
17
+ @abstractmethod
18
+ def try_get(self, key: str) -> str:
19
+ pass
20
+
21
+ @abstractmethod
22
+ def get_value(self) -> str:
23
+ pass
24
+
25
+ @abstractmethod
26
+ def get_expires_at_ms(self) -> float:
27
+ pass
28
+
29
+ @abstractmethod
30
+ def get_received_at_ms(self) -> float:
31
+ pass
32
+
33
+
34
+ class TokenResponse:
35
+ def __init__(self, token: TokenInterface):
36
+ self._token = token
37
+
38
+ def get_token(self) -> TokenInterface:
39
+ return self._token
40
+
41
+ def get_ttl_ms(self) -> float:
42
+ return self._token.get_expires_at_ms() - self._token.get_received_at_ms()
43
+
44
+
45
+ class SimpleToken(TokenInterface):
46
+ def __init__(
47
+ self, value: str, expires_at_ms: float, received_at_ms: float, claims: dict
48
+ ) -> None:
49
+ self.value = value
50
+ self.expires_at = expires_at_ms
51
+ self.received_at = received_at_ms
52
+ self.claims = claims
53
+
54
+ def ttl(self) -> float:
55
+ if self.expires_at == -1:
56
+ return -1
57
+
58
+ return self.expires_at - (datetime.now(timezone.utc).timestamp() * 1000)
59
+
60
+ def is_expired(self) -> bool:
61
+ if self.expires_at == -1:
62
+ return False
63
+
64
+ return self.ttl() <= 0
65
+
66
+ def try_get(self, key: str) -> str:
67
+ return self.claims.get(key)
68
+
69
+ def get_value(self) -> str:
70
+ return self.value
71
+
72
+ def get_expires_at_ms(self) -> float:
73
+ return self.expires_at
74
+
75
+ def get_received_at_ms(self) -> float:
76
+ return self.received_at
77
+
78
+
79
+ class JWToken(TokenInterface):
80
+
81
+ REQUIRED_FIELDS = {"exp"}
82
+
83
+ def __init__(self, token: str):
84
+ self._value = token
85
+ self._decoded = jwt.decode(
86
+ self._value,
87
+ options={"verify_signature": False},
88
+ algorithms=[jwt.get_unverified_header(self._value).get("alg")],
89
+ )
90
+ self._validate_token()
91
+
92
+ def is_expired(self) -> bool:
93
+ exp = self._decoded["exp"]
94
+ if exp == -1:
95
+ return False
96
+
97
+ return (
98
+ self._decoded["exp"] * 1000 <= datetime.now(timezone.utc).timestamp() * 1000
99
+ )
100
+
101
+ def ttl(self) -> float:
102
+ exp = self._decoded["exp"]
103
+ if exp == -1:
104
+ return -1
105
+
106
+ return (
107
+ self._decoded["exp"] * 1000 - datetime.now(timezone.utc).timestamp() * 1000
108
+ )
109
+
110
+ def try_get(self, key: str) -> str:
111
+ return self._decoded.get(key)
112
+
113
+ def get_value(self) -> str:
114
+ return self._value
115
+
116
+ def get_expires_at_ms(self) -> float:
117
+ return float(self._decoded["exp"] * 1000)
118
+
119
+ def get_received_at_ms(self) -> float:
120
+ return datetime.now(timezone.utc).timestamp() * 1000
121
+
122
+ def _validate_token(self):
123
+ actual_fields = {x for x in self._decoded.keys()}
124
+
125
+ if len(self.REQUIRED_FIELDS - actual_fields) != 0:
126
+ raise InvalidTokenSchemaErr(self.REQUIRED_FIELDS - actual_fields)
@@ -0,0 +1,370 @@
1
+ import asyncio
2
+ import logging
3
+ import threading
4
+ from datetime import datetime, timezone
5
+ from time import sleep
6
+ from typing import Any, Awaitable, Callable, Union
7
+
8
+ from redis.auth.err import RequestTokenErr, TokenRenewalErr
9
+ from redis.auth.idp import IdentityProviderInterface
10
+ from redis.auth.token import TokenResponse
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class CredentialsListener:
16
+ """
17
+ Listeners that will be notified on events related to credentials.
18
+ Accepts callbacks and awaitable callbacks.
19
+ """
20
+
21
+ def __init__(self):
22
+ self._on_next = None
23
+ self._on_error = None
24
+
25
+ @property
26
+ def on_next(self) -> Union[Callable[[Any], None], Awaitable]:
27
+ return self._on_next
28
+
29
+ @on_next.setter
30
+ def on_next(self, callback: Union[Callable[[Any], None], Awaitable]) -> None:
31
+ self._on_next = callback
32
+
33
+ @property
34
+ def on_error(self) -> Union[Callable[[Exception], None], Awaitable]:
35
+ return self._on_error
36
+
37
+ @on_error.setter
38
+ def on_error(self, callback: Union[Callable[[Exception], None], Awaitable]) -> None:
39
+ self._on_error = callback
40
+
41
+
42
+ class RetryPolicy:
43
+ def __init__(self, max_attempts: int, delay_in_ms: float):
44
+ self.max_attempts = max_attempts
45
+ self.delay_in_ms = delay_in_ms
46
+
47
+ def get_max_attempts(self) -> int:
48
+ """
49
+ Retry attempts before exception will be thrown.
50
+
51
+ :return: int
52
+ """
53
+ return self.max_attempts
54
+
55
+ def get_delay_in_ms(self) -> float:
56
+ """
57
+ Delay between retries in seconds.
58
+
59
+ :return: int
60
+ """
61
+ return self.delay_in_ms
62
+
63
+
64
+ class TokenManagerConfig:
65
+ def __init__(
66
+ self,
67
+ expiration_refresh_ratio: float,
68
+ lower_refresh_bound_millis: int,
69
+ token_request_execution_timeout_in_ms: int,
70
+ retry_policy: RetryPolicy,
71
+ ):
72
+ self._expiration_refresh_ratio = expiration_refresh_ratio
73
+ self._lower_refresh_bound_millis = lower_refresh_bound_millis
74
+ self._token_request_execution_timeout_in_ms = (
75
+ token_request_execution_timeout_in_ms
76
+ )
77
+ self._retry_policy = retry_policy
78
+
79
+ def get_expiration_refresh_ratio(self) -> float:
80
+ """
81
+ Represents the ratio of a token's lifetime at which a refresh should be triggered. # noqa: E501
82
+ For example, a value of 0.75 means the token should be refreshed
83
+ when 75% of its lifetime has elapsed (or when 25% of its lifetime remains).
84
+
85
+ :return: float
86
+ """
87
+
88
+ return self._expiration_refresh_ratio
89
+
90
+ def get_lower_refresh_bound_millis(self) -> int:
91
+ """
92
+ Represents the minimum time in milliseconds before token expiration
93
+ to trigger a refresh, in milliseconds.
94
+ This value sets a fixed lower bound for when a token refresh should occur,
95
+ regardless of the token's total lifetime.
96
+ If set to 0 there will be no lower bound and the refresh will be triggered
97
+ based on the expirationRefreshRatio only.
98
+
99
+ :return: int
100
+ """
101
+ return self._lower_refresh_bound_millis
102
+
103
+ def get_token_request_execution_timeout_in_ms(self) -> int:
104
+ """
105
+ Represents the maximum time in milliseconds to wait
106
+ for a token request to complete.
107
+
108
+ :return: int
109
+ """
110
+ return self._token_request_execution_timeout_in_ms
111
+
112
+ def get_retry_policy(self) -> RetryPolicy:
113
+ """
114
+ Represents the retry policy for token requests.
115
+
116
+ :return: RetryPolicy
117
+ """
118
+ return self._retry_policy
119
+
120
+
121
+ class TokenManager:
122
+ def __init__(
123
+ self, identity_provider: IdentityProviderInterface, config: TokenManagerConfig
124
+ ):
125
+ self._idp = identity_provider
126
+ self._config = config
127
+ self._next_timer = None
128
+ self._listener = None
129
+ self._init_timer = None
130
+ self._retries = 0
131
+
132
+ def __del__(self):
133
+ logger.info("Token manager are disposed")
134
+ self.stop()
135
+
136
+ def start(
137
+ self,
138
+ listener: CredentialsListener,
139
+ skip_initial: bool = False,
140
+ ) -> Callable[[], None]:
141
+ self._listener = listener
142
+
143
+ try:
144
+ loop = asyncio.get_running_loop()
145
+ except RuntimeError:
146
+ # Run loop in a separate thread to unblock main thread.
147
+ loop = asyncio.new_event_loop()
148
+ thread = threading.Thread(
149
+ target=_start_event_loop_in_thread, args=(loop,), daemon=True
150
+ )
151
+ thread.start()
152
+
153
+ # Event to block for initial execution.
154
+ init_event = asyncio.Event()
155
+ self._init_timer = loop.call_later(
156
+ 0, self._renew_token, skip_initial, init_event
157
+ )
158
+ logger.info("Token manager started")
159
+
160
+ # Blocks in thread-safe manner.
161
+ asyncio.run_coroutine_threadsafe(init_event.wait(), loop).result()
162
+ return self.stop
163
+
164
+ async def start_async(
165
+ self,
166
+ listener: CredentialsListener,
167
+ block_for_initial: bool = False,
168
+ initial_delay_in_ms: float = 0,
169
+ skip_initial: bool = False,
170
+ ) -> Callable[[], None]:
171
+ self._listener = listener
172
+
173
+ loop = asyncio.get_running_loop()
174
+ init_event = asyncio.Event()
175
+
176
+ # Wraps the async callback with async wrapper to schedule with loop.call_later()
177
+ wrapped = _async_to_sync_wrapper(
178
+ loop, self._renew_token_async, skip_initial, init_event
179
+ )
180
+ self._init_timer = loop.call_later(initial_delay_in_ms / 1000, wrapped)
181
+ logger.info("Token manager started")
182
+
183
+ if block_for_initial:
184
+ await init_event.wait()
185
+
186
+ return self.stop
187
+
188
+ def stop(self):
189
+ if self._init_timer is not None:
190
+ self._init_timer.cancel()
191
+ if self._next_timer is not None:
192
+ self._next_timer.cancel()
193
+
194
+ def acquire_token(self, force_refresh=False) -> TokenResponse:
195
+ try:
196
+ token = self._idp.request_token(force_refresh)
197
+ except RequestTokenErr as e:
198
+ if self._retries < self._config.get_retry_policy().get_max_attempts():
199
+ self._retries += 1
200
+ sleep(self._config.get_retry_policy().get_delay_in_ms() / 1000)
201
+ return self.acquire_token(force_refresh)
202
+ else:
203
+ raise e
204
+
205
+ self._retries = 0
206
+ return TokenResponse(token)
207
+
208
+ async def acquire_token_async(self, force_refresh=False) -> TokenResponse:
209
+ try:
210
+ token = self._idp.request_token(force_refresh)
211
+ except RequestTokenErr as e:
212
+ if self._retries < self._config.get_retry_policy().get_max_attempts():
213
+ self._retries += 1
214
+ await asyncio.sleep(
215
+ self._config.get_retry_policy().get_delay_in_ms() / 1000
216
+ )
217
+ return await self.acquire_token_async(force_refresh)
218
+ else:
219
+ raise e
220
+
221
+ self._retries = 0
222
+ return TokenResponse(token)
223
+
224
+ def _calculate_renewal_delay(self, expire_date: float, issue_date: float) -> float:
225
+ delay_for_lower_refresh = self._delay_for_lower_refresh(expire_date)
226
+ delay_for_ratio_refresh = self._delay_for_ratio_refresh(expire_date, issue_date)
227
+ delay = min(delay_for_ratio_refresh, delay_for_lower_refresh)
228
+
229
+ return 0 if delay < 0 else delay / 1000
230
+
231
+ def _delay_for_lower_refresh(self, expire_date: float):
232
+ return (
233
+ expire_date
234
+ - self._config.get_lower_refresh_bound_millis()
235
+ - (datetime.now(timezone.utc).timestamp() * 1000)
236
+ )
237
+
238
+ def _delay_for_ratio_refresh(self, expire_date: float, issue_date: float):
239
+ token_ttl = expire_date - issue_date
240
+ refresh_before = token_ttl - (
241
+ token_ttl * self._config.get_expiration_refresh_ratio()
242
+ )
243
+
244
+ return (
245
+ expire_date
246
+ - refresh_before
247
+ - (datetime.now(timezone.utc).timestamp() * 1000)
248
+ )
249
+
250
+ def _renew_token(
251
+ self, skip_initial: bool = False, init_event: asyncio.Event = None
252
+ ):
253
+ """
254
+ Task to renew token from identity provider.
255
+ Schedules renewal tasks based on token TTL.
256
+ """
257
+
258
+ try:
259
+ token_res = self.acquire_token(force_refresh=True)
260
+ delay = self._calculate_renewal_delay(
261
+ token_res.get_token().get_expires_at_ms(),
262
+ token_res.get_token().get_received_at_ms(),
263
+ )
264
+
265
+ if token_res.get_token().is_expired():
266
+ raise TokenRenewalErr("Requested token is expired")
267
+
268
+ if self._listener.on_next is None:
269
+ logger.warning(
270
+ "No registered callback for token renewal task. Renewal cancelled"
271
+ )
272
+ return
273
+
274
+ if not skip_initial:
275
+ try:
276
+ self._listener.on_next(token_res.get_token())
277
+ except Exception as e:
278
+ raise TokenRenewalErr(e)
279
+
280
+ if delay <= 0:
281
+ return
282
+
283
+ loop = asyncio.get_running_loop()
284
+ self._next_timer = loop.call_later(delay, self._renew_token)
285
+ logger.info(f"Next token renewal scheduled in {delay} seconds")
286
+ return token_res
287
+ except Exception as e:
288
+ if self._listener.on_error is None:
289
+ raise e
290
+
291
+ self._listener.on_error(e)
292
+ finally:
293
+ if init_event:
294
+ init_event.set()
295
+
296
+ async def _renew_token_async(
297
+ self, skip_initial: bool = False, init_event: asyncio.Event = None
298
+ ):
299
+ """
300
+ Async task to renew tokens from identity provider.
301
+ Schedules renewal tasks based on token TTL.
302
+ """
303
+
304
+ try:
305
+ token_res = await self.acquire_token_async(force_refresh=True)
306
+ delay = self._calculate_renewal_delay(
307
+ token_res.get_token().get_expires_at_ms(),
308
+ token_res.get_token().get_received_at_ms(),
309
+ )
310
+
311
+ if token_res.get_token().is_expired():
312
+ raise TokenRenewalErr("Requested token is expired")
313
+
314
+ if self._listener.on_next is None:
315
+ logger.warning(
316
+ "No registered callback for token renewal task. Renewal cancelled"
317
+ )
318
+ return
319
+
320
+ if not skip_initial:
321
+ try:
322
+ await self._listener.on_next(token_res.get_token())
323
+ except Exception as e:
324
+ raise TokenRenewalErr(e)
325
+
326
+ if delay <= 0:
327
+ return
328
+
329
+ loop = asyncio.get_running_loop()
330
+ wrapped = _async_to_sync_wrapper(loop, self._renew_token_async)
331
+ logger.info(f"Next token renewal scheduled in {delay} seconds")
332
+ loop.call_later(delay, wrapped)
333
+ except Exception as e:
334
+ if self._listener.on_error is None:
335
+ raise e
336
+
337
+ await self._listener.on_error(e)
338
+ finally:
339
+ if init_event:
340
+ init_event.set()
341
+
342
+
343
+ def _async_to_sync_wrapper(loop, coro_func, *args, **kwargs):
344
+ """
345
+ Wraps an asynchronous function so it can be used with loop.call_later.
346
+
347
+ :param loop: The event loop in which the coroutine will be executed.
348
+ :param coro_func: The coroutine function to wrap.
349
+ :param args: Positional arguments to pass to the coroutine function.
350
+ :param kwargs: Keyword arguments to pass to the coroutine function.
351
+ :return: A regular function suitable for loop.call_later.
352
+ """
353
+
354
+ def wrapped():
355
+ # Schedule the coroutine in the event loop
356
+ asyncio.ensure_future(coro_func(*args, **kwargs), loop=loop)
357
+
358
+ return wrapped
359
+
360
+
361
+ def _start_event_loop_in_thread(event_loop: asyncio.AbstractEventLoop):
362
+ """
363
+ Starts event loop in a thread.
364
+ Used to be able to schedule tasks using loop.call_later.
365
+
366
+ :param event_loop:
367
+ :return:
368
+ """
369
+ asyncio.set_event_loop(event_loop)
370
+ event_loop.run_forever()
redis/backoff.py CHANGED
@@ -110,5 +110,20 @@ class DecorrelatedJitterBackoff(AbstractBackoff):
110
110
  return self._previous_backoff
111
111
 
112
112
 
113
+ class ExponentialWithJitterBackoff(AbstractBackoff):
114
+ """Exponential backoff upon failure, with jitter"""
115
+
116
+ def __init__(self, cap: float = DEFAULT_CAP, base: float = DEFAULT_BASE) -> None:
117
+ """
118
+ `cap`: maximum backoff time in seconds
119
+ `base`: base backoff time in seconds
120
+ """
121
+ self._cap = cap
122
+ self._base = base
123
+
124
+ def compute(self, failures: int) -> float:
125
+ return min(self._cap, random.random() * self._base * 2**failures)
126
+
127
+
113
128
  def default_backoff():
114
129
  return EqualJitterBackoff()