redis 6.3.0__py3-none-any.whl → 7.0.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
redis/connection.py CHANGED
@@ -8,7 +8,18 @@ import weakref
8
8
  from abc import abstractmethod
9
9
  from itertools import chain
10
10
  from queue import Empty, Full, LifoQueue
11
- from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union
11
+ from typing import (
12
+ Any,
13
+ Callable,
14
+ Dict,
15
+ Iterable,
16
+ List,
17
+ Literal,
18
+ Optional,
19
+ Type,
20
+ TypeVar,
21
+ Union,
22
+ )
12
23
  from urllib.parse import parse_qs, unquote, urlparse
13
24
 
14
25
  from redis.cache import (
@@ -36,6 +47,12 @@ from .exceptions import (
36
47
  ResponseError,
37
48
  TimeoutError,
38
49
  )
50
+ from .maintenance_events import (
51
+ MaintenanceEventConnectionHandler,
52
+ MaintenanceEventPoolHandler,
53
+ MaintenanceEventsConfig,
54
+ MaintenanceState,
55
+ )
39
56
  from .retry import Retry
40
57
  from .utils import (
41
58
  CRYPTOGRAPHY_AVAILABLE,
@@ -159,6 +176,10 @@ class ConnectionInterface:
159
176
  def set_parser(self, parser_class):
160
177
  pass
161
178
 
179
+ @abstractmethod
180
+ def set_maintenance_event_pool_handler(self, maintenance_event_pool_handler):
181
+ pass
182
+
162
183
  @abstractmethod
163
184
  def get_protocol(self):
164
185
  pass
@@ -222,6 +243,80 @@ class ConnectionInterface:
222
243
  def re_auth(self):
223
244
  pass
224
245
 
246
+ @property
247
+ @abstractmethod
248
+ def maintenance_state(self) -> MaintenanceState:
249
+ """
250
+ Returns the current maintenance state of the connection.
251
+ """
252
+ pass
253
+
254
+ @maintenance_state.setter
255
+ @abstractmethod
256
+ def maintenance_state(self, state: "MaintenanceState"):
257
+ """
258
+ Sets the current maintenance state of the connection.
259
+ """
260
+ pass
261
+
262
+ @abstractmethod
263
+ def getpeername(self):
264
+ """
265
+ Returns the peer name of the connection.
266
+ """
267
+ pass
268
+
269
+ @abstractmethod
270
+ def mark_for_reconnect(self):
271
+ """
272
+ Mark the connection to be reconnected on the next command.
273
+ This is useful when a connection is moved to a different node.
274
+ """
275
+ pass
276
+
277
+ @abstractmethod
278
+ def should_reconnect(self):
279
+ """
280
+ Returns True if the connection should be reconnected.
281
+ """
282
+ pass
283
+
284
+ @abstractmethod
285
+ def get_resolved_ip(self):
286
+ """
287
+ Get resolved ip address for the connection.
288
+ """
289
+ pass
290
+
291
+ @abstractmethod
292
+ def update_current_socket_timeout(self, relax_timeout: Optional[float] = None):
293
+ """
294
+ Update the timeout for the current socket.
295
+ """
296
+ pass
297
+
298
+ @abstractmethod
299
+ def set_tmp_settings(
300
+ self,
301
+ tmp_host_address: Optional[str] = None,
302
+ tmp_relax_timeout: Optional[float] = None,
303
+ ):
304
+ """
305
+ Updates temporary host address and timeout settings for the connection.
306
+ """
307
+ pass
308
+
309
+ @abstractmethod
310
+ def reset_tmp_settings(
311
+ self,
312
+ reset_host_address: bool = False,
313
+ reset_relax_timeout: bool = False,
314
+ ):
315
+ """
316
+ Resets temporary host address and timeout settings for the connection.
317
+ """
318
+ pass
319
+
225
320
 
226
321
  class AbstractConnection(ConnectionInterface):
227
322
  "Manages communication to and from a Redis server"
@@ -233,7 +328,7 @@ class AbstractConnection(ConnectionInterface):
233
328
  socket_timeout: Optional[float] = None,
234
329
  socket_connect_timeout: Optional[float] = None,
235
330
  retry_on_timeout: bool = False,
236
- retry_on_error=SENTINEL,
331
+ retry_on_error: Union[Iterable[Type[Exception]], object] = SENTINEL,
237
332
  encoding: str = "utf-8",
238
333
  encoding_errors: str = "strict",
239
334
  decode_responses: bool = False,
@@ -250,6 +345,13 @@ class AbstractConnection(ConnectionInterface):
250
345
  protocol: Optional[int] = 2,
251
346
  command_packer: Optional[Callable[[], None]] = None,
252
347
  event_dispatcher: Optional[EventDispatcher] = None,
348
+ maintenance_events_pool_handler: Optional[MaintenanceEventPoolHandler] = None,
349
+ maintenance_events_config: Optional[MaintenanceEventsConfig] = None,
350
+ maintenance_state: "MaintenanceState" = MaintenanceState.NONE,
351
+ maintenance_event_hash: Optional[int] = None,
352
+ orig_host_address: Optional[str] = None,
353
+ orig_socket_timeout: Optional[float] = None,
354
+ orig_socket_connect_timeout: Optional[float] = None,
253
355
  ):
254
356
  """
255
357
  Initialize a new Connection.
@@ -283,19 +385,22 @@ class AbstractConnection(ConnectionInterface):
283
385
  self.socket_connect_timeout = socket_connect_timeout
284
386
  self.retry_on_timeout = retry_on_timeout
285
387
  if retry_on_error is SENTINEL:
286
- retry_on_error = []
388
+ retry_on_errors_list = []
389
+ else:
390
+ retry_on_errors_list = list(retry_on_error)
287
391
  if retry_on_timeout:
288
392
  # Add TimeoutError to the errors list to retry on
289
- retry_on_error.append(TimeoutError)
290
- self.retry_on_error = retry_on_error
291
- if retry or retry_on_error:
393
+ retry_on_errors_list.append(TimeoutError)
394
+ self.retry_on_error = retry_on_errors_list
395
+ if retry or self.retry_on_error:
292
396
  if retry is None:
293
397
  self.retry = Retry(NoBackoff(), 1)
294
398
  else:
295
399
  # deep-copy the Retry object as it is mutable
296
400
  self.retry = copy.deepcopy(retry)
297
- # Update the retry's supported errors with the specified errors
298
- self.retry.update_supported_errors(retry_on_error)
401
+ if self.retry_on_error:
402
+ # Update the retry's supported errors with the specified errors
403
+ self.retry.update_supported_errors(self.retry_on_error)
299
404
  else:
300
405
  self.retry = Retry(NoBackoff(), 0)
301
406
  self.health_check_interval = health_check_interval
@@ -305,7 +410,6 @@ class AbstractConnection(ConnectionInterface):
305
410
  self.handshake_metadata = None
306
411
  self._sock = None
307
412
  self._socket_read_size = socket_read_size
308
- self.set_parser(parser_class)
309
413
  self._connect_callbacks = []
310
414
  self._buffer_cutoff = 6000
311
415
  self._re_auth_token: Optional[TokenInterface] = None
@@ -320,6 +424,24 @@ class AbstractConnection(ConnectionInterface):
320
424
  raise ConnectionError("protocol must be either 2 or 3")
321
425
  # p = DEFAULT_RESP_VERSION
322
426
  self.protocol = p
427
+ if self.protocol == 3 and parser_class == DefaultParser:
428
+ parser_class = _RESP3Parser
429
+ self.set_parser(parser_class)
430
+
431
+ self.maintenance_events_config = maintenance_events_config
432
+
433
+ # Set up maintenance events if enabled
434
+ self._configure_maintenance_events(
435
+ maintenance_events_pool_handler,
436
+ orig_host_address,
437
+ orig_socket_timeout,
438
+ orig_socket_connect_timeout,
439
+ )
440
+
441
+ self._should_reconnect = False
442
+ self.maintenance_state = maintenance_state
443
+ self.maintenance_event_hash = maintenance_event_hash
444
+
323
445
  self._command_packer = self._construct_command_packer(command_packer)
324
446
 
325
447
  def __repr__(self):
@@ -375,6 +497,69 @@ class AbstractConnection(ConnectionInterface):
375
497
  """
376
498
  self._parser = parser_class(socket_read_size=self._socket_read_size)
377
499
 
500
+ def _configure_maintenance_events(
501
+ self,
502
+ maintenance_events_pool_handler=None,
503
+ orig_host_address=None,
504
+ orig_socket_timeout=None,
505
+ orig_socket_connect_timeout=None,
506
+ ):
507
+ """Enable maintenance events by setting up handlers and storing original connection parameters."""
508
+ if (
509
+ not self.maintenance_events_config
510
+ or not self.maintenance_events_config.enabled
511
+ ):
512
+ self._maintenance_event_connection_handler = None
513
+ return
514
+
515
+ # Set up pool handler if available
516
+ if maintenance_events_pool_handler:
517
+ self._parser.set_node_moving_push_handler(
518
+ maintenance_events_pool_handler.handle_event
519
+ )
520
+
521
+ # Set up connection handler
522
+ self._maintenance_event_connection_handler = MaintenanceEventConnectionHandler(
523
+ self, self.maintenance_events_config
524
+ )
525
+ self._parser.set_maintenance_push_handler(
526
+ self._maintenance_event_connection_handler.handle_event
527
+ )
528
+
529
+ # Store original connection parameters
530
+ self.orig_host_address = orig_host_address if orig_host_address else self.host
531
+ self.orig_socket_timeout = (
532
+ orig_socket_timeout if orig_socket_timeout else self.socket_timeout
533
+ )
534
+ self.orig_socket_connect_timeout = (
535
+ orig_socket_connect_timeout
536
+ if orig_socket_connect_timeout
537
+ else self.socket_connect_timeout
538
+ )
539
+
540
+ def set_maintenance_event_pool_handler(
541
+ self, maintenance_event_pool_handler: MaintenanceEventPoolHandler
542
+ ):
543
+ maintenance_event_pool_handler.set_connection(self)
544
+ self._parser.set_node_moving_push_handler(
545
+ maintenance_event_pool_handler.handle_event
546
+ )
547
+
548
+ # Update maintenance event connection handler if it doesn't exist
549
+ if not self._maintenance_event_connection_handler:
550
+ self._maintenance_event_connection_handler = (
551
+ MaintenanceEventConnectionHandler(
552
+ self, maintenance_event_pool_handler.config
553
+ )
554
+ )
555
+ self._parser.set_maintenance_push_handler(
556
+ self._maintenance_event_connection_handler.handle_event
557
+ )
558
+ else:
559
+ self._maintenance_event_connection_handler.config = (
560
+ maintenance_event_pool_handler.config
561
+ )
562
+
378
563
  def connect(self):
379
564
  "Connects to the Redis server if not already connected"
380
565
  self.connect_check_health(check_health=True)
@@ -499,6 +684,39 @@ class AbstractConnection(ConnectionInterface):
499
684
  ):
