redis 6.4.0__py3-none-any.whl → 7.0.0b2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
redis/connection.py CHANGED
@@ -8,7 +8,18 @@ import weakref
8
8
  from abc import abstractmethod
9
9
  from itertools import chain
10
10
  from queue import Empty, Full, LifoQueue
11
- from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union
11
+ from typing import (
12
+ Any,
13
+ Callable,
14
+ Dict,
15
+ Iterable,
16
+ List,
17
+ Literal,
18
+ Optional,
19
+ Type,
20
+ TypeVar,
21
+ Union,
22
+ )
12
23
  from urllib.parse import parse_qs, unquote, urlparse
13
24
 
14
25
  from redis.cache import (
@@ -36,6 +47,12 @@ from .exceptions import (
36
47
  ResponseError,
37
48
  TimeoutError,
38
49
  )
50
+ from .maint_notifications import (
51
+ MaintenanceState,
52
+ MaintNotificationsConfig,
53
+ MaintNotificationsConnectionHandler,
54
+ MaintNotificationsPoolHandler,
55
+ )
39
56
  from .retry import Retry
40
57
  from .utils import (
41
58
  CRYPTOGRAPHY_AVAILABLE,
@@ -51,8 +68,10 @@ from .utils import (
51
68
 
52
69
  if SSL_AVAILABLE:
53
70
  import ssl
71
+ from ssl import VerifyFlags
54
72
  else:
55
73
  ssl = None
74
+ VerifyFlags = None
56
75
 
57
76
  if HIREDIS_AVAILABLE:
58
77
  import hiredis
@@ -159,6 +178,10 @@ class ConnectionInterface:
159
178
  def set_parser(self, parser_class):
160
179
  pass
161
180
 
181
+ @abstractmethod
182
+ def set_maint_notifications_pool_handler(self, maint_notifications_pool_handler):
183
+ pass
184
+
162
185
  @abstractmethod
163
186
  def get_protocol(self):
164
187
  pass
@@ -222,6 +245,80 @@ class ConnectionInterface:
222
245
  def re_auth(self):
223
246
  pass
224
247
 
248
+ @property
249
+ @abstractmethod
250
+ def maintenance_state(self) -> MaintenanceState:
251
+ """
252
+ Returns the current maintenance state of the connection.
253
+ """
254
+ pass
255
+
256
+ @maintenance_state.setter
257
+ @abstractmethod
258
+ def maintenance_state(self, state: "MaintenanceState"):
259
+ """
260
+ Sets the current maintenance state of the connection.
261
+ """
262
+ pass
263
+
264
+ @abstractmethod
265
+ def getpeername(self):
266
+ """
267
+ Returns the peer name of the connection.
268
+ """
269
+ pass
270
+
271
+ @abstractmethod
272
+ def mark_for_reconnect(self):
273
+ """
274
+ Mark the connection to be reconnected on the next command.
275
+ This is useful when a connection is moved to a different node.
276
+ """
277
+ pass
278
+
279
+ @abstractmethod
280
+ def should_reconnect(self):
281
+ """
282
+ Returns True if the connection should be reconnected.
283
+ """
284
+ pass
285
+
286
+ @abstractmethod
287
+ def get_resolved_ip(self):
288
+ """
289
+ Get resolved ip address for the connection.
290
+ """
291
+ pass
292
+
293
+ @abstractmethod
294
+ def update_current_socket_timeout(self, relaxed_timeout: Optional[float] = None):
295
+ """
296
+ Update the timeout for the current socket.
297
+ """
298
+ pass
299
+
300
+ @abstractmethod
301
+ def set_tmp_settings(
302
+ self,
303
+ tmp_host_address: Optional[str] = None,
304
+ tmp_relaxed_timeout: Optional[float] = None,
305
+ ):
306
+ """
307
+ Updates temporary host address and timeout settings for the connection.
308
+ """
309
+ pass
310
+
311
+ @abstractmethod
312
+ def reset_tmp_settings(
313
+ self,
314
+ reset_host_address: bool = False,
315
+ reset_relaxed_timeout: bool = False,
316
+ ):
317
+ """
318
+ Resets temporary host address and timeout settings for the connection.
319
+ """
320
+ pass
321
+
225
322
 
226
323
  class AbstractConnection(ConnectionInterface):
227
324
  "Manages communication to and from a Redis server"
@@ -233,7 +330,7 @@ class AbstractConnection(ConnectionInterface):
233
330
  socket_timeout: Optional[float] = None,
234
331
  socket_connect_timeout: Optional[float] = None,
235
332
  retry_on_timeout: bool = False,
236
- retry_on_error=SENTINEL,
333
+ retry_on_error: Union[Iterable[Type[Exception]], object] = SENTINEL,
237
334
  encoding: str = "utf-8",
238
335
  encoding_errors: str = "strict",
239
336
  decode_responses: bool = False,
@@ -250,6 +347,15 @@ class AbstractConnection(ConnectionInterface):
250
347
  protocol: Optional[int] = 2,
251
348
  command_packer: Optional[Callable[[], None]] = None,
252
349
  event_dispatcher: Optional[EventDispatcher] = None,
350
+ maint_notifications_pool_handler: Optional[
351
+ MaintNotificationsPoolHandler
352
+ ] = None,
353
+ maint_notifications_config: Optional[MaintNotificationsConfig] = None,
354
+ maintenance_state: "MaintenanceState" = MaintenanceState.NONE,
355
+ maintenance_notification_hash: Optional[int] = None,
356
+ orig_host_address: Optional[str] = None,
357
+ orig_socket_timeout: Optional[float] = None,
358
+ orig_socket_connect_timeout: Optional[float] = None,
253
359
  ):
254
360
  """
255
361
  Initialize a new Connection.
@@ -283,19 +389,22 @@ class AbstractConnection(ConnectionInterface):
283
389
  self.socket_connect_timeout = socket_connect_timeout
284
390
  self.retry_on_timeout = retry_on_timeout
285
391
  if retry_on_error is SENTINEL:
286
- retry_on_error = []
392
+ retry_on_errors_list = []
393
+ else:
394
+ retry_on_errors_list = list(retry_on_error)
287
395
  if retry_on_timeout:
288
396
  # Add TimeoutError to the errors list to retry on
289
- retry_on_error.append(TimeoutError)
290
- self.retry_on_error = retry_on_error
291
- if retry or retry_on_error:
397
+ retry_on_errors_list.append(TimeoutError)
398
+ self.retry_on_error = retry_on_errors_list
399
+ if retry or self.retry_on_error:
292
400
  if retry is None:
293
401
  self.retry = Retry(NoBackoff(), 1)
294
402
  else:
295
403
  # deep-copy the Retry object as it is mutable
296
404
  self.retry = copy.deepcopy(retry)
297
- # Update the retry's supported errors with the specified errors
298
- self.retry.update_supported_errors(retry_on_error)
405
+ if self.retry_on_error:
406
+ # Update the retry's supported errors with the specified errors
407
+ self.retry.update_supported_errors(self.retry_on_error)
299
408
  else:
300
409
  self.retry = Retry(NoBackoff(), 0)
301
410
  self.health_check_interval = health_check_interval
@@ -305,7 +414,6 @@ class AbstractConnection(ConnectionInterface):
305
414
  self.handshake_metadata = None
306
415
  self._sock = None
307
416
  self._socket_read_size = socket_read_size
308
- self.set_parser(parser_class)
309
417
  self._connect_callbacks = []
310
418
  self._buffer_cutoff = 6000
311
419
  self._re_auth_token: Optional[TokenInterface] = None
@@ -320,6 +428,24 @@ class AbstractConnection(ConnectionInterface):
320
428
  raise ConnectionError("protocol must be either 2 or 3")
321
429
  # p = DEFAULT_RESP_VERSION
322
430
  self.protocol = p
431
+ if self.protocol == 3 and parser_class == DefaultParser:
432
+ parser_class = _RESP3Parser
433
+ self.set_parser(parser_class)
434
+
435
+ self.maint_notifications_config = maint_notifications_config
436
+
437
+ # Set up maintenance notifications if enabled
438
+ self._configure_maintenance_notifications(
439
+ maint_notifications_pool_handler,
440
+ orig_host_address,
441
+ orig_socket_timeout,
442
+ orig_socket_connect_timeout,
443
+ )
444
+
445
+ self._should_reconnect = False
446
+ self.maintenance_state = maintenance_state
447
+ self.maintenance_notification_hash = maintenance_notification_hash
448
+
323
449
  self._command_packer = self._construct_command_packer(command_packer)
324
450
 
325
451
  def __repr__(self):
@@ -375,6 +501,69 @@ class AbstractConnection(ConnectionInterface):
375
501
  """
376
502
  self._parser = parser_class(socket_read_size=self._socket_read_size)
377
503
 
504
+ def _configure_maintenance_notifications(
505
+ self,
506
+ maint_notifications_pool_handler=None,
507
+ orig_host_address=None,
508
+ orig_socket_timeout=None,
509
+ orig_socket_connect_timeout=None,
510
+ ):
511
+ """Enable maintenance notifications by setting up handlers and storing original connection parameters."""
512
+ if (
513
+ not self.maint_notifications_config
514
+ or not self.maint_notifications_config.enabled
515
+ ):
516
+ self._maint_notifications_connection_handler = None
517
+ return
518
+
519
+ # Set up pool handler if available
520
+ if maint_notifications_pool_handler:
521
+ self._parser.set_node_moving_push_handler(
522
+ maint_notifications_pool_handler.handle_notification
523
+ )
524
+
525
+ # Set up connection handler
526
+ self._maint_notifications_connection_handler = (
527
+ MaintNotificationsConnectionHandler(self, self.maint_notifications_config)
528
+ )
529
+ self._parser.set_maintenance_push_handler(
530
+ self._maint_notifications_connection_handler.handle_notification
531
+ )
532
+
533
+ # Store original connection parameters
534
+ self.orig_host_address = orig_host_address if orig_host_address else self.host
535
+ self.orig_socket_timeout = (
536
+ orig_socket_timeout if orig_socket_timeout else self.socket_timeout
537
+ )
538
+ self.orig_socket_connect_timeout = (
539
+ orig_socket_connect_timeout
540
+ if orig_socket_connect_timeout
541
+ else self.socket_connect_timeout
542
+ )
543
+
544
+ def set_maint_notifications_pool_handler(
545
+ self, maint_notifications_pool_handler: MaintNotificationsPoolHandler
546
+ ):
547
+ maint_notifications_pool_handler.set_connection(self)
548
+ self._parser.set_node_moving_push_handler(
549
+ maint_notifications_pool_handler.handle_notification
550
+ )
551
+
552
+ # Update maintenance notification connection handler if it doesn't exist
553
+ if not self._maint_notifications_connection_handler:
554
+ self._maint_notifications_connection_handler = (
555
+ MaintNotificationsConnectionHandler(
556
+ self, maint_notifications_pool_handler.config
557
+ )
558
+ )
559
+ self._parser.set_maintenance_push_handler(
560
+ self._maint_notifications_connection_handler.handle_notification
561
+ )
562
+ else:
563
+ self._maint_notifications_connection_handler.config = (
564
+ maint_notifications_pool_handler.config
565
+ )
566
+
378
567
  def connect(self):
379
568
  "Connects to the Redis server if not already connected"
380
569
  self.connect_check_health(check_health=True)
@@ -499,6 +688,49 @@ class AbstractConnection(ConnectionInterface):
499
688
  ):
500
689
  raise ConnectionError("Invalid RESP version")
501
690
 
691
+ # Send maintenance notifications handshake if RESP3 is active
692
+ # and maintenance notifications are enabled
693
+ # and we have a host to determine the endpoint type from
694
+ # When the maint_notifications_config enabled mode is "auto",
695
+ # we just log a warning if the handshake fails
696
+ # When the mode is enabled=True, we raise an exception in case of failure
697
+ if (
698
+ self.protocol not in [2, "2"]
699
+ and self.maint_notifications_config
700
+ and self.maint_notifications_config.enabled
701
+ and self._maint_notifications_connection_handler
702
+ and hasattr(self, "host")
703
+ ):
704
+ try:
705
+ endpoint_type = self.maint_notifications_config.get_endpoint_type(
706
+ self.host, self
707
+ )
708
+ self.send_command(
709
+ "CLIENT",
710
+ "MAINT_NOTIFICATIONS",
711
+ "ON",
712
+ "moving-endpoint-type",
713
+ endpoint_type.value,
714
+ check_health=check_health,
715
+ )
716
+ response = self.read_response()
717
+ if str_if_bytes(response) != "OK":
718
+ raise ResponseError(
719
+ "The server doesn't support maintenance notifications"
720
+ )
721
+ except Exception as e:
722
+ if (
723
+ isinstance(e, ResponseError)
724
+ and self.maint_notifications_config.enabled == "auto"
725
+ ):
726
+ # Log warning but don't fail the connection
727
+ import logging
728
+
729
+ logger = logging.getLogger(__name__)
730
+ logger.warning(f"Failed to enable maintenance notifications: {e}")
731
+ else:
732
+ raise
733
+
502
734
  # if a client_name is given, set it
503
735
  if self.client_name:
504
736
  self.send_command(
@@ -549,6 +781,8 @@ class AbstractConnection(ConnectionInterface):
549
781
 
550
782
  conn_sock = self._sock
551
783
  self._sock = None
784
+ # reset the reconnect flag
785
+ self._should_reconnect = False
552
786
  if conn_sock is None:
553
787
  return
554
788
 
@@ -626,6 +860,7 @@ class AbstractConnection(ConnectionInterface):
626
860
 
627
861
  try:
628
862
  return self._parser.can_read(timeout)
863
+
629
864
  except OSError as e:
630
865
  self.disconnect()
631
866
  raise ConnectionError(f"Error while reading from {host_error}: {e.args}")
@@ -732,6 +967,110 @@ class AbstractConnection(ConnectionInterface):
732
967
  self.read_response()
733
968
  self._re_auth_token = None
734
969
 
970
+ def get_resolved_ip(self) -> Optional[str]:
971
+ """
972
+ Extract the resolved IP address from an
973
+ established connection or resolve it from the host.
974
+
975
+ First tries to get the actual IP from the socket (most accurate),
976
+ then falls back to DNS resolution if needed.
977
+
978
+ Args:
979
+ connection: The connection object to extract the IP from
980
+
981
+ Returns:
982
+ str: The resolved IP address, or None if it cannot be determined
983
+ """
984
+
985
+ # Method 1: Try to get the actual IP from the established socket connection
986
+ # This is most accurate as it shows the exact IP being used
987
+ try:
988
+ if self._sock is not None:
989
+ peer_addr = self._sock.getpeername()
990
+ if peer_addr and len(peer_addr) >= 1:
991
+ # For TCP sockets, peer_addr is typically (host, port) tuple
992
+ # Return just the host part
993
+ return peer_addr[0]
994
+ except (AttributeError, OSError):
995
+ # Socket might not be connected or getpeername() might fail
996
+ pass
997
+
998
+ # Method 2: Fallback to DNS resolution of the host
999
+ # This is less accurate but works when socket is not available
1000
+ try:
1001
+ host = getattr(self, "host", "localhost")
1002
+ port = getattr(self, "port", 6379)
1003
+ if host:
1004
+ # Use getaddrinfo to resolve the hostname to IP
1005
+ # This mimics what the connection would do during _connect()
1006
+ addr_info = socket.getaddrinfo(
1007
+ host, port, socket.AF_UNSPEC, socket.SOCK_STREAM
1008
+ )
1009
+ if addr_info:
1010
+ # Return the IP from the first result
1011
+ # addr_info[0] is (family, socktype, proto, canonname, sockaddr)
1012
+ # sockaddr[0] is the IP address
1013
+ return addr_info[0][4][0]
1014
+ except (AttributeError, OSError, socket.gaierror):
1015
+ # DNS resolution might fail
1016
+ pass
1017
+
1018
+ return None
1019
+
1020
+ @property
1021
+ def maintenance_state(self) -> MaintenanceState:
1022
+ return self._maintenance_state
1023
+
1024
+ @maintenance_state.setter
1025
+ def maintenance_state(self, state: "MaintenanceState"):
1026
+ self._maintenance_state = state
1027
+
1028
+ def getpeername(self):
1029
+ if not self._sock:
1030
+ return None
1031
+ return self._sock.getpeername()[0]
1032
+
1033
+ def mark_for_reconnect(self):
1034
+ self._should_reconnect = True
1035
+
1036
+ def should_reconnect(self):
1037
+ return self._should_reconnect
1038
+
1039
+ def update_current_socket_timeout(self, relaxed_timeout: Optional[float] = None):
1040
+ if self._sock:
1041
+ timeout = relaxed_timeout if relaxed_timeout != -1 else self.socket_timeout
1042
+ self._sock.settimeout(timeout)
1043
+ self.update_parser_buffer_timeout(timeout)
1044
+
1045
+ def update_parser_buffer_timeout(self, timeout: Optional[float] = None):
1046
+ if self._parser and self._parser._buffer:
1047
+ self._parser._buffer.socket_timeout = timeout
1048
+
1049
+ def set_tmp_settings(
1050
+ self,
1051
+ tmp_host_address: Optional[Union[str, object]] = SENTINEL,
1052
+ tmp_relaxed_timeout: Optional[float] = None,
1053
+ ):
1054
+ """
1055
+ The value of SENTINEL is used to indicate that the property should not be updated.
1056
+ """
1057
+ if tmp_host_address is not SENTINEL:
1058
+ self.host = tmp_host_address
1059
+ if tmp_relaxed_timeout != -1:
1060
+ self.socket_timeout = tmp_relaxed_timeout
1061
+ self.socket_connect_timeout = tmp_relaxed_timeout
1062
+
1063
+ def reset_tmp_settings(
1064
+ self,
1065
+ reset_host_address: bool = False,
1066
+ reset_relaxed_timeout: bool = False,
1067
+ ):
1068
+ if reset_host_address:
1069
+ self.host = self.orig_host_address
1070
+ if reset_relaxed_timeout:
1071
+ self.socket_timeout = self.orig_socket_timeout
1072
+ self.socket_connect_timeout = self.orig_socket_connect_timeout
1073
+
735
1074
 
736
1075
  class Connection(AbstractConnection):
737
1076
  "Manages TCP communication to and from a Redis server"
@@ -764,6 +1103,7 @@ class Connection(AbstractConnection):
764
1103
  # ipv4/ipv6, but we want to set options prior to calling
765
1104
  # socket.connect()
766
1105
  err = None
1106
+
767
1107
  for res in socket.getaddrinfo(
768
1108
  self.host, self.port, self.socket_type, socket.SOCK_STREAM
769
1109
  ):
@@ -1032,6 +1372,8 @@ class SSLConnection(Connection):
1032
1372
  ssl_keyfile=None,
1033
1373
  ssl_certfile=None,
1034
1374
  ssl_cert_reqs="required",
1375
+ ssl_include_verify_flags: Optional[List["VerifyFlags"]] = None,
1376
+ ssl_exclude_verify_flags: Optional[List["VerifyFlags"]] = None,
1035
1377
  ssl_ca_certs=None,
1036
1378
  ssl_ca_data=None,
1037
1379
  ssl_check_hostname=True,
@@ -1050,10 +1392,13 @@ class SSLConnection(Connection):
1050
1392
  Args:
1051
1393
  ssl_keyfile: Path to an ssl private key. Defaults to None.
1052
1394
  ssl_certfile: Path to an ssl certificate. Defaults to None.
1053
- ssl_cert_reqs: The string value for the SSLContext.verify_mode (none, optional, required), or an ssl.VerifyMode. Defaults to "required".
1395
+ ssl_cert_reqs: The string value for the SSLContext.verify_mode (none, optional, required),
1396
+ or an ssl.VerifyMode. Defaults to "required".
1397
+ ssl_include_verify_flags: A list of flags to be included in the SSLContext.verify_flags. Defaults to None.
1398
+ ssl_exclude_verify_flags: A list of flags to be excluded from the SSLContext.verify_flags. Defaults to None.
1054
1399
  ssl_ca_certs: The path to a file of concatenated CA certificates in PEM format. Defaults to None.
1055
1400
  ssl_ca_data: Either an ASCII string of one or more PEM-encoded certificates or a bytes-like object of DER-encoded certificates.
1056
- ssl_check_hostname: If set, match the hostname during the SSL handshake. Defaults to False.
1401
+ ssl_check_hostname: If set, match the hostname during the SSL handshake. Defaults to True.
1057
1402
  ssl_ca_path: The path to a directory containing several CA certificates in PEM format. Defaults to None.
1058
1403
  ssl_password: Password for unlocking an encrypted private key. Defaults to None.
1059
1404
 
@@ -1086,6 +1431,8 @@ class SSLConnection(Connection):
1086
1431
  )
1087
1432
  ssl_cert_reqs = CERT_REQS[ssl_cert_reqs]
1088
1433
  self.cert_reqs = ssl_cert_reqs
1434
+ self.ssl_include_verify_flags = ssl_include_verify_flags
1435
+ self.ssl_exclude_verify_flags = ssl_exclude_verify_flags
1089
1436
  self.ca_certs = ssl_ca_certs
1090
1437
  self.ca_data = ssl_ca_data
1091
1438
  self.ca_path = ssl_ca_path
@@ -1125,6 +1472,12 @@ class SSLConnection(Connection):
1125
1472
  context = ssl.create_default_context()
1126
1473
  context.check_hostname = self.check_hostname
1127
1474
  context.verify_mode = self.cert_reqs
1475
+ if self.ssl_include_verify_flags:
1476
+ for flag in self.ssl_include_verify_flags:
1477
+ context.verify_flags |= flag
1478
+ if self.ssl_exclude_verify_flags:
1479
+ for flag in self.ssl_exclude_verify_flags:
1480
+ context.verify_flags &= ~flag
1128
1481
  if self.certfile or self.keyfile:
1129
1482
  context.load_cert_chain(
1130
1483
  certfile=self.certfile,
@@ -1238,6 +1591,20 @@ def to_bool(value):
1238
1591
  return bool(value)
1239
1592
 
1240
1593
 
1594
+ def parse_ssl_verify_flags(value):
1595
+ # flags are passed in as a string representation of a list,
1596
+ # e.g. VERIFY_X509_STRICT, VERIFY_X509_PARTIAL_CHAIN
1597
+ verify_flags_str = value.replace("[", "").replace("]", "")
1598
+
1599
+ verify_flags = []
1600
+ for flag in verify_flags_str.split(","):
1601
+ flag = flag.strip()
1602
+ if not hasattr(VerifyFlags, flag):
1603
+ raise ValueError(f"Invalid ssl verify flag: {flag}")
1604
+ verify_flags.append(getattr(VerifyFlags, flag))
1605
+ return verify_flags
1606
+
1607
+
1241
1608
  URL_QUERY_ARGUMENT_PARSERS = {
1242
1609
  "db": int,
1243
1610
  "socket_timeout": float,
@@ -1248,6 +1615,8 @@ URL_QUERY_ARGUMENT_PARSERS = {
1248
1615
  "max_connections": int,
1249
1616
  "health_check_interval": int,
1250
1617
  "ssl_check_hostname": to_bool,
1618
+ "ssl_include_verify_flags": parse_ssl_verify_flags,
1619
+ "ssl_exclude_verify_flags": parse_ssl_verify_flags,
1251
1620
  "timeout": float,
1252
1621
  }
1253
1622
 
@@ -1394,7 +1763,7 @@ class ConnectionPool:
1394
1763
  self._cache_factory = cache_factory
1395
1764
 
1396
1765
  if connection_kwargs.get("cache_config") or connection_kwargs.get("cache"):
1397
- if connection_kwargs.get("protocol") not in [3, "3"]:
1766
+ if self.connection_kwargs.get("protocol") not in [3, "3"]:
1398
1767
  raise RedisError("Client caching is only supported with RESP version 3")
1399
1768
 
1400
1769
  cache = self.connection_kwargs.get("cache")
@@ -1415,6 +1784,22 @@ class ConnectionPool:
1415
1784
  connection_kwargs.pop("cache", None)
1416
1785
  connection_kwargs.pop("cache_config", None)
1417
1786
 
1787
+ if self.connection_kwargs.get(
1788
+ "maint_notifications_pool_handler"
1789
+ ) or self.connection_kwargs.get("maint_notifications_config"):
1790
+ if self.connection_kwargs.get("protocol") not in [3, "3"]:
1791
+ raise RedisError(
1792
+ "Push handlers on connection are only supported with RESP version 3"
1793
+ )
1794
+ config = self.connection_kwargs.get("maint_notifications_config", None) or (
1795
+ self.connection_kwargs.get("maint_notifications_pool_handler").config
1796
+ if self.connection_kwargs.get("maint_notifications_pool_handler")
1797
+ else None
1798
+ )
1799
+
1800
+ if config and config.enabled:
1801
+ self._update_connection_kwargs_for_maint_notifications()
1802
+
1418
1803
  self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None)
1419
1804
  if self._event_dispatcher is None:
1420
1805
  self._event_dispatcher = EventDispatcher()
@@ -1449,6 +1834,69 @@ class ConnectionPool:
1449
1834
  """
1450
1835
  return self.connection_kwargs.get("protocol", None)
1451
1836
 
1837
+ def maint_notifications_pool_handler_enabled(self):
1838
+ """
1839
+ Returns:
1840
+ True if the maintenance notifications pool handler is enabled, False otherwise.
1841
+ """
1842
+ maint_notifications_config = self.connection_kwargs.get(
1843
+ "maint_notifications_config", None
1844
+ )
1845
+
1846
+ return maint_notifications_config and maint_notifications_config.enabled
1847
+
1848
+ def set_maint_notifications_pool_handler(
1849
+ self, maint_notifications_pool_handler: MaintNotificationsPoolHandler
1850
+ ):
1851
+ self.connection_kwargs.update(
1852
+ {
1853
+ "maint_notifications_pool_handler": maint_notifications_pool_handler,
1854
+ "maint_notifications_config": maint_notifications_pool_handler.config,
1855
+ }
1856
+ )
1857
+ self._update_connection_kwargs_for_maint_notifications()
1858
+
1859
+ self._update_maint_notifications_configs_for_connections(
1860
+ maint_notifications_pool_handler
1861
+ )
1862
+
1863
+ def _update_maint_notifications_configs_for_connections(
1864
+ self, maint_notifications_pool_handler
1865
+ ):
1866
+ """Update the maintenance notifications config for all connections in the pool."""
1867
+ with self._lock:
1868
+ for conn in self._available_connections:
1869
+ conn.set_maint_notifications_pool_handler(
1870
+ maint_notifications_pool_handler
1871
+ )
1872
+ conn.maint_notifications_config = (
1873
+ maint_notifications_pool_handler.config
1874
+ )
1875
+ for conn in self._in_use_connections:
1876
+ conn.set_maint_notifications_pool_handler(
1877
+ maint_notifications_pool_handler
1878
+ )
1879
+ conn.maint_notifications_config = (
1880
+ maint_notifications_pool_handler.config
1881
+ )
1882
+
1883
+ def _update_connection_kwargs_for_maint_notifications(self):
1884
+ """Store original connection parameters for maintenance notifications."""
1885
+ if self.connection_kwargs.get("orig_host_address", None) is None:
1886
+ # If orig_host_address is None it means we haven't
1887
+ # configured the original values yet
1888
+ self.connection_kwargs.update(
1889
+ {
1890
+ "orig_host_address": self.connection_kwargs.get("host"),
1891
+ "orig_socket_timeout": self.connection_kwargs.get(
1892
+ "socket_timeout", None
1893
+ ),
1894
+ "orig_socket_connect_timeout": self.connection_kwargs.get(
1895
+ "socket_connect_timeout", None
1896
+ ),
1897
+ }
1898
+ )
1899
+
1452
1900
  def reset(self) -> None:
1453
1901
  self._created_connections = 0
1454
1902
  self._available_connections = []
@@ -1536,7 +1984,11 @@ class ConnectionPool:
1536
1984
  # pool before all data has been read or the socket has been
1537
1985
  # closed. either way, reconnect and verify everything is good.
1538
1986
  try:
1539
- if connection.can_read() and self.cache is None:
1987
+ if (
1988
+ connection.can_read()
1989
+ and self.cache is None
1990
+ and not self.maint_notifications_pool_handler_enabled()
1991
+ ):
1540
1992
  raise ConnectionError("Connection has data")
1541
1993
  except (ConnectionError, TimeoutError, OSError):
1542
1994
  connection.disconnect()
@@ -1548,7 +2000,6 @@ class ConnectionPool:
1548
2000
  # leak it
1549
2001
  self.release(connection)
1550
2002
  raise
1551
-
1552
2003
  return connection
1553
2004
 
1554
2005
  def get_encoder(self) -> Encoder:
@@ -1566,12 +2017,13 @@ class ConnectionPool:
1566
2017
  raise MaxConnectionsError("Too many connections")
1567
2018
  self._created_connections += 1
1568
2019
 
2020
+ kwargs = dict(self.connection_kwargs)
2021
+
1569
2022
  if self.cache is not None:
1570
2023
  return CacheProxyConnection(
1571
- self.connection_class(**self.connection_kwargs), self.cache, self._lock
2024
+ self.connection_class(**kwargs), self.cache, self._lock
1572
2025
  )
1573
-
1574
- return self.connection_class(**self.connection_kwargs)
2026
+ return self.connection_class(**kwargs)
1575
2027
 
1576
2028
  def release(self, connection: "Connection") -> None:
1577
2029
  "Releases the connection back to the pool"
@@ -1585,6 +2037,8 @@ class ConnectionPool:
1585
2037
  return
1586
2038
 
1587
2039
  if self.owns_connection(connection):
2040
+ if connection.should_reconnect():
2041
+ connection.disconnect()
1588
2042
  self._available_connections.append(connection)
1589
2043
  self._event_dispatcher.dispatch(
1590
2044
  AfterConnectionReleasedEvent(connection)
@@ -1646,6 +2100,186 @@ class ConnectionPool:
1646
2100
  for conn in self._in_use_connections:
1647
2101
  conn.set_re_auth_token(token)
1648
2102
 
2103
+ def _should_update_connection(
2104
+ self,
2105
+ conn: "Connection",
2106
+ matching_pattern: Literal[
2107
+ "connected_address", "configured_address", "notification_hash"
2108
+ ] = "connected_address",
2109
+ matching_address: Optional[str] = None,
2110
+ matching_notification_hash: Optional[int] = None,
2111
+ ) -> bool:
2112
+ """
2113
+ Check if the connection should be updated based on the matching criteria.
2114
+ """
2115
+ if matching_pattern == "connected_address":
2116
+ if matching_address and conn.getpeername() != matching_address:
2117
+ return False
2118
+ elif matching_pattern == "configured_address":
2119
+ if matching_address and conn.host != matching_address:
2120
+ return False
2121
+ elif matching_pattern == "notification_hash":
2122
+ if (
2123
+ matching_notification_hash
2124
+ and conn.maintenance_notification_hash != matching_notification_hash
2125
+ ):
2126
+ return False
2127
+ return True
2128
+
2129
+ def update_connection_settings(
2130
+ self,
2131
+ conn: "Connection",
2132
+ state: Optional["MaintenanceState"] = None,
2133
+ maintenance_notification_hash: Optional[int] = None,
2134
+ host_address: Optional[str] = None,
2135
+ relaxed_timeout: Optional[float] = None,
2136
+ update_notification_hash: bool = False,
2137
+ reset_host_address: bool = False,
2138
+ reset_relaxed_timeout: bool = False,
2139
+ ):
2140
+ """
2141
+ Update the settings for a single connection.
2142
+ """
2143
+ if state:
2144
+ conn.maintenance_state = state
2145
+
2146
+ if update_notification_hash:
2147
+ # update the notification hash only if requested
2148
+ conn.maintenance_notification_hash = maintenance_notification_hash
2149
+
2150
+ if host_address is not None:
2151
+ conn.set_tmp_settings(tmp_host_address=host_address)
2152
+
2153
+ if relaxed_timeout is not None:
2154
+ conn.set_tmp_settings(tmp_relaxed_timeout=relaxed_timeout)
2155
+
2156
+ if reset_relaxed_timeout or reset_host_address:
2157
+ conn.reset_tmp_settings(
2158
+ reset_host_address=reset_host_address,
2159
+ reset_relaxed_timeout=reset_relaxed_timeout,
2160
+ )
2161
+
2162
+ conn.update_current_socket_timeout(relaxed_timeout)
2163
+
2164
+ def update_connections_settings(
2165
+ self,
2166
+ state: Optional["MaintenanceState"] = None,
2167
+ maintenance_notification_hash: Optional[int] = None,
2168
+ host_address: Optional[str] = None,
2169
+ relaxed_timeout: Optional[float] = None,
2170
+ matching_address: Optional[str] = None,
2171
+ matching_notification_hash: Optional[int] = None,
2172
+ matching_pattern: Literal[
2173
+ "connected_address", "configured_address", "notification_hash"
2174
+ ] = "connected_address",
2175
+ update_notification_hash: bool = False,
2176
+ reset_host_address: bool = False,
2177
+ reset_relaxed_timeout: bool = False,
2178
+ include_free_connections: bool = True,
2179
+ ):
2180
+ """
2181
+ Update the settings for all matching connections in the pool.
2182
+
2183
+ This method does not create new connections.
2184
+ This method does not affect the connection kwargs.
2185
+
2186
+ :param state: The maintenance state to set for the connection.
2187
+ :param maintenance_notification_hash: The hash of the maintenance notification
2188
+ to set for the connection.
2189
+ :param host_address: The host address to set for the connection.
2190
+ :param relaxed_timeout: The relaxed timeout to set for the connection.
2191
+ :param matching_address: The address to match for the connection.
2192
+ :param matching_notification_hash: The notification hash to match for the connection.
2193
+ :param matching_pattern: The pattern to match for the connection.
2194
+ :param update_notification_hash: Whether to update the notification hash for the connection.
2195
+ :param reset_host_address: Whether to reset the host address to the original address.
2196
+ :param reset_relaxed_timeout: Whether to reset the relaxed timeout to the original timeout.
2197
+ :param include_free_connections: Whether to include free/available connections.
2198
+ """
2199
+ with self._lock:
2200
+ for conn in self._in_use_connections:
2201
+ if self._should_update_connection(
2202
+ conn,
2203
+ matching_pattern,
2204
+ matching_address,
2205
+ matching_notification_hash,
2206
+ ):
2207
+ self.update_connection_settings(
2208
+ conn,
2209
+ state=state,
2210
+ maintenance_notification_hash=maintenance_notification_hash,
2211
+ host_address=host_address,
2212
+ relaxed_timeout=relaxed_timeout,
2213
+ update_notification_hash=update_notification_hash,
2214
+ reset_host_address=reset_host_address,
2215
+ reset_relaxed_timeout=reset_relaxed_timeout,
2216
+ )
2217
+
2218
+ if include_free_connections:
2219
+ for conn in self._available_connections:
2220
+ if self._should_update_connection(
2221
+ conn,
2222
+ matching_pattern,
2223
+ matching_address,
2224
+ matching_notification_hash,
2225
+ ):
2226
+ self.update_connection_settings(
2227
+ conn,
2228
+ state=state,
2229
+ maintenance_notification_hash=maintenance_notification_hash,
2230
+ host_address=host_address,
2231
+ relaxed_timeout=relaxed_timeout,
2232
+ update_notification_hash=update_notification_hash,
2233
+ reset_host_address=reset_host_address,
2234
+ reset_relaxed_timeout=reset_relaxed_timeout,
2235
+ )
2236
+
2237
+ def update_connection_kwargs(
2238
+ self,
2239
+ **kwargs,
2240
+ ):
2241
+ """
2242
+ Update the connection kwargs for all future connections.
2243
+
2244
+ This method updates the connection kwargs for all future connections created by the pool.
2245
+ Existing connections are not affected.
2246
+ """
2247
+ self.connection_kwargs.update(kwargs)
2248
+
2249
+ def update_active_connections_for_reconnect(
2250
+ self,
2251
+ moving_address_src: Optional[str] = None,
2252
+ ):
2253
+ """
2254
+ Mark all active connections for reconnect.
2255
+ This is used when a cluster node is migrated to a different address.
2256
+
2257
+ :param moving_address_src: The address of the node that is being moved.
2258
+ """
2259
+ with self._lock:
2260
+ for conn in self._in_use_connections:
2261
+ if self._should_update_connection(
2262
+ conn, "connected_address", moving_address_src
2263
+ ):
2264
+ conn.mark_for_reconnect()
2265
+
2266
+ def disconnect_free_connections(
2267
+ self,
2268
+ moving_address_src: Optional[str] = None,
2269
+ ):
2270
+ """
2271
+ Disconnect all free/available connections.
2272
+ This is used when a cluster node is migrated to a different address.
2273
+
2274
+ :param moving_address_src: The address of the node that is being moved.
2275
+ """
2276
+ with self._lock:
2277
+ for conn in self._available_connections:
2278
+ if self._should_update_connection(
2279
+ conn, "connected_address", moving_address_src
2280
+ ):
2281
+ conn.disconnect()
2282
+
1649
2283
  async def _mock(self, error: RedisError):
1650
2284
  """
1651
2285
  Dummy functions, needs to be passed as error callback to retry object.
@@ -1699,6 +2333,8 @@ class BlockingConnectionPool(ConnectionPool):
1699
2333
  ):
1700
2334
  self.queue_class = queue_class
1701
2335
  self.timeout = timeout
2336
+ self._in_maintenance = False
2337
+ self._locked = False
1702
2338
  super().__init__(
1703
2339
  connection_class=connection_class,
1704
2340
  max_connections=max_connections,
@@ -1707,16 +2343,27 @@ class BlockingConnectionPool(ConnectionPool):
1707
2343
 
1708
2344
  def reset(self):
1709
2345
  # Create and fill up a thread safe queue with ``None`` values.
1710
- self.pool = self.queue_class(self.max_connections)
1711
- while True:
1712
- try:
1713
- self.pool.put_nowait(None)
1714
- except Full:
1715
- break
2346
+ try:
2347
+ if self._in_maintenance:
2348
+ self._lock.acquire()
2349
+ self._locked = True
2350
+ self.pool = self.queue_class(self.max_connections)
2351
+ while True:
2352
+ try:
2353
+ self.pool.put_nowait(None)
2354
+ except Full:
2355
+ break
1716
2356
 
1717
- # Keep a list of actual connection instances so that we can
1718
- # disconnect them later.
1719
- self._connections = []
2357
+ # Keep a list of actual connection instances so that we can
2358
+ # disconnect them later.
2359
+ self._connections = []
2360
+ finally:
2361
+ if self._locked:
2362
+ try:
2363
+ self._lock.release()
2364
+ except Exception:
2365
+ pass
2366
+ self._locked = False
1720
2367
 
1721
2368
  # this must be the last operation in this method. while reset() is
1722
2369
  # called when holding _fork_lock, other threads in this process
@@ -1731,14 +2378,28 @@ class BlockingConnectionPool(ConnectionPool):
1731
2378
 
1732
2379
  def make_connection(self):
1733
2380
  "Make a fresh connection."
1734
- if self.cache is not None:
1735
- connection = CacheProxyConnection(
1736
- self.connection_class(**self.connection_kwargs), self.cache, self._lock
1737
- )
1738
- else:
1739
- connection = self.connection_class(**self.connection_kwargs)
1740
- self._connections.append(connection)
1741
- return connection
2381
+ try:
2382
+ if self._in_maintenance:
2383
+ self._lock.acquire()
2384
+ self._locked = True
2385
+
2386
+ if self.cache is not None:
2387
+ connection = CacheProxyConnection(
2388
+ self.connection_class(**self.connection_kwargs),
2389
+ self.cache,
2390
+ self._lock,
2391
+ )
2392
+ else:
2393
+ connection = self.connection_class(**self.connection_kwargs)
2394
+ self._connections.append(connection)
2395
+ return connection
2396
+ finally:
2397
+ if self._locked:
2398
+ try:
2399
+ self._lock.release()
2400
+ except Exception:
2401
+ pass
2402
+ self._locked = False
1742
2403
 
1743
2404
  @deprecated_args(
1744
2405
  args_to_warn=["*"],
@@ -1764,16 +2425,27 @@ class BlockingConnectionPool(ConnectionPool):
1764
2425
  # self.timeout then raise a ``ConnectionError``.
1765
2426
  connection = None
1766
2427
  try:
1767
- connection = self.pool.get(block=True, timeout=self.timeout)
1768
- except Empty:
1769
- # Note that this is not caught by the redis client and will be
1770
- # raised unless handled by application code. If you want never to
1771
- raise ConnectionError("No connection available.")
1772
-
1773
- # If the ``connection`` is actually ``None`` then that's a cue to make
1774
- # a new connection to add to the pool.
1775
- if connection is None:
1776
- connection = self.make_connection()
2428
+ if self._in_maintenance:
2429
+ self._lock.acquire()
2430
+ self._locked = True
2431
+ try:
2432
+ connection = self.pool.get(block=True, timeout=self.timeout)
2433
+ except Empty:
2434
+ # Note that this is not caught by the redis client and will be
2435
+ # raised unless handled by application code. If you want never to
2436
+ raise ConnectionError("No connection available.")
2437
+
2438
+ # If the ``connection`` is actually ``None`` then that's a cue to make
2439
+ # a new connection to add to the pool.
2440
+ if connection is None:
2441
+ connection = self.make_connection()
2442
+ finally:
2443
+ if self._locked:
2444
+ try:
2445
+ self._lock.release()
2446
+ except Exception:
2447
+ pass
2448
+ self._locked = False
1777
2449
 
1778
2450
  try:
1779
2451
  # ensure this connection is connected to Redis
@@ -1801,25 +2473,177 @@ class BlockingConnectionPool(ConnectionPool):
1801
2473
  "Releases the connection back to the pool."
1802
2474
  # Make sure we haven't changed process.
1803
2475
  self._checkpid()
1804
- if not self.owns_connection(connection):
1805
- # pool doesn't own this connection. do not add it back
1806
- # to the pool. instead add a None value which is a placeholder
1807
- # that will cause the pool to recreate the connection if
1808
- # its needed.
1809
- connection.disconnect()
1810
- self.pool.put_nowait(None)
1811
- return
1812
2476
 
1813
- # Put the connection back into the pool.
1814
2477
  try:
1815
- self.pool.put_nowait(connection)
1816
- except Full:
1817
- # perhaps the pool has been reset() after a fork? regardless,
1818
- # we don't want this connection
1819
- pass
2478
+ if self._in_maintenance:
2479
+ self._lock.acquire()
2480
+ self._locked = True
2481
+ if not self.owns_connection(connection):
2482
+ # pool doesn't own this connection. do not add it back
2483
+ # to the pool. instead add a None value which is a placeholder
2484
+ # that will cause the pool to recreate the connection if
2485
+ # its needed.
2486
+ connection.disconnect()
2487
+ self.pool.put_nowait(None)
2488
+ return
2489
+ if connection.should_reconnect():
2490
+ connection.disconnect()
2491
+ # Put the connection back into the pool.
2492
+ try:
2493
+ self.pool.put_nowait(connection)
2494
+ except Full:
2495
+ # perhaps the pool has been reset() after a fork? regardless,
2496
+ # we don't want this connection
2497
+ pass
2498
+ finally:
2499
+ if self._locked:
2500
+ try:
2501
+ self._lock.release()
2502
+ except Exception:
2503
+ pass
2504
+ self._locked = False
1820
2505
 
1821
2506
  def disconnect(self):
1822
2507
  "Disconnects all connections in the pool."
1823
2508
  self._checkpid()
1824
- for connection in self._connections:
1825
- connection.disconnect()
2509
+ try:
2510
+ if self._in_maintenance:
2511
+ self._lock.acquire()
2512
+ self._locked = True
2513
+ for connection in self._connections:
2514
+ connection.disconnect()
2515
+ finally:
2516
+ if self._locked:
2517
+ try:
2518
+ self._lock.release()
2519
+ except Exception:
2520
+ pass
2521
+ self._locked = False
2522
+
2523
+ def update_connections_settings(
2524
+ self,
2525
+ state: Optional["MaintenanceState"] = None,
2526
+ maintenance_notification_hash: Optional[int] = None,
2527
+ relaxed_timeout: Optional[float] = None,
2528
+ host_address: Optional[str] = None,
2529
+ matching_address: Optional[str] = None,
2530
+ matching_notification_hash: Optional[int] = None,
2531
+ matching_pattern: Literal[
2532
+ "connected_address", "configured_address", "notification_hash"
2533
+ ] = "connected_address",
2534
+ update_notification_hash: bool = False,
2535
+ reset_host_address: bool = False,
2536
+ reset_relaxed_timeout: bool = False,
2537
+ include_free_connections: bool = True,
2538
+ ):
2539
+ """
2540
+ Override base class method to work with BlockingConnectionPool's structure.
2541
+ """
2542
+ with self._lock:
2543
+ if include_free_connections:
2544
+ for conn in tuple(self._connections):
2545
+ if self._should_update_connection(
2546
+ conn,
2547
+ matching_pattern,
2548
+ matching_address,
2549
+ matching_notification_hash,
2550
+ ):
2551
+ self.update_connection_settings(
2552
+ conn,
2553
+ state=state,
2554
+ maintenance_notification_hash=maintenance_notification_hash,
2555
+ host_address=host_address,
2556
+ relaxed_timeout=relaxed_timeout,
2557
+ update_notification_hash=update_notification_hash,
2558
+ reset_host_address=reset_host_address,
2559
+ reset_relaxed_timeout=reset_relaxed_timeout,
2560
+ )
2561
+ else:
2562
+ connections_in_queue = {conn for conn in self.pool.queue if conn}
2563
+ for conn in self._connections:
2564
+ if conn not in connections_in_queue:
2565
+ if self._should_update_connection(
2566
+ conn,
2567
+ matching_pattern,
2568
+ matching_address,
2569
+ matching_notification_hash,
2570
+ ):
2571
+ self.update_connection_settings(
2572
+ conn,
2573
+ state=state,
2574
+ maintenance_notification_hash=maintenance_notification_hash,
2575
+ host_address=host_address,
2576
+ relaxed_timeout=relaxed_timeout,
2577
+ update_notification_hash=update_notification_hash,
2578
+ reset_host_address=reset_host_address,
2579
+ reset_relaxed_timeout=reset_relaxed_timeout,
2580
+ )
2581
+
2582
+ def update_active_connections_for_reconnect(
2583
+ self,
2584
+ moving_address_src: Optional[str] = None,
2585
+ ):
2586
+ """
2587
+ Mark all active connections for reconnect.
2588
+ This is used when a cluster node is migrated to a different address.
2589
+
2590
+ :param moving_address_src: The address of the node that is being moved.
2591
+ """
2592
+ with self._lock:
2593
+ connections_in_queue = {conn for conn in self.pool.queue if conn}
2594
+ for conn in self._connections:
2595
+ if conn not in connections_in_queue:
2596
+ if self._should_update_connection(
2597
+ conn,
2598
+ matching_pattern="connected_address",
2599
+ matching_address=moving_address_src,
2600
+ ):
2601
+ conn.mark_for_reconnect()
2602
+
2603
+ def disconnect_free_connections(
2604
+ self,
2605
+ moving_address_src: Optional[str] = None,
2606
+ ):
2607
+ """
2608
+ Disconnect all free/available connections.
2609
+ This is used when a cluster node is migrated to a different address.
2610
+
2611
+ :param moving_address_src: The address of the node that is being moved.
2612
+ """
2613
+ with self._lock:
2614
+ existing_connections = self.pool.queue
2615
+
2616
+ for conn in existing_connections:
2617
+ if conn:
2618
+ if self._should_update_connection(
2619
+ conn, "connected_address", moving_address_src
2620
+ ):
2621
+ conn.disconnect()
2622
+
2623
+ def _update_maint_notifications_config_for_connections(
2624
+ self, maint_notifications_config
2625
+ ):
2626
+ for conn in tuple(self._connections):
2627
+ conn.maint_notifications_config = maint_notifications_config
2628
+
2629
+ def _update_maint_notifications_configs_for_connections(
2630
+ self, maint_notifications_pool_handler
2631
+ ):
2632
+ """Update the maintenance notifications config for all connections in the pool."""
2633
+ with self._lock:
2634
+ for conn in tuple(self._connections):
2635
+ conn.set_maint_notifications_pool_handler(
2636
+ maint_notifications_pool_handler
2637
+ )
2638
+ conn.maint_notifications_config = (
2639
+ maint_notifications_pool_handler.config
2640
+ )
2641
+
2642
+ def set_in_maintenance(self, in_maintenance: bool):
2643
+ """
2644
+ Sets a flag that this Blocking ConnectionPool is in maintenance mode.
2645
+
2646
+ This is used to prevent new connections from being created while we are in maintenance mode.
2647
+ The pool will be in maintenance mode only when we are processing a MOVING notification.
2648
+ """
2649
+ self._in_maintenance = in_maintenance