redis 5.2.1__py3-none-any.whl → 5.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
redis/connection.py CHANGED
@@ -9,7 +9,7 @@ from abc import abstractmethod
9
9
  from itertools import chain
10
10
  from queue import Empty, Full, LifoQueue
11
11
  from time import time
12
- from typing import Any, Callable, Dict, List, Optional, Type, Union
12
+ from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union
13
13
  from urllib.parse import parse_qs, unquote, urlparse
14
14
 
15
15
  from redis.cache import (
@@ -22,8 +22,10 @@ from redis.cache import (
22
22
  )
23
23
 
24
24
  from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser
25
+ from .auth.token import TokenInterface
25
26
  from .backoff import NoBackoff
26
27
  from .credentials import CredentialProvider, UsernamePasswordCredentialProvider
28
+ from .event import AfterConnectionReleasedEvent, EventDispatcher
27
29
  from .exceptions import (
28
30
  AuthenticationError,
29
31
  AuthenticationWrongNumberOfArgsError,
@@ -40,6 +42,7 @@ from .utils import (
40
42
  HIREDIS_AVAILABLE,
41
43
  SSL_AVAILABLE,
42
44
  compare_versions,
45
+ deprecated_args,
43
46
  ensure_string,
44
47
  format_error_message,
45
48
  get_lib_version,
@@ -151,6 +154,10 @@ class ConnectionInterface:
151
154
  def set_parser(self, parser_class):
152
155
  pass
153
156
 
157
+ @abstractmethod
158
+ def get_protocol(self):
159
+ pass
160
+
154
161
  @abstractmethod
155
162
  def connect(self):
156
163
  pass
@@ -202,6 +209,14 @@ class ConnectionInterface:
202
209
  def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]:
203
210
  pass
204
211
 
212
+ @abstractmethod
213
+ def set_re_auth_token(self, token: TokenInterface):
214
+ pass
215
+
216
+ @abstractmethod
217
+ def re_auth(self):
218
+ pass
219
+
205
220
 
206
221
  class AbstractConnection(ConnectionInterface):
207
222
  "Manages communication to and from a Redis server"
@@ -229,6 +244,7 @@ class AbstractConnection(ConnectionInterface):
229
244
  credential_provider: Optional[CredentialProvider] = None,
230
245
  protocol: Optional[int] = 2,
231
246
  command_packer: Optional[Callable[[], None]] = None,
247
+ event_dispatcher: Optional[EventDispatcher] = None,
232
248
  ):
233
249
  """
234
250
  Initialize a new Connection.
@@ -244,6 +260,10 @@ class AbstractConnection(ConnectionInterface):
244
260
  "1. 'password' and (optional) 'username'\n"
245
261
  "2. 'credential_provider'"
246
262
  )
263
+ if event_dispatcher is None:
264
+ self._event_dispatcher = EventDispatcher()
265
+ else:
266
+ self._event_dispatcher = event_dispatcher
247
267
  self.pid = os.getpid()
248
268
  self.db = db
249
269
  self.client_name = client_name
@@ -283,6 +303,7 @@ class AbstractConnection(ConnectionInterface):
283
303
  self.set_parser(parser_class)
284
304
  self._connect_callbacks = []
285
305
  self._buffer_cutoff = 6000
306
+ self._re_auth_token: Optional[TokenInterface] = None
286
307
  try:
287
308
  p = int(protocol)
288
309
  except TypeError:
@@ -663,6 +684,19 @@ class AbstractConnection(ConnectionInterface):
663
684
  def handshake_metadata(self, value: Union[Dict[bytes, bytes], Dict[str, str]]):
664
685
  self._handshake_metadata = value
665
686
 
687
+ def set_re_auth_token(self, token: TokenInterface):
688
+ self._re_auth_token = token
689
+
690
+ def re_auth(self):
691
+ if self._re_auth_token is not None:
692
+ self.send_command(
693
+ "AUTH",
694
+ self._re_auth_token.try_get("oid"),
695
+ self._re_auth_token.get_value(),
696
+ )
697
+ self.read_response()
698
+ self._re_auth_token = None
699
+
666
700
 
667
701
  class Connection(AbstractConnection):
668
702
  "Manages TCP communication to and from a Redis server"
@@ -750,6 +784,7 @@ class CacheProxyConnection(ConnectionInterface):
750
784
  self.retry = self._conn.retry
751
785
  self.host = self._conn.host
752
786
  self.port = self._conn.port
787
+ self.credential_provider = conn.credential_provider
753
788
  self._pool_lock = pool_lock
754
789
  self._cache = cache
755
790
  self._cache_lock = threading.Lock()
@@ -870,9 +905,11 @@ class CacheProxyConnection(ConnectionInterface):
870
905
  and self._cache.get(self._current_command_cache_key).status
871
906
  != CacheEntryStatus.IN_PROGRESS
872
907
  ):
873
- return copy.deepcopy(
908
+ res = copy.deepcopy(
874
909
  self._cache.get(self._current_command_cache_key).cache_value
875
910
  )
911
+ self._current_command_cache_key = None
912
+ return res
876
913
 
877
914
  response = self._conn.read_response(
878
915
  disable_decoding=disable_decoding,
@@ -898,6 +935,8 @@ class CacheProxyConnection(ConnectionInterface):
898
935
  cache_entry.cache_value = response
899
936
  self._cache.set(cache_entry)
900
937
 
938
+ self._current_command_cache_key = None
939
+
901
940
  return response
902
941
 
903
942
  def pack_command(self, *args):
@@ -933,6 +972,15 @@ class CacheProxyConnection(ConnectionInterface):
933
972
  else:
934
973
  self._cache.delete_by_redis_keys(data[1])
935
974
 
975
+ def get_protocol(self):
976
+ return self._conn.get_protocol()
977
+
978
+ def set_re_auth_token(self, token: TokenInterface):
979
+ self._conn.set_re_auth_token(token)
980
+
981
+ def re_auth(self):
982
+ self._conn.re_auth()
983
+
936
984
 
937
985
  class SSLConnection(Connection):
938
986
  """Manages SSL connections to and from the Redis server(s).
@@ -1216,6 +1264,9 @@ def parse_url(url):
1216
1264
  return kwargs
1217
1265
 
1218
1266
 
1267
+ _CP = TypeVar("_CP", bound="ConnectionPool")
1268
+
1269
+
1219
1270
  class ConnectionPool:
1220
1271
  """
1221
1272
  Create a connection pool. ``If max_connections`` is set, then this
@@ -1231,7 +1282,7 @@ class ConnectionPool:
1231
1282
  """
1232
1283
 
1233
1284
  @classmethod
1234
- def from_url(cls, url, **kwargs):
1285
+ def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP:
1235
1286
  """
1236
1287
  Return a connection pool configured from the given URL.
1237
1288
 
@@ -1318,6 +1369,10 @@ class ConnectionPool:
1318
1369
  connection_kwargs.pop("cache", None)
1319
1370
  connection_kwargs.pop("cache_config", None)
1320
1371
 
1372
+ self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None)
1373
+ if self._event_dispatcher is None:
1374
+ self._event_dispatcher = EventDispatcher()
1375
+
1321
1376
  # a lock to protect the critical section in _checkpid().
1322
1377
  # this lock is acquired when the process id changes, such as
1323
1378
  # after a fork. during this time, multiple threads in the child
@@ -1327,6 +1382,7 @@ class ConnectionPool:
1327
1382
  # will notice the first thread already did the work and simply
1328
1383
  # release the lock.
1329
1384
  self._fork_lock = threading.Lock()
1385
+ self._lock = threading.Lock()
1330
1386
  self.reset()
1331
1387
 
1332
1388
  def __repr__(self) -> (str, str):
@@ -1344,7 +1400,6 @@ class ConnectionPool:
1344
1400
  return self.connection_kwargs.get("protocol", None)
1345
1401
 
1346
1402
  def reset(self) -> None:
1347
- self._lock = threading.Lock()
1348
1403
  self._created_connections = 0
1349
1404
  self._available_connections = []
1350
1405
  self._in_use_connections = set()
@@ -1407,8 +1462,14 @@ class ConnectionPool:
1407
1462
  finally:
1408
1463
  self._fork_lock.release()
1409
1464
 
1410
- def get_connection(self, command_name: str, *keys, **options) -> "Connection":
1465
+ @deprecated_args(
1466
+ args_to_warn=["*"],
1467
+ reason="Use get_connection() without args instead",
1468
+ version="5.3.0",
1469
+ )
1470
+ def get_connection(self, command_name=None, *keys, **options) -> "Connection":
1411
1471
  "Get a connection from the pool"
1472
+
1412
1473
  self._checkpid()
1413
1474
  with self._lock:
1414
1475
  try:
@@ -1471,15 +1532,18 @@ class ConnectionPool:
1471
1532
  except KeyError:
1472
1533
  # Gracefully fail when a connection is returned to this pool
1473
1534
  # that the pool doesn't actually own
1474
- pass
1535
+ return
1475
1536
 
1476
1537
  if self.owns_connection(connection):
1477
1538
  self._available_connections.append(connection)
1539
+ self._event_dispatcher.dispatch(
1540
+ AfterConnectionReleasedEvent(connection)
1541
+ )
1478
1542
  else:
1479
- # pool doesn't own this connection. do not add it back
1480
- # to the pool and decrement the count so that another
1481
- # connection can take its place if needed
1482
- self._created_connections -= 1
1543
+ # Pool doesn't own this connection, do not add it back
1544
+ # to the pool.
1545
+ # The created connections count should not be changed,
1546
+ # because the connection was not created by the pool.
1483
1547
  connection.disconnect()
1484
1548
  return
1485
1549
 
@@ -1517,6 +1581,29 @@ class ConnectionPool:
1517
1581
  for conn in self._in_use_connections:
1518
1582
  conn.retry = retry
1519
1583
 
1584
+ def re_auth_callback(self, token: TokenInterface):
1585
+ with self._lock:
1586
+ for conn in self._available_connections:
1587
+ conn.retry.call_with_retry(
1588
+ lambda: conn.send_command(
1589
+ "AUTH", token.try_get("oid"), token.get_value()
1590
+ ),
1591
+ lambda error: self._mock(error),
1592
+ )
1593
+ conn.retry.call_with_retry(
1594
+ lambda: conn.read_response(), lambda error: self._mock(error)
1595
+ )
1596
+ for conn in self._in_use_connections:
1597
+ conn.set_re_auth_token(token)
1598
+
1599
+ async def _mock(self, error: RedisError):
1600
+ """
1601
+ Dummy functions, needs to be passed as error callback to retry object.
1602
+ :param error:
1603
+ :return:
1604
+ """
1605
+ pass
1606
+
1520
1607
 
1521
1608
  class BlockingConnectionPool(ConnectionPool):
1522
1609
  """
@@ -1603,7 +1690,12 @@ class BlockingConnectionPool(ConnectionPool):
1603
1690
  self._connections.append(connection)
1604
1691
  return connection
1605
1692
 
1606
- def get_connection(self, command_name, *keys, **options):
1693
+ @deprecated_args(
1694
+ args_to_warn=["*"],
1695
+ reason="Use get_connection() without args instead",
1696
+ version="5.3.0",
1697
+ )
1698
+ def get_connection(self, command_name=None, *keys, **options):
1607
1699
  """
1608
1700
  Get a connection, blocking for ``self.timeout`` until a connection
1609
1701
  is available from the pool.
redis/credentials.py CHANGED
@@ -1,4 +1,8 @@
1
- from typing import Optional, Tuple, Union
1
+ import logging
2
+ from abc import ABC, abstractmethod
3
+ from typing import Any, Callable, Optional, Tuple, Union
4
+
5
+ logger = logging.getLogger(__name__)
2
6
 
3
7
 
4
8
  class CredentialProvider:
@@ -9,6 +13,38 @@ class CredentialProvider:
9
13
  def get_credentials(self) -> Union[Tuple[str], Tuple[str, str]]:
10
14
  raise NotImplementedError("get_credentials must be implemented")
11
15
 
16
+ async def get_credentials_async(self) -> Union[Tuple[str], Tuple[str, str]]:
17
+ logger.warning(
18
+ "This method is added for backward compatability. "
19
+ "Please override it in your implementation."
20
+ )
21
+ return self.get_credentials()
22
+
23
+
24
+ class StreamingCredentialProvider(CredentialProvider, ABC):
25
+ """
26
+ Credential provider that streams credentials in the background.
27
+ """
28
+
29
+ @abstractmethod
30
+ def on_next(self, callback: Callable[[Any], None]):
31
+ """
32
+ Specifies the callback that should be invoked
33
+ when the next credentials will be retrieved.
34
+
35
+ :param callback: Callback with
36
+ :return:
37
+ """
38
+ pass
39
+
40
+ @abstractmethod
41
+ def on_error(self, callback: Callable[[Exception], None]):
42
+ pass
43
+
44
+ @abstractmethod
45
+ def is_streaming(self) -> bool:
46
+ pass
47
+
12
48
 
13
49
  class UsernamePasswordCredentialProvider(CredentialProvider):
14
50
  """
@@ -24,3 +60,6 @@ class UsernamePasswordCredentialProvider(CredentialProvider):
24
60
  if self.username:
25
61
  return self.username, self.password
26
62
  return (self.password,)
63
+
64
+ async def get_credentials_async(self) -> Union[Tuple[str], Tuple[str, str]]:
65
+ return self.get_credentials()