500
685
  raise ConnectionError("Invalid RESP version")
501
686
 
687
+ # Send maintenance notifications handshake if RESP3 is active and maintenance events are enabled
688
+ # and we have a host to determine the endpoint type from
689
+ if (
690
+ self.protocol not in [2, "2"]
691
+ and self.maintenance_events_config
692
+ and self.maintenance_events_config.enabled
693
+ and self._maintenance_event_connection_handler
694
+ and hasattr(self, "host")
695
+ ):
696
+ try:
697
+ endpoint_type = self.maintenance_events_config.get_endpoint_type(
698
+ self.host, self
699
+ )
700
+ self.send_command(
701
+ "CLIENT",
702
+ "MAINT_NOTIFICATIONS",
703
+ "ON",
704
+ "moving-endpoint-type",
705
+ endpoint_type.value,
706
+ check_health=check_health,
707
+ )
708
+ response = self.read_response()
709
+ if str_if_bytes(response) != "OK":
710
+ raise ConnectionError(
711
+ "The server doesn't support maintenance notifications"
712
+ )
713
+ except Exception as e:
714
+ # Log warning but don't fail the connection
715
+ import logging
716
+
717
+ logger = logging.getLogger(__name__)
718
+ logger.warning(f"Failed to enable maintenance notifications: {e}")
719
+
502
720
  # if a client_name is given, set it
503
721
  if self.client_name:
504
722
  self.send_command(
@@ -549,6 +767,8 @@ class AbstractConnection(ConnectionInterface):
549
767
 
550
768
  conn_sock = self._sock
551
769
  self._sock = None
770
+ # reset the reconnect flag
771
+ self._should_reconnect = False
552
772
  if conn_sock is None:
553
773
  return
554
774
 
@@ -626,6 +846,7 @@ class AbstractConnection(ConnectionInterface):
626
846
 
627
847
  try:
628
848
  return self._parser.can_read(timeout)
849
+
629
850
  except OSError as e:
630
851
  self.disconnect()
631
852
  raise ConnectionError(f"Error while reading from {host_error}: {e.args}")
@@ -732,6 +953,110 @@ class AbstractConnection(ConnectionInterface):
732
953
  self.read_response()
733
954
  self._re_auth_token = None
734
955
 
956
+ def get_resolved_ip(self) -> Optional[str]:
957
+ """
958
+ Extract the resolved IP address from an
959
+ established connection or resolve it from the host.
960
+
961
+ First tries to get the actual IP from the socket (most accurate),
962
+ then falls back to DNS resolution if needed.
963
+
964
+ Args:
965
+ connection: The connection object to extract the IP from
966
+
967
+ Returns:
968
+ str: The resolved IP address, or None if it cannot be determined
969
+ """
970
+
971
+ # Method 1: Try to get the actual IP from the established socket connection
972
+ # This is most accurate as it shows the exact IP being used
973
+ try:
974
+ if self._sock is not None:
975
+ peer_addr = self._sock.getpeername()
976
+ if peer_addr and len(peer_addr) >= 1:
977
+ # For TCP sockets, peer_addr is typically (host, port) tuple
978
+ # Return just the host part
979
+ return peer_addr[0]
980
+ except (AttributeError, OSError):
981
+ # Socket might not be connected or getpeername() might fail
982
+ pass
983
+
984
+ # Method 2: Fallback to DNS resolution of the host
985
+ # This is less accurate but works when socket is not available
986
+ try:
987
+ host = getattr(self, "host", "localhost")
988
+ port = getattr(self, "port", 6379)
989
+ if host:
990
+ # Use getaddrinfo to resolve the hostname to IP
991
+ # This mimics what the connection would do during _connect()
992
+ addr_info = socket.getaddrinfo(
993
+ host, port, socket.AF_UNSPEC, socket.SOCK_STREAM
994
+ )
995
+ if addr_info:
996
+ # Return the IP from the first result
997
+ # addr_info[0] is (family, socktype, proto, canonname, sockaddr)
998
+ # sockaddr[0] is the IP address
999
+ return addr_info[0][4][0]
1000
+ except (AttributeError, OSError, socket.gaierror):
1001
+ # DNS resolution might fail
1002
+ pass
1003
+
1004
+ return None
1005
+
1006
+ @property
1007
+ def maintenance_state(self) -> MaintenanceState:
1008
+ return self._maintenance_state
1009
+
1010
+ @maintenance_state.setter
1011
+ def maintenance_state(self, state: "MaintenanceState"):
1012
+ self._maintenance_state = state
1013
+
1014
+ def getpeername(self):
1015
+ if not self._sock:
1016
+ return None
1017
+ return self._sock.getpeername()[0]
1018
+
1019
+ def mark_for_reconnect(self):
1020
+ self._should_reconnect = True
1021
+
1022
+ def should_reconnect(self):
1023
+ return self._should_reconnect
1024
+
1025
+ def update_current_socket_timeout(self, relax_timeout: Optional[float] = None):
1026
+ if self._sock:
1027
+ timeout = relax_timeout if relax_timeout != -1 else self.socket_timeout
1028
+ self._sock.settimeout(timeout)
1029
+ self.update_parser_buffer_timeout(timeout)
1030
+
1031
+ def update_parser_buffer_timeout(self, timeout: Optional[float] = None):
1032
+ if self._parser and self._parser._buffer:
1033
+ self._parser._buffer.socket_timeout = timeout
1034
+
1035
+ def set_tmp_settings(
1036
+ self,
1037
+ tmp_host_address: Optional[Union[str, object]] = SENTINEL,
1038
+ tmp_relax_timeout: Optional[float] = None,
1039
+ ):
1040
+ """
1041
+ The value of SENTINEL is used to indicate that the property should not be updated.
1042
+ """
1043
+ if tmp_host_address is not SENTINEL:
1044
+ self.host = tmp_host_address
1045
+ if tmp_relax_timeout != -1:
1046
+ self.socket_timeout = tmp_relax_timeout
1047
+ self.socket_connect_timeout = tmp_relax_timeout
1048
+
1049
+ def reset_tmp_settings(
1050
+ self,
1051
+ reset_host_address: bool = False,
1052
+ reset_relax_timeout: bool = False,
1053
+ ):
1054
+ if reset_host_address:
1055
+ self.host = self.orig_host_address
1056
+ if reset_relax_timeout:
1057
+ self.socket_timeout = self.orig_socket_timeout
1058
+ self.socket_connect_timeout = self.orig_socket_connect_timeout
1059
+
735
1060
 
736
1061
  class Connection(AbstractConnection):
737
1062
  "Manages TCP communication to and from a Redis server"
@@ -764,6 +1089,7 @@ class Connection(AbstractConnection):
764
1089
  # ipv4/ipv6, but we want to set options prior to calling
765
1090
  # socket.connect()
766
1091
  err = None
1092
+
767
1093
  for res in socket.getaddrinfo(
768
1094
  self.host, self.port, self.socket_type, socket.SOCK_STREAM
769
1095
  ):
@@ -816,7 +1142,7 @@ class CacheProxyConnection(ConnectionInterface):
816
1142
  self,
817
1143
  conn: ConnectionInterface,
818
1144
  cache: CacheInterface,
819
- pool_lock: threading.Lock,
1145
+ pool_lock: threading.RLock,
820
1146
  ):
821
1147
  self.pid = os.getpid()
822
1148
  self._conn = conn
@@ -1053,7 +1379,7 @@ class SSLConnection(Connection):
1053
1379
  ssl_cert_reqs: The string value for the SSLContext.verify_mode (none, optional, required), or an ssl.VerifyMode. Defaults to "required".
1054
1380
  ssl_ca_certs: The path to a file of concatenated CA certificates in PEM format. Defaults to None.
1055
1381
  ssl_ca_data: Either an ASCII string of one or more PEM-encoded certificates or a bytes-like object of DER-encoded certificates.
1056
- ssl_check_hostname: If set, match the hostname during the SSL handshake. Defaults to False.
1382
+ ssl_check_hostname: If set, match the hostname during the SSL handshake. Defaults to True.
1057
1383
  ssl_ca_path: The path to a directory containing several CA certificates in PEM format. Defaults to None.
1058
1384
  ssl_password: Password for unlocking an encrypted private key. Defaults to None.
1059
1385
 
@@ -1394,7 +1720,7 @@ class ConnectionPool:
1394
1720
  self._cache_factory = cache_factory
1395
1721
 
1396
1722
  if connection_kwargs.get("cache_config") or connection_kwargs.get("cache"):
1397
- if connection_kwargs.get("protocol") not in [3, "3"]:
1723
+ if self.connection_kwargs.get("protocol") not in [3, "3"]:
1398
1724
  raise RedisError("Client caching is only supported with RESP version 3")
1399
1725
 
1400
1726
  cache = self.connection_kwargs.get("cache")
@@ -1415,6 +1741,22 @@ class ConnectionPool:
1415
1741
  connection_kwargs.pop("cache", None)
1416
1742
  connection_kwargs.pop("cache_config", None)
1417
1743
 
1744
+ if self.connection_kwargs.get(
1745
+ "maintenance_events_pool_handler"
1746
+ ) or self.connection_kwargs.get("maintenance_events_config"):
1747
+ if self.connection_kwargs.get("protocol") not in [3, "3"]:
1748
+ raise RedisError(
1749
+ "Push handlers on connection are only supported with RESP version 3"
1750
+ )
1751
+ config = self.connection_kwargs.get("maintenance_events_config", None) or (
1752
+ self.connection_kwargs.get("maintenance_events_pool_handler").config
1753
+ if self.connection_kwargs.get("maintenance_events_pool_handler")
1754
+ else None
1755
+ )
1756
+
1757
+ if config and config.enabled:
1758
+ self._update_connection_kwargs_for_maintenance_events()
1759
+
1418
1760
  self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None)
1419
1761
  if self._event_dispatcher is None:
1420
1762
  self._event_dispatcher = EventDispatcher()
@@ -1429,13 +1771,7 @@ class ConnectionPool:
1429
1771
  # release the lock.
1430
1772
 
1431
1773
  self._fork_lock = threading.RLock()
1432
-
1433
- if self.cache is None:
1434
- self._lock = threading.RLock()
1435
- else:
1436
- # TODO: To avoid breaking changes during the bug fix, we have to keep non-reentrant lock.
1437
- # TODO: Remove this before next major version (7.0.0)
1438
- self._lock = threading.Lock()
1774
+ self._lock = threading.RLock()
1439
1775
 
1440
1776
  self.reset()
1441
1777
 
@@ -1455,6 +1791,61 @@ class ConnectionPool:
1455
1791
  """
1456
1792
  return self.connection_kwargs.get("protocol", None)
1457
1793
 
1794
+ def maintenance_events_pool_handler_enabled(self):
1795
+ """
1796
+ Returns:
1797
+ True if the maintenance events pool handler is enabled, False otherwise.
1798
+ """
1799
+ maintenance_events_config = self.connection_kwargs.get(
1800
+ "maintenance_events_config", None
1801
+ )
1802
+
1803
+ return maintenance_events_config and maintenance_events_config.enabled
1804
+
1805
+ def set_maintenance_events_pool_handler(
1806
+ self, maintenance_events_pool_handler: MaintenanceEventPoolHandler
1807
+ ):
1808
+ self.connection_kwargs.update(
1809
+ {
1810
+ "maintenance_events_pool_handler": maintenance_events_pool_handler,
1811
+ "maintenance_events_config": maintenance_events_pool_handler.config,
1812
+ }
1813
+ )
1814
+ self._update_connection_kwargs_for_maintenance_events()
1815
+
1816
+ self._update_maintenance_events_configs_for_connections(
1817
+ maintenance_events_pool_handler
1818
+ )
1819
+
1820
+ def _update_maintenance_events_configs_for_connections(
1821
+ self, maintenance_events_pool_handler
1822
+ ):
1823
+ """Update the maintenance events config for all connections in the pool."""
1824
+ with self._lock:
1825
+ for conn in self._available_connections:
1826
+ conn.set_maintenance_event_pool_handler(maintenance_events_pool_handler)
1827
+ conn.maintenance_events_config = maintenance_events_pool_handler.config
1828
+ for conn in self._in_use_connections:
1829
+ conn.set_maintenance_event_pool_handler(maintenance_events_pool_handler)
1830
+ conn.maintenance_events_config = maintenance_events_pool_handler.config
1831
+
1832
+ def _update_connection_kwargs_for_maintenance_events(self):
1833
+ """Store original connection parameters for maintenance events."""
1834
+ if self.connection_kwargs.get("orig_host_address", None) is None:
1835
+ # If orig_host_address is None it means we haven't
1836
+ # configured the original values yet
1837
+ self.connection_kwargs.update(
1838
+ {
1839
+ "orig_host_address": self.connection_kwargs.get("host"),
1840
+ "orig_socket_timeout": self.connection_kwargs.get(
1841
+ "socket_timeout", None
1842
+ ),
1843
+ "orig_socket_connect_timeout": self.connection_kwargs.get(
1844
+ "socket_connect_timeout", None
1845
+ ),
1846
+ }
1847
+ )
1848
+
1458
1849
  def reset(self) -> None:
1459
1850
  self._created_connections = 0
1460
1851
  self._available_connections = []
@@ -1542,7 +1933,11 @@ class ConnectionPool:
1542
1933
  # pool before all data has been read or the socket has been
1543
1934
  # closed. either way, reconnect and verify everything is good.
1544
1935
  try:
1545
- if connection.can_read() and self.cache is None:
1936
+ if (
1937
+ connection.can_read()
1938
+ and self.cache is None
1939
+ and not self.maintenance_events_pool_handler_enabled()
1940
+ ):
1546
1941
  raise ConnectionError("Connection has data")
1547
1942
  except (ConnectionError, TimeoutError, OSError):
1548
1943
  connection.disconnect()
@@ -1554,7 +1949,6 @@ class ConnectionPool:
1554
1949
  # leak it
1555
1950
  self.release(connection)
1556
1951
  raise
1557
-
1558
1952
  return connection
1559
1953
 
1560
1954
  def get_encoder(self) -> Encoder:
@@ -1572,12 +1966,13 @@ class ConnectionPool:
1572
1966
  raise MaxConnectionsError("Too many connections")
1573
1967
  self._created_connections += 1
1574
1968
 
1969
+ kwargs = dict(self.connection_kwargs)
1970
+
1575
1971
  if self.cache is not None:
1576
1972
  return CacheProxyConnection(
1577
- self.connection_class(**self.connection_kwargs), self.cache, self._lock
1973
+ self.connection_class(**kwargs), self.cache, self._lock
1578
1974
  )
1579
-
1580
- return self.connection_class(**self.connection_kwargs)
1975
+ return self.connection_class(**kwargs)
1581
1976
 
1582
1977
  def release(self, connection: "Connection") -> None:
1583
1978
  "Releases the connection back to the pool"
@@ -1591,6 +1986,8 @@ class ConnectionPool:
1591
1986
  return
1592
1987
 
1593
1988
  if self.owns_connection(connection):
1989
+ if connection.should_reconnect():
1990
+ connection.disconnect()
1594
1991
  self._available_connections.append(connection)
1595
1992
  self._event_dispatcher.dispatch(
1596
1993
  AfterConnectionReleasedEvent(connection)
@@ -1652,6 +2049,186 @@ class ConnectionPool:
1652
2049
  for conn in self._in_use_connections:
1653
2050
  conn.set_re_auth_token(token)
1654
2051
 
2052
+ def _should_update_connection(
2053
+ self,
2054
+ conn: "Connection",
2055
+ matching_pattern: Literal[
2056
+ "connected_address", "configured_address", "event_hash"
2057
+ ] = "connected_address",
2058
+ matching_address: Optional[str] = None,
2059
+ matching_event_hash: Optional[int] = None,
2060
+ ) -> bool:
2061
+ """
2062
+ Check if the connection should be updated based on the matching address.
2063
+ """
2064
+ if matching_pattern == "connected_address":
2065
+ if matching_address and conn.getpeername() != matching_address:
2066
+ return False
2067
+ elif matching_pattern == "configured_address":
2068
+ if matching_address and conn.host != matching_address:
2069
+ return False
2070
+ elif matching_pattern == "event_hash":
2071
+ if (
2072
+ matching_event_hash
2073
+ and conn.maintenance_event_hash != matching_event_hash
2074
+ ):
2075
+ return False
2076
+ return True
2077
+
2078
+ def update_connection_settings(
2079
+ self,
2080
+ conn: "Connection",
2081
+ state: Optional["MaintenanceState"] = None,
2082
+ maintenance_event_hash: Optional[int] = None,
2083
+ host_address: Optional[str] = None,
2084
+ relax_timeout: Optional[float] = None,
2085
+ update_event_hash: bool = False,
2086
+ reset_host_address: bool = False,
2087
+ reset_relax_timeout: bool = False,
2088
+ ):
2089
+ """
2090
+ Update the settings for a single connection.
2091
+ """
2092
+ if state:
2093
+ conn.maintenance_state = state
2094
+
2095
+ if update_event_hash:
2096
+ # update the event hash only if requested
2097
+ conn.maintenance_event_hash = maintenance_event_hash
2098
+
2099
+ if host_address is not None:
2100
+ conn.set_tmp_settings(tmp_host_address=host_address)
2101
+
2102
+ if relax_timeout is not None:
2103
+ conn.set_tmp_settings(tmp_relax_timeout=relax_timeout)
2104
+
2105
+ if reset_relax_timeout or reset_host_address:
2106
+ conn.reset_tmp_settings(
2107
+ reset_host_address=reset_host_address,
2108
+ reset_relax_timeout=reset_relax_timeout,
2109
+ )
2110
+
2111
+ conn.update_current_socket_timeout(relax_timeout)
2112
+
2113
+ def update_connections_settings(
2114
+ self,
2115
+ state: Optional["MaintenanceState"] = None,
2116
+ maintenance_event_hash: Optional[int] = None,
2117
+ host_address: Optional[str] = None,
2118
+ relax_timeout: Optional[float] = None,
2119
+ matching_address: Optional[str] = None,
2120
+ matching_event_hash: Optional[int] = None,
2121
+ matching_pattern: Literal[
2122
+ "connected_address", "configured_address", "event_hash"
2123
+ ] = "connected_address",
2124
+ update_event_hash: bool = False,
2125
+ reset_host_address: bool = False,
2126
+ reset_relax_timeout: bool = False,
2127
+ include_free_connections: bool = True,
2128
+ ):
2129
+ """
2130
+ Update the settings for all matching connections in the pool.
2131
+
2132
+ This method does not create new connections.
2133
+ This method does not affect the connection kwargs.
2134
+
2135
+ :param state: The maintenance state to set for the connection.
2136
+ :param maintenance_event_hash: The hash of the maintenance event
2137
+ to set for the connection.
2138
+ :param host_address: The host address to set for the connection.
2139
+ :param relax_timeout: The relax timeout to set for the connection.
2140
+ :param matching_address: The address to match for the connection.
2141
+ :param matching_event_hash: The event hash to match for the connection.
2142
+ :param matching_pattern: The pattern to match for the connection.
2143
+ :param update_event_hash: Whether to update the event hash for the connection.
2144
+ :param reset_host_address: Whether to reset the host address to the original address.
2145
+ :param reset_relax_timeout: Whether to reset the relax timeout to the original timeout.
2146
+ :param include_free_connections: Whether to include free/available connections.
2147
+ """
2148
+ with self._lock:
2149
+ for conn in self._in_use_connections:
2150
+ if self._should_update_connection(
2151
+ conn,
2152
+ matching_pattern,
2153
+ matching_address,
2154
+ matching_event_hash,
2155
+ ):
2156
+ self.update_connection_settings(
2157
+ conn,
2158
+ state=state,
2159
+ maintenance_event_hash=maintenance_event_hash,
2160
+ host_address=host_address,
2161
+ relax_timeout=relax_timeout,
2162
+ update_event_hash=update_event_hash,
2163
+ reset_host_address=reset_host_address,
2164
+ reset_relax_timeout=reset_relax_timeout,
2165
+ )
2166
+
2167
+ if include_free_connections:
2168
+ for conn in self._available_connections:
2169
+ if self._should_update_connection(
2170
+ conn,
2171
+ matching_pattern,
2172
+ matching_address,
2173
+ matching_event_hash,
2174
+ ):
2175
+ self.update_connection_settings(
2176
+ conn,
2177
+ state=state,
2178
+ maintenance_event_hash=maintenance_event_hash,
2179
+ host_address=host_address,
2180
+ relax_timeout=relax_timeout,
2181
+ update_event_hash=update_event_hash,
2182
+ reset_host_address=reset_host_address,
2183
+ reset_relax_timeout=reset_relax_timeout,
2184
+ )
2185
+
2186
+ def update_connection_kwargs(
2187
+ self,
2188
+ **kwargs,
2189
+ ):
2190
+ """
2191
+ Update the connection kwargs for all future connections.
2192
+
2193
+ This method updates the connection kwargs for all future connections created by the pool.
2194
+ Existing connections are not affected.
2195
+ """
2196
+ self.connection_kwargs.update(kwargs)
2197
+
2198
+ def update_active_connections_for_reconnect(
2199
+ self,
2200
+ moving_address_src: Optional[str] = None,
2201
+ ):
2202
+ """
2203
+ Mark all active connections for reconnect.
2204
+ This is used when a cluster node is migrated to a different address.
2205
+
2206
+ :param moving_address_src: The address of the node that is being moved.
2207
+ """
2208
+ with self._lock:
2209
+ for conn in self._in_use_connections:
2210
+ if self._should_update_connection(
2211
+ conn, "connected_address", moving_address_src
2212
+ ):
2213
+ conn.mark_for_reconnect()
2214
+
2215
+ def disconnect_free_connections(
2216
+ self,
2217
+ moving_address_src: Optional[str] = None,
2218
+ ):
2219
+ """
2220
+ Disconnect all free/available connections.
2221
+ This is used when a cluster node is migrated to a different address.
2222
+
2223
+ :param moving_address_src: The address of the node that is being moved.
2224
+ """
2225
+ with self._lock:
2226
+ for conn in self._available_connections:
2227
+ if self._should_update_connection(
2228
+ conn, "connected_address", moving_address_src
2229
+ ):
2230
+ conn.disconnect()
2231
+
1655
2232
  async def _mock(self, error: RedisError):
1656
2233
  """
1657
2234
  Dummy functions, needs to be passed as error callback to retry object.
@@ -1705,6 +2282,8 @@ class BlockingConnectionPool(ConnectionPool):
1705
2282
  ):
1706
2283
  self.queue_class = queue_class
1707
2284
  self.timeout = timeout
2285
+ self._in_maintenance = False
2286
+ self._locked = False
1708
2287
  super().__init__(
1709
2288
  connection_class=connection_class,
1710
2289
  max_connections=max_connections,
@@ -1713,16 +2292,27 @@ class BlockingConnectionPool(ConnectionPool):
1713
2292
 
1714
2293
  def reset(self):
1715
2294
  # Create and fill up a thread safe queue with ``None`` values.
1716
- self.pool = self.queue_class(self.max_connections)
1717
- while True:
1718
- try:
1719
- self.pool.put_nowait(None)
1720
- except Full:
1721
- break
2295
+ try:
2296
+ if self._in_maintenance:
2297
+ self._lock.acquire()
2298
+ self._locked = True
2299
+ self.pool = self.queue_class(self.max_connections)
2300
+ while True:
2301
+ try:
2302
+ self.pool.put_nowait(None)
2303
+ except Full:
2304
+ break
1722
2305
 
1723
- # Keep a list of actual connection instances so that we can
1724
- # disconnect them later.
1725
- self._connections = []
2306
+ # Keep a list of actual connection instances so that we can
2307
+ # disconnect them later.
2308
+ self._connections = []
2309
+ finally:
2310
+ if self._locked:
2311
+ try:
2312
+ self._lock.release()
2313
+ except Exception:
2314
+ pass
2315
+ self._locked = False
1726
2316
 
1727
2317
  # this must be the last operation in this method. while reset() is
1728
2318
  # called when holding _fork_lock, other threads in this process
@@ -1737,14 +2327,28 @@ class BlockingConnectionPool(ConnectionPool):
1737
2327
 
1738
2328
  def make_connection(self):
1739
2329
  "Make a fresh connection."
1740
- if self.cache is not None:
1741
- connection = CacheProxyConnection(
1742
- self.connection_class(**self.connection_kwargs), self.cache, self._lock
1743
- )
1744
- else:
1745
- connection = self.connection_class(**self.connection_kwargs)
1746
- self._connections.append(connection)
1747
- return connection
2330
+ try:
2331
+ if self._in_maintenance:
2332
+ self._lock.acquire()
2333
+ self._locked = True
2334
+
2335
+ if self.cache is not None:
2336
+ connection = CacheProxyConnection(
2337
+ self.connection_class(**self.connection_kwargs),
2338
+ self.cache,
2339
+ self._lock,
2340
+ )
2341
+ else:
2342
+ connection = self.connection_class(**self.connection_kwargs)
2343
+ self._connections.append(connection)
2344
+ return connection
2345
+ finally:
2346
+ if self._locked:
2347
+ try:
2348
+ self._lock.release()
2349
+ except Exception:
2350
+ pass
2351
+ self._locked = False
1748
2352
 
1749
2353
  @deprecated_args(
1750
2354
  args_to_warn=["*"],
@@ -1770,16 +2374,27 @@ class BlockingConnectionPool(ConnectionPool):
1770
2374
  # self.timeout then raise a ``ConnectionError``.
1771
2375
  connection = None
1772
2376
  try:
1773
- connection = self.pool.get(block=True, timeout=self.timeout)
1774
- except Empty:
1775
- # Note that this is not caught by the redis client and will be
1776
- # raised unless handled by application code. If you want never to
1777
- raise ConnectionError("No connection available.")
1778
-
1779
- # If the ``connection`` is actually ``None`` then that's a cue to make
1780
- # a new connection to add to the pool.
1781
- if connection is None:
1782
- connection = self.make_connection()
2377
+ if self._in_maintenance:
2378
+ self._lock.acquire()
2379
+ self._locked = True
2380
+ try:
2381
+ connection = self.pool.get(block=True, timeout=self.timeout)
2382
+ except Empty:
2383
+ # Note that this is not caught by the redis client and will be
2384
+ # raised unless handled by application code. If you want never to
2385
+ raise ConnectionError("No connection available.")
2386
+
2387
+ # If the ``connection`` is actually ``None`` then that's a cue to make
2388
+ # a new connection to add to the pool.
2389
+ if connection is None:
2390
+ connection = self.make_connection()
2391
+ finally:
2392
+ if self._locked:
2393
+ try:
2394
+ self._lock.release()
2395
+ except Exception:
2396
+ pass
2397
+ self._locked = False
1783
2398
 
1784
2399
  try:
1785
2400
  # ensure this connection is connected to Redis
@@ -1807,25 +2422,173 @@ class BlockingConnectionPool(ConnectionPool):
1807
2422
  "Releases the connection back to the pool."
1808
2423
  # Make sure we haven't changed process.
1809
2424
  self._checkpid()
1810
- if not self.owns_connection(connection):
1811
- # pool doesn't own this connection. do not add it back
1812
- # to the pool. instead add a None value which is a placeholder
1813
- # that will cause the pool to recreate the connection if
1814
- # its needed.
1815
- connection.disconnect()
1816
- self.pool.put_nowait(None)
1817
- return
1818
2425
 
1819
- # Put the connection back into the pool.
1820
2426
  try:
1821
- self.pool.put_nowait(connection)
1822
- except Full:
1823
- # perhaps the pool has been reset() after a fork? regardless,
1824
- # we don't want this connection
1825
- pass
2427
+ if self._in_maintenance:
2428
+ self._lock.acquire()
2429
+ self._locked = True
2430
+ if not self.owns_connection(connection):
2431
+ # pool doesn't own this connection. do not add it back
2432
+ # to the pool. instead add a None value which is a placeholder
2433
+ # that will cause the pool to recreate the connection if
2434
+ # its needed.
2435
+ connection.disconnect()
2436
+ self.pool.put_nowait(None)
2437
+ return
2438
+ if connection.should_reconnect():
2439
+ connection.disconnect()
2440
+ # Put the connection back into the pool.
2441
+ try:
2442
+ self.pool.put_nowait(connection)
2443
+ except Full:
2444
+ # perhaps the pool has been reset() after a fork? regardless,
2445
+ # we don't want this connection
2446
+ pass
2447
+ finally:
2448
+ if self._locked:
2449
+ try:
2450
+ self._lock.release()
2451
+ except Exception:
2452
+ pass
2453
+ self._locked = False
1826
2454
 
1827
2455
  def disconnect(self):
1828
2456
  "Disconnects all connections in the pool."
1829
2457
  self._checkpid()
1830
- for connection in self._connections:
1831
- connection.disconnect()
2458
+ try:
2459
+ if self._in_maintenance:
2460
+ self._lock.acquire()
2461
+ self._locked = True
2462
+ for connection in self._connections:
2463
+ connection.disconnect()
2464
+ finally:
2465
+ if self._locked:
2466
+ try:
2467
+ self._lock.release()
2468
+ except Exception:
2469
+ pass
2470
+ self._locked = False
2471
+
2472
+ def update_connections_settings(
2473
+ self,
2474
+ state: Optional["MaintenanceState"] = None,
2475
+ maintenance_event_hash: Optional[int] = None,
2476
+ relax_timeout: Optional[float] = None,
2477
+ host_address: Optional[str] = None,
2478
+ matching_address: Optional[str] = None,
2479
+ matching_event_hash: Optional[int] = None,
2480
+ matching_pattern: Literal[
2481
+ "connected_address", "configured_address", "event_hash"
2482
+ ] = "connected_address",
2483
+ update_event_hash: bool = False,
2484
+ reset_host_address: bool = False,
2485
+ reset_relax_timeout: bool = False,
2486
+ include_free_connections: bool = True,
2487
+ ):
2488
+ """
2489
+ Override base class method to work with BlockingConnectionPool's structure.
2490
+ """
2491
+ with self._lock:
2492
+ if include_free_connections:
2493
+ for conn in tuple(self._connections):
2494
+ if self._should_update_connection(
2495
+ conn,
2496
+ matching_pattern,
2497
+ matching_address,
2498
+ matching_event_hash,
2499
+ ):
2500
+ self.update_connection_settings(
2501
+ conn,
2502
+ state=state,
2503
+ maintenance_event_hash=maintenance_event_hash,
2504
+ host_address=host_address,
2505
+ relax_timeout=relax_timeout,
2506
+ update_event_hash=update_event_hash,
2507
+ reset_host_address=reset_host_address,
2508
+ reset_relax_timeout=reset_relax_timeout,
2509
+ )
2510
+ else:
2511
+ connections_in_queue = {conn for conn in self.pool.queue if conn}
2512
+ for conn in self._connections:
2513
+ if conn not in connections_in_queue:
2514
+ if self._should_update_connection(
2515
+ conn,
2516
+ matching_pattern,
2517
+ matching_address,
2518
+ matching_event_hash,
2519
+ ):
2520
+ self.update_connection_settings(
2521
+ conn,
2522
+ state=state,
2523
+ maintenance_event_hash=maintenance_event_hash,
2524
+ host_address=host_address,
2525
+ relax_timeout=relax_timeout,
2526
+ update_event_hash=update_event_hash,
2527
+ reset_host_address=reset_host_address,
2528
+ reset_relax_timeout=reset_relax_timeout,
2529
+ )
2530
+
2531
+ def update_active_connections_for_reconnect(
2532
+ self,
2533
+ moving_address_src: Optional[str] = None,
2534
+ ):
2535
+ """
2536
+ Mark all active connections for reconnect.
2537
+ This is used when a cluster node is migrated to a different address.
2538
+
2539
+ :param moving_address_src: The address of the node that is being moved.
2540
+ """
2541
+ with self._lock:
2542
+ connections_in_queue = {conn for conn in self.pool.queue if conn}
2543
+ for conn in self._connections:
2544
+ if conn not in connections_in_queue:
2545
+ if self._should_update_connection(
2546
+ conn,
2547
+ matching_pattern="connected_address",
2548
+ matching_address=moving_address_src,
2549
+ ):
2550
+ conn.mark_for_reconnect()
2551
+
2552
+ def disconnect_free_connections(
2553
+ self,
2554
+ moving_address_src: Optional[str] = None,
2555
+ ):
2556
+ """
2557
+ Disconnect all free/available connections.
2558
+ This is used when a cluster node is migrated to a different address.
2559
+
2560
+ :param moving_address_src: The address of the node that is being moved.
2561
+ """
2562
+ with self._lock:
2563
+ existing_connections = self.pool.queue
2564
+
2565
+ for conn in existing_connections:
2566
+ if conn:
2567
+ if self._should_update_connection(
2568
+ conn, "connected_address", moving_address_src
2569
+ ):
2570
+ conn.disconnect()
2571
+
2572
+ def _update_maintenance_events_config_for_connections(
2573
+ self, maintenance_events_config
2574
+ ):
2575
+ for conn in tuple(self._connections):
2576
+ conn.maintenance_events_config = maintenance_events_config
2577
+
2578
+ def _update_maintenance_events_configs_for_connections(
2579
+ self, maintenance_events_pool_handler
2580
+ ):
2581
+ """Update the maintenance events config for all connections in the pool."""
2582
+ with self._lock:
2583
+ for conn in tuple(self._connections):
2584
+ conn.set_maintenance_event_pool_handler(maintenance_events_pool_handler)
2585
+ conn.maintenance_events_config = maintenance_events_pool_handler.config
2586
+
2587
+ def set_in_maintenance(self, in_maintenance: bool):
2588
+ """
2589
+ Sets a flag that this Blocking ConnectionPool is in maintenance mode.
2590
+
2591
+ This is used to prevent new connections from being created while we are in maintenance mode.
2592
+ The pool will be in maintenance mode only when we are processing a MOVING event.
2593
+ """
2594
+ self._in_maintenance = in_maintenance