redis 6.4.0__py3-none-any.whl → 7.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. redis/__init__.py +1 -1
  2. redis/_parsers/base.py +193 -8
  3. redis/_parsers/helpers.py +64 -6
  4. redis/_parsers/hiredis.py +16 -10
  5. redis/_parsers/resp3.py +11 -5
  6. redis/asyncio/client.py +65 -8
  7. redis/asyncio/cluster.py +57 -5
  8. redis/asyncio/connection.py +62 -2
  9. redis/asyncio/http/__init__.py +0 -0
  10. redis/asyncio/http/http_client.py +265 -0
  11. redis/asyncio/multidb/__init__.py +0 -0
  12. redis/asyncio/multidb/client.py +530 -0
  13. redis/asyncio/multidb/command_executor.py +339 -0
  14. redis/asyncio/multidb/config.py +210 -0
  15. redis/asyncio/multidb/database.py +69 -0
  16. redis/asyncio/multidb/event.py +84 -0
  17. redis/asyncio/multidb/failover.py +125 -0
  18. redis/asyncio/multidb/failure_detector.py +38 -0
  19. redis/asyncio/multidb/healthcheck.py +285 -0
  20. redis/background.py +204 -0
  21. redis/cache.py +1 -0
  22. redis/client.py +97 -16
  23. redis/cluster.py +14 -3
  24. redis/commands/core.py +348 -313
  25. redis/commands/helpers.py +0 -20
  26. redis/commands/json/commands.py +2 -2
  27. redis/commands/search/__init__.py +2 -2
  28. redis/commands/search/aggregation.py +24 -26
  29. redis/commands/search/commands.py +10 -10
  30. redis/commands/search/field.py +2 -2
  31. redis/commands/search/query.py +23 -23
  32. redis/commands/vectorset/__init__.py +1 -1
  33. redis/commands/vectorset/commands.py +43 -25
  34. redis/commands/vectorset/utils.py +40 -4
  35. redis/connection.py +1257 -83
  36. redis/data_structure.py +81 -0
  37. redis/event.py +84 -10
  38. redis/exceptions.py +8 -0
  39. redis/http/__init__.py +0 -0
  40. redis/http/http_client.py +425 -0
  41. redis/maint_notifications.py +810 -0
  42. redis/multidb/__init__.py +0 -0
  43. redis/multidb/circuit.py +144 -0
  44. redis/multidb/client.py +526 -0
  45. redis/multidb/command_executor.py +350 -0
  46. redis/multidb/config.py +207 -0
  47. redis/multidb/database.py +130 -0
  48. redis/multidb/event.py +89 -0
  49. redis/multidb/exception.py +17 -0
  50. redis/multidb/failover.py +125 -0
  51. redis/multidb/failure_detector.py +104 -0
  52. redis/multidb/healthcheck.py +282 -0
  53. redis/retry.py +14 -1
  54. redis/utils.py +34 -0
  55. {redis-6.4.0.dist-info → redis-7.0.0.dist-info}/METADATA +7 -4
  56. redis-7.0.0.dist-info/RECORD +105 -0
  57. redis-6.4.0.dist-info/RECORD +0 -78
  58. {redis-6.4.0.dist-info → redis-7.0.0.dist-info}/WHEEL +0 -0
  59. {redis-6.4.0.dist-info → redis-7.0.0.dist-info}/licenses/LICENSE +0 -0
redis/connection.py CHANGED
@@ -5,10 +5,21 @@ import sys
5
5
  import threading
6
6
  import time
7
7
  import weakref
8
- from abc import abstractmethod
8
+ from abc import ABC, abstractmethod
9
9
  from itertools import chain
10
10
  from queue import Empty, Full, LifoQueue
11
- from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union
11
+ from typing import (
12
+ Any,
13
+ Callable,
14
+ Dict,
15
+ Iterable,
16
+ List,
17
+ Literal,
18
+ Optional,
19
+ Type,
20
+ TypeVar,
21
+ Union,
22
+ )
12
23
  from urllib.parse import parse_qs, unquote, urlparse
13
24
 
14
25
  from redis.cache import (
@@ -36,6 +47,12 @@ from .exceptions import (
36
47
  ResponseError,
37
48
  TimeoutError,
38
49
  )
50
+ from .maint_notifications import (
51
+ MaintenanceState,
52
+ MaintNotificationsConfig,
53
+ MaintNotificationsConnectionHandler,
54
+ MaintNotificationsPoolHandler,
55
+ )
39
56
  from .retry import Retry
40
57
  from .utils import (
41
58
  CRYPTOGRAPHY_AVAILABLE,
@@ -51,8 +68,10 @@ from .utils import (
51
68
 
52
69
  if SSL_AVAILABLE:
53
70
  import ssl
71
+ from ssl import VerifyFlags
54
72
  else:
55
73
  ssl = None
74
+ VerifyFlags = None
56
75
 
57
76
  if HIREDIS_AVAILABLE:
58
77
  import hiredis
@@ -222,8 +241,418 @@ class ConnectionInterface:
222
241
  def re_auth(self):
223
242
  pass
224
243
 
244
+ @abstractmethod
245
+ def mark_for_reconnect(self):
246
+ """
247
+ Mark the connection to be reconnected on the next command.
248
+ This is useful when a connection is moved to a different node.
249
+ """
250
+ pass
251
+
252
+ @abstractmethod
253
+ def should_reconnect(self):
254
+ """
255
+ Returns True if the connection should be reconnected.
256
+ """
257
+ pass
258
+
259
+ @abstractmethod
260
+ def reset_should_reconnect(self):
261
+ """
262
+ Reset the internal flag to False.
263
+ """
264
+ pass
265
+
266
+
267
+ class MaintNotificationsAbstractConnection:
268
+ """
269
+ Abstract class for handling maintenance notifications logic.
270
+ This class is expected to be used as base class together with ConnectionInterface.
271
+
272
+ This class is intended to be used with multiple inheritance!
273
+
274
+ All logic related to maintenance notifications is encapsulated in this class.
275
+ """
276
+
277
+ def __init__(
278
+ self,
279
+ maint_notifications_config: Optional[MaintNotificationsConfig],
280
+ maint_notifications_pool_handler: Optional[
281
+ MaintNotificationsPoolHandler
282
+ ] = None,
283
+ maintenance_state: "MaintenanceState" = MaintenanceState.NONE,
284
+ maintenance_notification_hash: Optional[int] = None,
285
+ orig_host_address: Optional[str] = None,
286
+ orig_socket_timeout: Optional[float] = None,
287
+ orig_socket_connect_timeout: Optional[float] = None,
288
+ parser: Optional[Union[_HiredisParser, _RESP3Parser]] = None,
289
+ ):
290
+ """
291
+ Initialize the maintenance notifications for the connection.
292
+
293
+ Args:
294
+ maint_notifications_config (MaintNotificationsConfig): The configuration for maintenance notifications.
295
+ maint_notifications_pool_handler (Optional[MaintNotificationsPoolHandler]): The pool handler for maintenance notifications.
296
+ maintenance_state (MaintenanceState): The current maintenance state of the connection.
297
+ maintenance_notification_hash (Optional[int]): The current maintenance notification hash of the connection.
298
+ orig_host_address (Optional[str]): The original host address of the connection.
299
+ orig_socket_timeout (Optional[float]): The original socket timeout of the connection.
300
+ orig_socket_connect_timeout (Optional[float]): The original socket connect timeout of the connection.
301
+ parser (Optional[Union[_HiredisParser, _RESP3Parser]]): The parser to use for maintenance notifications.
302
+ If not provided, the parser from the connection is used.
303
+ This is useful when the parser is created after this object.
304
+ """
305
+ self.maint_notifications_config = maint_notifications_config
306
+ self.maintenance_state = maintenance_state
307
+ self.maintenance_notification_hash = maintenance_notification_hash
308
+ self._configure_maintenance_notifications(
309
+ maint_notifications_pool_handler,
310
+ orig_host_address,
311
+ orig_socket_timeout,
312
+ orig_socket_connect_timeout,
313
+ parser,
314
+ )
315
+
316
+ @abstractmethod
317
+ def _get_parser(self) -> Union[_HiredisParser, _RESP3Parser]:
318
+ pass
319
+
320
+ @abstractmethod
321
+ def _get_socket(self) -> Optional[socket.socket]:
322
+ pass
323
+
324
+ @abstractmethod
325
+ def get_protocol(self) -> Union[int, str]:
326
+ """
327
+ Returns:
328
+ The RESP protocol version, or ``None`` if the protocol is not specified,
329
+ in which case the server default will be used.
330
+ """
331
+ pass
332
+
333
+ @property
334
+ @abstractmethod
335
+ def host(self) -> str:
336
+ pass
337
+
338
+ @host.setter
339
+ @abstractmethod
340
+ def host(self, value: str):
341
+ pass
342
+
343
+ @property
344
+ @abstractmethod
345
+ def socket_timeout(self) -> Optional[Union[float, int]]:
346
+ pass
347
+
348
+ @socket_timeout.setter
349
+ @abstractmethod
350
+ def socket_timeout(self, value: Optional[Union[float, int]]):
351
+ pass
352
+
353
+ @property
354
+ @abstractmethod
355
+ def socket_connect_timeout(self) -> Optional[Union[float, int]]:
356
+ pass
357
+
358
+ @socket_connect_timeout.setter
359
+ @abstractmethod
360
+ def socket_connect_timeout(self, value: Optional[Union[float, int]]):
361
+ pass
362
+
363
+ @abstractmethod
364
+ def send_command(self, *args, **kwargs):
365
+ pass
366
+
367
+ @abstractmethod
368
+ def read_response(
369
+ self,
370
+ disable_decoding=False,
371
+ *,
372
+ disconnect_on_error=True,
373
+ push_request=False,
374
+ ):
375
+ pass
376
+
377
+ @abstractmethod
378
+ def disconnect(self, *args):
379
+ pass
380
+
381
+ def _configure_maintenance_notifications(
382
+ self,
383
+ maint_notifications_pool_handler: Optional[
384
+ MaintNotificationsPoolHandler
385
+ ] = None,
386
+ orig_host_address=None,
387
+ orig_socket_timeout=None,
388
+ orig_socket_connect_timeout=None,
389
+ parser: Optional[Union[_HiredisParser, _RESP3Parser]] = None,
390
+ ):
391
+ """
392
+ Enable maintenance notifications by setting up
393
+ handlers and storing original connection parameters.
394
+
395
+ Should be used ONLY with parsers that support push notifications.
396
+ """
397
+ if (
398
+ not self.maint_notifications_config
399
+ or not self.maint_notifications_config.enabled
400
+ ):
401
+ self._maint_notifications_pool_handler = None
402
+ self._maint_notifications_connection_handler = None
403
+ return
404
+
405
+ if not parser:
406
+ raise RedisError(
407
+ "To configure maintenance notifications, a parser must be provided!"
408
+ )
409
+
410
+ if not isinstance(parser, _HiredisParser) and not isinstance(
411
+ parser, _RESP3Parser
412
+ ):
413
+ raise RedisError(
414
+ "Maintenance notifications are only supported with hiredis and RESP3 parsers!"
415
+ )
416
+
417
+ if maint_notifications_pool_handler:
418
+ # Extract a reference to a new pool handler that copies all properties
419
+ # of the original one and has a different connection reference
420
+ # This is needed because when we attach the handler to the parser
421
+ # we need to make sure that the handler has a reference to the
422
+ # connection that the parser is attached to.
423
+ self._maint_notifications_pool_handler = (
424
+ maint_notifications_pool_handler.get_handler_for_connection()
425
+ )
426
+ self._maint_notifications_pool_handler.set_connection(self)
427
+ else:
428
+ self._maint_notifications_pool_handler = None
429
+
430
+ self._maint_notifications_connection_handler = (
431
+ MaintNotificationsConnectionHandler(self, self.maint_notifications_config)
432
+ )
433
+
434
+ # Set up pool handler if available
435
+ if self._maint_notifications_pool_handler:
436
+ parser.set_node_moving_push_handler(
437
+ self._maint_notifications_pool_handler.handle_notification
438
+ )
439
+
440
+ # Set up connection handler
441
+ parser.set_maintenance_push_handler(
442
+ self._maint_notifications_connection_handler.handle_notification
443
+ )
444
+
445
+ # Store original connection parameters
446
+ self.orig_host_address = orig_host_address if orig_host_address else self.host
447
+ self.orig_socket_timeout = (
448
+ orig_socket_timeout if orig_socket_timeout else self.socket_timeout
449
+ )
450
+ self.orig_socket_connect_timeout = (
451
+ orig_socket_connect_timeout
452
+ if orig_socket_connect_timeout
453
+ else self.socket_connect_timeout
454
+ )
455
+
456
+ def set_maint_notifications_pool_handler_for_connection(
457
+ self, maint_notifications_pool_handler: MaintNotificationsPoolHandler
458
+ ):
459
+ # Deep copy the pool handler to avoid sharing the same pool handler
460
+ # between multiple connections, because otherwise each connection will override
461
+ # the connection reference and the pool handler will only hold a reference
462
+ # to the last connection that was set.
463
+ maint_notifications_pool_handler_copy = (
464
+ maint_notifications_pool_handler.get_handler_for_connection()
465
+ )
466
+
467
+ maint_notifications_pool_handler_copy.set_connection(self)
468
+ self._get_parser().set_node_moving_push_handler(
469
+ maint_notifications_pool_handler_copy.handle_notification
470
+ )
471
+
472
+ self._maint_notifications_pool_handler = maint_notifications_pool_handler_copy
225
473
 
226
- class AbstractConnection(ConnectionInterface):
474
+ # Update maintenance notification connection handler if it doesn't exist
475
+ if not self._maint_notifications_connection_handler:
476
+ self._maint_notifications_connection_handler = (
477
+ MaintNotificationsConnectionHandler(
478
+ self, maint_notifications_pool_handler.config
479
+ )
480
+ )
481
+ self._get_parser().set_maintenance_push_handler(
482
+ self._maint_notifications_connection_handler.handle_notification
483
+ )
484
+ else:
485
+ self._maint_notifications_connection_handler.config = (
486
+ maint_notifications_pool_handler.config
487
+ )
488
+
489
+ def activate_maint_notifications_handling_if_enabled(self, check_health=True):
490
+ # Send maintenance notifications handshake if RESP3 is active
491
+ # and maintenance notifications are enabled
492
+ # and we have a host to determine the endpoint type from
493
+ # When the maint_notifications_config enabled mode is "auto",
494
+ # we just log a warning if the handshake fails
495
+ # When the mode is enabled=True, we raise an exception in case of failure
496
+ if (
497
+ self.get_protocol() not in [2, "2"]
498
+ and self.maint_notifications_config
499
+ and self.maint_notifications_config.enabled
500
+ and self._maint_notifications_connection_handler
501
+ and hasattr(self, "host")
502
+ ):
503
+ self._enable_maintenance_notifications(
504
+ maint_notifications_config=self.maint_notifications_config,
505
+ check_health=check_health,
506
+ )
507
+
508
+ def _enable_maintenance_notifications(
509
+ self, maint_notifications_config: MaintNotificationsConfig, check_health=True
510
+ ):
511
+ try:
512
+ host = getattr(self, "host", None)
513
+ if host is None:
514
+ raise ValueError(
515
+ "Cannot enable maintenance notifications for connection"
516
+ " object that doesn't have a host attribute."
517
+ )
518
+ else:
519
+ endpoint_type = maint_notifications_config.get_endpoint_type(host, self)
520
+ self.send_command(
521
+ "CLIENT",
522
+ "MAINT_NOTIFICATIONS",
523
+ "ON",
524
+ "moving-endpoint-type",
525
+ endpoint_type.value,
526
+ check_health=check_health,
527
+ )
528
+ response = self.read_response()
529
+ if not response or str_if_bytes(response) != "OK":
530
+ raise ResponseError(
531
+ "The server doesn't support maintenance notifications"
532
+ )
533
+ except Exception as e:
534
+ if (
535
+ isinstance(e, ResponseError)
536
+ and maint_notifications_config.enabled == "auto"
537
+ ):
538
+ # Log warning but don't fail the connection
539
+ import logging
540
+
541
+ logger = logging.getLogger(__name__)
542
+ logger.warning(f"Failed to enable maintenance notifications: {e}")
543
+ else:
544
+ raise
545
+
546
+ def get_resolved_ip(self) -> Optional[str]:
547
+ """
548
+ Extract the resolved IP address from an
549
+ established connection or resolve it from the host.
550
+
551
+ First tries to get the actual IP from the socket (most accurate),
552
+ then falls back to DNS resolution if needed.
553
+
554
+ Args:
555
+ connection: The connection object to extract the IP from
556
+
557
+ Returns:
558
+ str: The resolved IP address, or None if it cannot be determined
559
+ """
560
+
561
+ # Method 1: Try to get the actual IP from the established socket connection
562
+ # This is most accurate as it shows the exact IP being used
563
+ try:
564
+ conn_socket = self._get_socket()
565
+ if conn_socket is not None:
566
+ peer_addr = conn_socket.getpeername()
567
+ if peer_addr and len(peer_addr) >= 1:
568
+ # For TCP sockets, peer_addr is typically (host, port) tuple
569
+ # Return just the host part
570
+ return peer_addr[0]
571
+ except (AttributeError, OSError):
572
+ # Socket might not be connected or getpeername() might fail
573
+ pass
574
+
575
+ # Method 2: Fallback to DNS resolution of the host
576
+ # This is less accurate but works when socket is not available
577
+ try:
578
+ host = getattr(self, "host", "localhost")
579
+ port = getattr(self, "port", 6379)
580
+ if host:
581
+ # Use getaddrinfo to resolve the hostname to IP
582
+ # This mimics what the connection would do during _connect()
583
+ addr_info = socket.getaddrinfo(
584
+ host, port, socket.AF_UNSPEC, socket.SOCK_STREAM
585
+ )
586
+ if addr_info:
587
+ # Return the IP from the first result
588
+ # addr_info[0] is (family, socktype, proto, canonname, sockaddr)
589
+ # sockaddr[0] is the IP address
590
+ return str(addr_info[0][4][0])
591
+ except (AttributeError, OSError, socket.gaierror):
592
+ # DNS resolution might fail
593
+ pass
594
+
595
+ return None
596
+
597
+ @property
598
+ def maintenance_state(self) -> MaintenanceState:
599
+ return self._maintenance_state
600
+
601
+ @maintenance_state.setter
602
+ def maintenance_state(self, state: "MaintenanceState"):
603
+ self._maintenance_state = state
604
+
605
+ def getpeername(self):
606
+ """
607
+ Returns the peer name of the connection.
608
+ """
609
+ conn_socket = self._get_socket()
610
+ if conn_socket:
611
+ return conn_socket.getpeername()[0]
612
+ return None
613
+
614
+ def update_current_socket_timeout(self, relaxed_timeout: Optional[float] = None):
615
+ conn_socket = self._get_socket()
616
+ if conn_socket:
617
+ timeout = relaxed_timeout if relaxed_timeout != -1 else self.socket_timeout
618
+ conn_socket.settimeout(timeout)
619
+ self.update_parser_timeout(timeout)
620
+
621
+ def update_parser_timeout(self, timeout: Optional[float] = None):
622
+ parser = self._get_parser()
623
+ if parser and parser._buffer:
624
+ if isinstance(parser, _RESP3Parser) and timeout:
625
+ parser._buffer.socket_timeout = timeout
626
+ elif isinstance(parser, _HiredisParser):
627
+ parser._socket_timeout = timeout
628
+
629
+ def set_tmp_settings(
630
+ self,
631
+ tmp_host_address: Optional[Union[str, object]] = SENTINEL,
632
+ tmp_relaxed_timeout: Optional[float] = None,
633
+ ):
634
+ """
635
+ The value of SENTINEL is used to indicate that the property should not be updated.
636
+ """
637
+ if tmp_host_address and tmp_host_address != SENTINEL:
638
+ self.host = str(tmp_host_address)
639
+ if tmp_relaxed_timeout != -1:
640
+ self.socket_timeout = tmp_relaxed_timeout
641
+ self.socket_connect_timeout = tmp_relaxed_timeout
642
+
643
+ def reset_tmp_settings(
644
+ self,
645
+ reset_host_address: bool = False,
646
+ reset_relaxed_timeout: bool = False,
647
+ ):
648
+ if reset_host_address:
649
+ self.host = self.orig_host_address
650
+ if reset_relaxed_timeout:
651
+ self.socket_timeout = self.orig_socket_timeout
652
+ self.socket_connect_timeout = self.orig_socket_connect_timeout
653
+
654
+
655
+ class AbstractConnection(MaintNotificationsAbstractConnection, ConnectionInterface):
227
656
  "Manages communication to and from a Redis server"
228
657
 
229
658
  def __init__(
@@ -233,7 +662,7 @@ class AbstractConnection(ConnectionInterface):
233
662
  socket_timeout: Optional[float] = None,
234
663
  socket_connect_timeout: Optional[float] = None,
235
664
  retry_on_timeout: bool = False,
236
- retry_on_error=SENTINEL,
665
+ retry_on_error: Union[Iterable[Type[Exception]], object] = SENTINEL,
237
666
  encoding: str = "utf-8",
238
667
  encoding_errors: str = "strict",
239
668
  decode_responses: bool = False,
@@ -250,6 +679,15 @@ class AbstractConnection(ConnectionInterface):
250
679
  protocol: Optional[int] = 2,
251
680
  command_packer: Optional[Callable[[], None]] = None,
252
681
  event_dispatcher: Optional[EventDispatcher] = None,
682
+ maint_notifications_config: Optional[MaintNotificationsConfig] = None,
683
+ maint_notifications_pool_handler: Optional[
684
+ MaintNotificationsPoolHandler
685
+ ] = None,
686
+ maintenance_state: "MaintenanceState" = MaintenanceState.NONE,
687
+ maintenance_notification_hash: Optional[int] = None,
688
+ orig_host_address: Optional[str] = None,
689
+ orig_socket_timeout: Optional[float] = None,
690
+ orig_socket_connect_timeout: Optional[float] = None,
253
691
  ):
254
692
  """
255
693
  Initialize a new Connection.
@@ -277,25 +715,28 @@ class AbstractConnection(ConnectionInterface):
277
715
  self.credential_provider = credential_provider
278
716
  self.password = password
279
717
  self.username = username
280
- self.socket_timeout = socket_timeout
718
+ self._socket_timeout = socket_timeout
281
719
  if socket_connect_timeout is None:
282
720
  socket_connect_timeout = socket_timeout
283
- self.socket_connect_timeout = socket_connect_timeout
721
+ self._socket_connect_timeout = socket_connect_timeout
284
722
  self.retry_on_timeout = retry_on_timeout
285
723
  if retry_on_error is SENTINEL:
286
- retry_on_error = []
724
+ retry_on_errors_list = []
725
+ else:
726
+ retry_on_errors_list = list(retry_on_error)
287
727
  if retry_on_timeout:
288
728
  # Add TimeoutError to the errors list to retry on
289
- retry_on_error.append(TimeoutError)
290
- self.retry_on_error = retry_on_error
291
- if retry or retry_on_error:
729
+ retry_on_errors_list.append(TimeoutError)
730
+ self.retry_on_error = retry_on_errors_list
731
+ if retry or self.retry_on_error:
292
732
  if retry is None:
293
733
  self.retry = Retry(NoBackoff(), 1)
294
734
  else:
295
735
  # deep-copy the Retry object as it is mutable
296
736
  self.retry = copy.deepcopy(retry)
297
- # Update the retry's supported errors with the specified errors
298
- self.retry.update_supported_errors(retry_on_error)
737
+ if self.retry_on_error:
738
+ # Update the retry's supported errors with the specified errors
739
+ self.retry.update_supported_errors(self.retry_on_error)
299
740
  else:
300
741
  self.retry = Retry(NoBackoff(), 0)
301
742
  self.health_check_interval = health_check_interval
@@ -305,7 +746,6 @@ class AbstractConnection(ConnectionInterface):
305
746
  self.handshake_metadata = None
306
747
  self._sock = None
307
748
  self._socket_read_size = socket_read_size
308
- self.set_parser(parser_class)
309
749
  self._connect_callbacks = []
310
750
  self._buffer_cutoff = 6000
311
751
  self._re_auth_token: Optional[TokenInterface] = None
@@ -320,7 +760,30 @@ class AbstractConnection(ConnectionInterface):
320
760
  raise ConnectionError("protocol must be either 2 or 3")
321
761
  # p = DEFAULT_RESP_VERSION
322
762
  self.protocol = p
763
+ if self.protocol == 3 and parser_class == _RESP2Parser:
764
+ # If the protocol is 3 but the parser is RESP2, change it to RESP3
765
+ # This is needed because the parser might be set before the protocol
766
+ # or might be provided as a kwarg to the constructor
767
+ # We need to react on discrepancy only for RESP2 and RESP3
768
+ # as hiredis supports both
769
+ parser_class = _RESP3Parser
770
+ self.set_parser(parser_class)
771
+
323
772
  self._command_packer = self._construct_command_packer(command_packer)
773
+ self._should_reconnect = False
774
+
775
+ # Set up maintenance notifications
776
+ MaintNotificationsAbstractConnection.__init__(
777
+ self,
778
+ maint_notifications_config,
779
+ maint_notifications_pool_handler,
780
+ maintenance_state,
781
+ maintenance_notification_hash,
782
+ orig_host_address,
783
+ orig_socket_timeout,
784
+ orig_socket_connect_timeout,
785
+ self._parser,
786
+ )
324
787
 
325
788
  def __repr__(self):
326
789
  repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()])
@@ -375,6 +838,9 @@ class AbstractConnection(ConnectionInterface):
375
838
  """
376
839
  self._parser = parser_class(socket_read_size=self._socket_read_size)
377
840
 
841
+ def _get_parser(self) -> Union[_HiredisParser, _RESP3Parser, _RESP2Parser]:
842
+ return self._parser
843
+
378
844
  def connect(self):
379
845
  "Connects to the Redis server if not already connected"
380
846
  self.connect_check_health(check_health=True)
@@ -499,6 +965,11 @@ class AbstractConnection(ConnectionInterface):
499
965
  ):
500
966
  raise ConnectionError("Invalid RESP version")
501
967
 
968
+ # Activate maintenance notifications for this connection
969
+ # if enabled in the configuration
970
+ # This is a no-op if maintenance notifications are not enabled
971
+ self.activate_maint_notifications_handling_if_enabled(check_health=check_health)
972
+
502
973
  # if a client_name is given, set it
503
974
  if self.client_name:
504
975
  self.send_command(
@@ -549,6 +1020,8 @@ class AbstractConnection(ConnectionInterface):
549
1020
 
550
1021
  conn_sock = self._sock
551
1022
  self._sock = None
1023
+ # reset the reconnect flag
1024
+ self.reset_should_reconnect()
552
1025
  if conn_sock is None:
553
1026
  return
554
1027
 
@@ -563,6 +1036,15 @@ class AbstractConnection(ConnectionInterface):
563
1036
  except OSError:
564
1037
  pass
565
1038
 
1039
+ def mark_for_reconnect(self):
1040
+ self._should_reconnect = True
1041
+
1042
+ def should_reconnect(self):
1043
+ return self._should_reconnect
1044
+
1045
+ def reset_should_reconnect(self):
1046
+ self._should_reconnect = False
1047
+
566
1048
  def _send_ping(self):
567
1049
  """Send PING, expect PONG in return"""
568
1050
  self.send_command("PING", check_health=False)
@@ -626,6 +1108,7 @@ class AbstractConnection(ConnectionInterface):
626
1108
 
627
1109
  try:
628
1110
  return self._parser.can_read(timeout)
1111
+
629
1112
  except OSError as e:
630
1113
  self.disconnect()
631
1114
  raise ConnectionError(f"Error while reading from {host_error}: {e.args}")
@@ -732,6 +1215,25 @@ class AbstractConnection(ConnectionInterface):
732
1215
  self.read_response()
733
1216
  self._re_auth_token = None
734
1217
 
1218
+ def _get_socket(self) -> Optional[socket.socket]:
1219
+ return self._sock
1220
+
1221
+ @property
1222
+ def socket_timeout(self) -> Optional[Union[float, int]]:
1223
+ return self._socket_timeout
1224
+
1225
+ @socket_timeout.setter
1226
+ def socket_timeout(self, value: Optional[Union[float, int]]):
1227
+ self._socket_timeout = value
1228
+
1229
+ @property
1230
+ def socket_connect_timeout(self) -> Optional[Union[float, int]]:
1231
+ return self._socket_connect_timeout
1232
+
1233
+ @socket_connect_timeout.setter
1234
+ def socket_connect_timeout(self, value: Optional[Union[float, int]]):
1235
+ self._socket_connect_timeout = value
1236
+
735
1237
 
736
1238
  class Connection(AbstractConnection):
737
1239
  "Manages TCP communication to and from a Redis server"
@@ -745,7 +1247,7 @@ class Connection(AbstractConnection):
745
1247
  socket_type=0,
746
1248
  **kwargs,
747
1249
  ):
748
- self.host = host
1250
+ self._host = host
749
1251
  self.port = int(port)
750
1252
  self.socket_keepalive = socket_keepalive
751
1253
  self.socket_keepalive_options = socket_keepalive_options or {}
@@ -764,6 +1266,7 @@ class Connection(AbstractConnection):
764
1266
  # ipv4/ipv6, but we want to set options prior to calling
765
1267
  # socket.connect()
766
1268
  err = None
1269
+
767
1270
  for res in socket.getaddrinfo(
768
1271
  self.host, self.port, self.socket_type, socket.SOCK_STREAM
769
1272
  ):
@@ -806,8 +1309,16 @@ class Connection(AbstractConnection):
806
1309
  def _host_error(self):
807
1310
  return f"{self.host}:{self.port}"
808
1311
 
1312
+ @property
1313
+ def host(self) -> str:
1314
+ return self._host
1315
+
1316
+ @host.setter
1317
+ def host(self, value: str):
1318
+ self._host = value
1319
+
809
1320
 
810
- class CacheProxyConnection(ConnectionInterface):
1321
+ class CacheProxyConnection(MaintNotificationsAbstractConnection, ConnectionInterface):
811
1322
  DUMMY_CACHE_VALUE = b"foo"
812
1323
  MIN_ALLOWED_VERSION = "7.4.0"
813
1324
  DEFAULT_SERVER_NAME = "redis"
@@ -831,6 +1342,19 @@ class CacheProxyConnection(ConnectionInterface):
831
1342
  self._current_options = None
832
1343
  self.register_connect_callback(self._enable_tracking_callback)
833
1344
 
1345
+ if isinstance(self._conn, MaintNotificationsAbstractConnection):
1346
+ MaintNotificationsAbstractConnection.__init__(
1347
+ self,
1348
+ self._conn.maint_notifications_config,
1349
+ self._conn._maint_notifications_pool_handler,
1350
+ self._conn.maintenance_state,
1351
+ self._conn.maintenance_notification_hash,
1352
+ self._conn.host,
1353
+ self._conn.socket_timeout,
1354
+ self._conn.socket_connect_timeout,
1355
+ self._conn._get_parser(),
1356
+ )
1357
+
834
1358
  def repr_pieces(self):
835
1359
  return self._conn.repr_pieces()
836
1360
 
@@ -843,6 +1367,17 @@ class CacheProxyConnection(ConnectionInterface):
843
1367
  def set_parser(self, parser_class):
844
1368
  self._conn.set_parser(parser_class)
845
1369
 
1370
+ def set_maint_notifications_pool_handler_for_connection(
1371
+ self, maint_notifications_pool_handler
1372
+ ):
1373
+ if isinstance(self._conn, MaintNotificationsAbstractConnection):
1374
+ self._conn.set_maint_notifications_pool_handler_for_connection(
1375
+ maint_notifications_pool_handler
1376
+ )
1377
+
1378
+ def get_protocol(self):
1379
+ return self._conn.get_protocol()
1380
+
846
1381
  def connect(self):
847
1382
  self._conn.connect()
848
1383
 
@@ -988,6 +1523,109 @@ class CacheProxyConnection(ConnectionInterface):
988
1523
  def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]:
989
1524
  return self._conn.handshake_metadata
990
1525
 
1526
+ def set_re_auth_token(self, token: TokenInterface):
1527
+ self._conn.set_re_auth_token(token)
1528
+
1529
+ def re_auth(self):
1530
+ self._conn.re_auth()
1531
+
1532
+ def mark_for_reconnect(self):
1533
+ self._conn.mark_for_reconnect()
1534
+
1535
+ def should_reconnect(self):
1536
+ return self._conn.should_reconnect()
1537
+
1538
+ def reset_should_reconnect(self):
1539
+ self._conn.reset_should_reconnect()
1540
+
1541
+ @property
1542
+ def host(self) -> str:
1543
+ return self._conn.host
1544
+
1545
+ @host.setter
1546
+ def host(self, value: str):
1547
+ self._conn.host = value
1548
+
1549
+ @property
1550
+ def socket_timeout(self) -> Optional[Union[float, int]]:
1551
+ return self._conn.socket_timeout
1552
+
1553
+ @socket_timeout.setter
1554
+ def socket_timeout(self, value: Optional[Union[float, int]]):
1555
+ self._conn.socket_timeout = value
1556
+
1557
+ @property
1558
+ def socket_connect_timeout(self) -> Optional[Union[float, int]]:
1559
+ return self._conn.socket_connect_timeout
1560
+
1561
+ @socket_connect_timeout.setter
1562
+ def socket_connect_timeout(self, value: Optional[Union[float, int]]):
1563
+ self._conn.socket_connect_timeout = value
1564
+
1565
+ def _get_socket(self) -> Optional[socket.socket]:
1566
+ if isinstance(self._conn, MaintNotificationsAbstractConnection):
1567
+ return self._conn._get_socket()
1568
+ else:
1569
+ raise NotImplementedError(
1570
+ "Maintenance notifications are not supported by this connection type"
1571
+ )
1572
+
1573
+ def _get_maint_notifications_connection_instance(
1574
+ self, connection
1575
+ ) -> MaintNotificationsAbstractConnection:
1576
+ """
1577
+ Validate that connection instance supports maintenance notifications.
1578
+ With this helper method we ensure that we are working
1579
+ with the correct connection type.
1580
+ After twe validate that connection instance supports maintenance notifications
1581
+ we can safely return the connection instance
1582
+ as MaintNotificationsAbstractConnection.
1583
+ """
1584
+ if not isinstance(connection, MaintNotificationsAbstractConnection):
1585
+ raise NotImplementedError(
1586
+ "Maintenance notifications are not supported by this connection type"
1587
+ )
1588
+ else:
1589
+ return connection
1590
+
1591
+ @property
1592
+ def maintenance_state(self) -> MaintenanceState:
1593
+ con = self._get_maint_notifications_connection_instance(self._conn)
1594
+ return con.maintenance_state
1595
+
1596
+ @maintenance_state.setter
1597
+ def maintenance_state(self, state: MaintenanceState):
1598
+ con = self._get_maint_notifications_connection_instance(self._conn)
1599
+ con.maintenance_state = state
1600
+
1601
+ def getpeername(self):
1602
+ con = self._get_maint_notifications_connection_instance(self._conn)
1603
+ return con.getpeername()
1604
+
1605
+ def get_resolved_ip(self):
1606
+ con = self._get_maint_notifications_connection_instance(self._conn)
1607
+ return con.get_resolved_ip()
1608
+
1609
+ def update_current_socket_timeout(self, relaxed_timeout: Optional[float] = None):
1610
+ con = self._get_maint_notifications_connection_instance(self._conn)
1611
+ con.update_current_socket_timeout(relaxed_timeout)
1612
+
1613
+ def set_tmp_settings(
1614
+ self,
1615
+ tmp_host_address: Optional[str] = None,
1616
+ tmp_relaxed_timeout: Optional[float] = None,
1617
+ ):
1618
+ con = self._get_maint_notifications_connection_instance(self._conn)
1619
+ con.set_tmp_settings(tmp_host_address, tmp_relaxed_timeout)
1620
+
1621
+ def reset_tmp_settings(
1622
+ self,
1623
+ reset_host_address: bool = False,
1624
+ reset_relaxed_timeout: bool = False,
1625
+ ):
1626
+ con = self._get_maint_notifications_connection_instance(self._conn)
1627
+ con.reset_tmp_settings(reset_host_address, reset_relaxed_timeout)
1628
+
991
1629
  def _connect(self):
992
1630
  self._conn._connect()
993
1631
 
@@ -1011,15 +1649,6 @@ class CacheProxyConnection(ConnectionInterface):
1011
1649
  else:
1012
1650
  self._cache.delete_by_redis_keys(data[1])
1013
1651
 
1014
- def get_protocol(self):
1015
- return self._conn.get_protocol()
1016
-
1017
- def set_re_auth_token(self, token: TokenInterface):
1018
- self._conn.set_re_auth_token(token)
1019
-
1020
- def re_auth(self):
1021
- self._conn.re_auth()
1022
-
1023
1652
 
1024
1653
  class SSLConnection(Connection):
1025
1654
  """Manages SSL connections to and from the Redis server(s).
@@ -1032,6 +1661,8 @@ class SSLConnection(Connection):
1032
1661
  ssl_keyfile=None,
1033
1662
  ssl_certfile=None,
1034
1663
  ssl_cert_reqs="required",
1664
+ ssl_include_verify_flags: Optional[List["VerifyFlags"]] = None,
1665
+ ssl_exclude_verify_flags: Optional[List["VerifyFlags"]] = None,
1035
1666
  ssl_ca_certs=None,
1036
1667
  ssl_ca_data=None,
1037
1668
  ssl_check_hostname=True,
@@ -1050,10 +1681,13 @@ class SSLConnection(Connection):
1050
1681
  Args:
1051
1682
  ssl_keyfile: Path to an ssl private key. Defaults to None.
1052
1683
  ssl_certfile: Path to an ssl certificate. Defaults to None.
1053
- ssl_cert_reqs: The string value for the SSLContext.verify_mode (none, optional, required), or an ssl.VerifyMode. Defaults to "required".
1684
+ ssl_cert_reqs: The string value for the SSLContext.verify_mode (none, optional, required),
1685
+ or an ssl.VerifyMode. Defaults to "required".
1686
+ ssl_include_verify_flags: A list of flags to be included in the SSLContext.verify_flags. Defaults to None.
1687
+ ssl_exclude_verify_flags: A list of flags to be excluded from the SSLContext.verify_flags. Defaults to None.
1054
1688
  ssl_ca_certs: The path to a file of concatenated CA certificates in PEM format. Defaults to None.
1055
1689
  ssl_ca_data: Either an ASCII string of one or more PEM-encoded certificates or a bytes-like object of DER-encoded certificates.
1056
- ssl_check_hostname: If set, match the hostname during the SSL handshake. Defaults to False.
1690
+ ssl_check_hostname: If set, match the hostname during the SSL handshake. Defaults to True.
1057
1691
  ssl_ca_path: The path to a directory containing several CA certificates in PEM format. Defaults to None.
1058
1692
  ssl_password: Password for unlocking an encrypted private key. Defaults to None.
1059
1693
 
@@ -1086,6 +1720,8 @@ class SSLConnection(Connection):
1086
1720
  )
1087
1721
  ssl_cert_reqs = CERT_REQS[ssl_cert_reqs]
1088
1722
  self.cert_reqs = ssl_cert_reqs
1723
+ self.ssl_include_verify_flags = ssl_include_verify_flags
1724
+ self.ssl_exclude_verify_flags = ssl_exclude_verify_flags
1089
1725
  self.ca_certs = ssl_ca_certs
1090
1726
  self.ca_data = ssl_ca_data
1091
1727
  self.ca_path = ssl_ca_path
@@ -1125,6 +1761,12 @@ class SSLConnection(Connection):
1125
1761
  context = ssl.create_default_context()
1126
1762
  context.check_hostname = self.check_hostname
1127
1763
  context.verify_mode = self.cert_reqs
1764
+ if self.ssl_include_verify_flags:
1765
+ for flag in self.ssl_include_verify_flags:
1766
+ context.verify_flags |= flag
1767
+ if self.ssl_exclude_verify_flags:
1768
+ for flag in self.ssl_exclude_verify_flags:
1769
+ context.verify_flags &= ~flag
1128
1770
  if self.certfile or self.keyfile:
1129
1771
  context.load_cert_chain(
1130
1772
  certfile=self.certfile,
@@ -1238,6 +1880,20 @@ def to_bool(value):
1238
1880
  return bool(value)
1239
1881
 
1240
1882
 
1883
+ def parse_ssl_verify_flags(value):
1884
+ # flags are passed in as a string representation of a list,
1885
+ # e.g. VERIFY_X509_STRICT, VERIFY_X509_PARTIAL_CHAIN
1886
+ verify_flags_str = value.replace("[", "").replace("]", "")
1887
+
1888
+ verify_flags = []
1889
+ for flag in verify_flags_str.split(","):
1890
+ flag = flag.strip()
1891
+ if not hasattr(VerifyFlags, flag):
1892
+ raise ValueError(f"Invalid ssl verify flag: {flag}")
1893
+ verify_flags.append(getattr(VerifyFlags, flag))
1894
+ return verify_flags
1895
+
1896
+
1241
1897
  URL_QUERY_ARGUMENT_PARSERS = {
1242
1898
  "db": int,
1243
1899
  "socket_timeout": float,
@@ -1248,6 +1904,8 @@ URL_QUERY_ARGUMENT_PARSERS = {
1248
1904
  "max_connections": int,
1249
1905
  "health_check_interval": int,
1250
1906
  "ssl_check_hostname": to_bool,
1907
+ "ssl_include_verify_flags": parse_ssl_verify_flags,
1908
+ "ssl_exclude_verify_flags": parse_ssl_verify_flags,
1251
1909
  "timeout": float,
1252
1910
  }
1253
1911
 
@@ -1312,7 +1970,396 @@ def parse_url(url):
1312
1970
  _CP = TypeVar("_CP", bound="ConnectionPool")
1313
1971
 
1314
1972
 
1315
- class ConnectionPool:
1973
+ class ConnectionPoolInterface(ABC):
1974
+ @abstractmethod
1975
+ def get_protocol(self):
1976
+ pass
1977
+
1978
+ @abstractmethod
1979
+ def reset(self):
1980
+ pass
1981
+
1982
+ @abstractmethod
1983
+ @deprecated_args(
1984
+ args_to_warn=["*"],
1985
+ reason="Use get_connection() without args instead",
1986
+ version="5.3.0",
1987
+ )
1988
+ def get_connection(
1989
+ self, command_name: Optional[str], *keys, **options
1990
+ ) -> ConnectionInterface:
1991
+ pass
1992
+
1993
+ @abstractmethod
1994
+ def get_encoder(self):
1995
+ pass
1996
+
1997
+ @abstractmethod
1998
+ def release(self, connection: ConnectionInterface):
1999
+ pass
2000
+
2001
+ @abstractmethod
2002
+ def disconnect(self, inuse_connections: bool = True):
2003
+ pass
2004
+
2005
+ @abstractmethod
2006
+ def close(self):
2007
+ pass
2008
+
2009
+ @abstractmethod
2010
+ def set_retry(self, retry: Retry):
2011
+ pass
2012
+
2013
+ @abstractmethod
2014
+ def re_auth_callback(self, token: TokenInterface):
2015
+ pass
2016
+
2017
+
2018
+ class MaintNotificationsAbstractConnectionPool:
2019
+ """
2020
+ Abstract class for handling maintenance notifications logic.
2021
+ This class is mixed into the ConnectionPool classes.
2022
+
2023
+ This class is not intended to be used directly!
2024
+
2025
+ All logic related to maintenance notifications and
2026
+ connection pool handling is encapsulated in this class.
2027
+ """
2028
+
2029
+ def __init__(
2030
+ self,
2031
+ maint_notifications_config: Optional[MaintNotificationsConfig] = None,
2032
+ **kwargs,
2033
+ ):
2034
+ # Initialize maintenance notifications
2035
+ is_protocol_supported = kwargs.get("protocol") in [3, "3"]
2036
+ if maint_notifications_config is None and is_protocol_supported:
2037
+ maint_notifications_config = MaintNotificationsConfig()
2038
+
2039
+ if maint_notifications_config and maint_notifications_config.enabled:
2040
+ if not is_protocol_supported:
2041
+ raise RedisError(
2042
+ "Maintenance notifications handlers on connection are only supported with RESP version 3"
2043
+ )
2044
+
2045
+ self._maint_notifications_pool_handler = MaintNotificationsPoolHandler(
2046
+ self, maint_notifications_config
2047
+ )
2048
+
2049
+ self._update_connection_kwargs_for_maint_notifications(
2050
+ self._maint_notifications_pool_handler
2051
+ )
2052
+ else:
2053
+ self._maint_notifications_pool_handler = None
2054
+
2055
+ @property
2056
+ @abstractmethod
2057
+ def connection_kwargs(self) -> Dict[str, Any]:
2058
+ pass
2059
+
2060
+ @connection_kwargs.setter
2061
+ @abstractmethod
2062
+ def connection_kwargs(self, value: Dict[str, Any]):
2063
+ pass
2064
+
2065
+ @abstractmethod
2066
+ def _get_pool_lock(self) -> threading.RLock:
2067
+ pass
2068
+
2069
+ @abstractmethod
2070
+ def _get_free_connections(self) -> Iterable["MaintNotificationsAbstractConnection"]:
2071
+ pass
2072
+
2073
+ @abstractmethod
2074
+ def _get_in_use_connections(
2075
+ self,
2076
+ ) -> Iterable["MaintNotificationsAbstractConnection"]:
2077
+ pass
2078
+
2079
+ def maint_notifications_enabled(self):
2080
+ """
2081
+ Returns:
2082
+ True if the maintenance notifications are enabled, False otherwise.
2083
+ The maintenance notifications config is stored in the pool handler.
2084
+ If the pool handler is not set, the maintenance notifications are not enabled.
2085
+ """
2086
+ maint_notifications_config = (
2087
+ self._maint_notifications_pool_handler.config
2088
+ if self._maint_notifications_pool_handler
2089
+ else None
2090
+ )
2091
+
2092
+ return maint_notifications_config and maint_notifications_config.enabled
2093
+
2094
+ def update_maint_notifications_config(
2095
+ self, maint_notifications_config: MaintNotificationsConfig
2096
+ ):
2097
+ """
2098
+ Updates the maintenance notifications configuration.
2099
+ This method should be called only if the pool was created
2100
+ without enabling the maintenance notifications and
2101
+ in a later point in time maintenance notifications
2102
+ are requested to be enabled.
2103
+ """
2104
+ if (
2105
+ self.maint_notifications_enabled()
2106
+ and not maint_notifications_config.enabled
2107
+ ):
2108
+ raise ValueError(
2109
+ "Cannot disable maintenance notifications after enabling them"
2110
+ )
2111
+ # first update pool settings
2112
+ if not self._maint_notifications_pool_handler:
2113
+ self._maint_notifications_pool_handler = MaintNotificationsPoolHandler(
2114
+ self, maint_notifications_config
2115
+ )
2116
+ else:
2117
+ self._maint_notifications_pool_handler.config = maint_notifications_config
2118
+
2119
+ # then update connection kwargs and existing connections
2120
+ self._update_connection_kwargs_for_maint_notifications(
2121
+ self._maint_notifications_pool_handler
2122
+ )
2123
+ self._update_maint_notifications_configs_for_connections(
2124
+ self._maint_notifications_pool_handler
2125
+ )
2126
+
2127
+ def _update_connection_kwargs_for_maint_notifications(
2128
+ self, maint_notifications_pool_handler: MaintNotificationsPoolHandler
2129
+ ):
2130
+ """
2131
+ Update the connection kwargs for all future connections.
2132
+ """
2133
+ if not self.maint_notifications_enabled():
2134
+ return
2135
+
2136
+ self.connection_kwargs.update(
2137
+ {
2138
+ "maint_notifications_pool_handler": maint_notifications_pool_handler,
2139
+ "maint_notifications_config": maint_notifications_pool_handler.config,
2140
+ }
2141
+ )
2142
+
2143
+ # Store original connection parameters for maintenance notifications.
2144
+ if self.connection_kwargs.get("orig_host_address", None) is None:
2145
+ # If orig_host_address is None it means we haven't
2146
+ # configured the original values yet
2147
+ self.connection_kwargs.update(
2148
+ {
2149
+ "orig_host_address": self.connection_kwargs.get("host"),
2150
+ "orig_socket_timeout": self.connection_kwargs.get(
2151
+ "socket_timeout", None
2152
+ ),
2153
+ "orig_socket_connect_timeout": self.connection_kwargs.get(
2154
+ "socket_connect_timeout", None
2155
+ ),
2156
+ }
2157
+ )
2158
+
2159
+ def _update_maint_notifications_configs_for_connections(
2160
+ self, maint_notifications_pool_handler: MaintNotificationsPoolHandler
2161
+ ):
2162
+ """Update the maintenance notifications config for all connections in the pool."""
2163
+ with self._get_pool_lock():
2164
+ for conn in self._get_free_connections():
2165
+ conn.set_maint_notifications_pool_handler_for_connection(
2166
+ maint_notifications_pool_handler
2167
+ )
2168
+ conn.maint_notifications_config = (
2169
+ maint_notifications_pool_handler.config
2170
+ )
2171
+ conn.disconnect()
2172
+ for conn in self._get_in_use_connections():
2173
+ conn.set_maint_notifications_pool_handler_for_connection(
2174
+ maint_notifications_pool_handler
2175
+ )
2176
+ conn.maint_notifications_config = (
2177
+ maint_notifications_pool_handler.config
2178
+ )
2179
+ conn.mark_for_reconnect()
2180
+
2181
+ def _should_update_connection(
2182
+ self,
2183
+ conn: "MaintNotificationsAbstractConnection",
2184
+ matching_pattern: Literal[
2185
+ "connected_address", "configured_address", "notification_hash"
2186
+ ] = "connected_address",
2187
+ matching_address: Optional[str] = None,
2188
+ matching_notification_hash: Optional[int] = None,
2189
+ ) -> bool:
2190
+ """
2191
+ Check if the connection should be updated based on the matching criteria.
2192
+ """
2193
+ if matching_pattern == "connected_address":
2194
+ if matching_address and conn.getpeername() != matching_address:
2195
+ return False
2196
+ elif matching_pattern == "configured_address":
2197
+ if matching_address and conn.host != matching_address:
2198
+ return False
2199
+ elif matching_pattern == "notification_hash":
2200
+ if (
2201
+ matching_notification_hash
2202
+ and conn.maintenance_notification_hash != matching_notification_hash
2203
+ ):
2204
+ return False
2205
+ return True
2206
+
2207
+ def update_connection_settings(
2208
+ self,
2209
+ conn: "MaintNotificationsAbstractConnection",
2210
+ state: Optional["MaintenanceState"] = None,
2211
+ maintenance_notification_hash: Optional[int] = None,
2212
+ host_address: Optional[str] = None,
2213
+ relaxed_timeout: Optional[float] = None,
2214
+ update_notification_hash: bool = False,
2215
+ reset_host_address: bool = False,
2216
+ reset_relaxed_timeout: bool = False,
2217
+ ):
2218
+ """
2219
+ Update the settings for a single connection.
2220
+ """
2221
+ if state:
2222
+ conn.maintenance_state = state
2223
+
2224
+ if update_notification_hash:
2225
+ # update the notification hash only if requested
2226
+ conn.maintenance_notification_hash = maintenance_notification_hash
2227
+
2228
+ if host_address is not None:
2229
+ conn.set_tmp_settings(tmp_host_address=host_address)
2230
+
2231
+ if relaxed_timeout is not None:
2232
+ conn.set_tmp_settings(tmp_relaxed_timeout=relaxed_timeout)
2233
+
2234
+ if reset_relaxed_timeout or reset_host_address:
2235
+ conn.reset_tmp_settings(
2236
+ reset_host_address=reset_host_address,
2237
+ reset_relaxed_timeout=reset_relaxed_timeout,
2238
+ )
2239
+
2240
+ conn.update_current_socket_timeout(relaxed_timeout)
2241
+
2242
+ def update_connections_settings(
2243
+ self,
2244
+ state: Optional["MaintenanceState"] = None,
2245
+ maintenance_notification_hash: Optional[int] = None,
2246
+ host_address: Optional[str] = None,
2247
+ relaxed_timeout: Optional[float] = None,
2248
+ matching_address: Optional[str] = None,
2249
+ matching_notification_hash: Optional[int] = None,
2250
+ matching_pattern: Literal[
2251
+ "connected_address", "configured_address", "notification_hash"
2252
+ ] = "connected_address",
2253
+ update_notification_hash: bool = False,
2254
+ reset_host_address: bool = False,
2255
+ reset_relaxed_timeout: bool = False,
2256
+ include_free_connections: bool = True,
2257
+ ):
2258
+ """
2259
+ Update the settings for all matching connections in the pool.
2260
+
2261
+ This method does not create new connections.
2262
+ This method does not affect the connection kwargs.
2263
+
2264
+ :param state: The maintenance state to set for the connection.
2265
+ :param maintenance_notification_hash: The hash of the maintenance notification
2266
+ to set for the connection.
2267
+ :param host_address: The host address to set for the connection.
2268
+ :param relaxed_timeout: The relaxed timeout to set for the connection.
2269
+ :param matching_address: The address to match for the connection.
2270
+ :param matching_notification_hash: The notification hash to match for the connection.
2271
+ :param matching_pattern: The pattern to match for the connection.
2272
+ :param update_notification_hash: Whether to update the notification hash for the connection.
2273
+ :param reset_host_address: Whether to reset the host address to the original address.
2274
+ :param reset_relaxed_timeout: Whether to reset the relaxed timeout to the original timeout.
2275
+ :param include_free_connections: Whether to include free/available connections.
2276
+ """
2277
+ with self._get_pool_lock():
2278
+ for conn in self._get_in_use_connections():
2279
+ if self._should_update_connection(
2280
+ conn,
2281
+ matching_pattern,
2282
+ matching_address,
2283
+ matching_notification_hash,
2284
+ ):
2285
+ self.update_connection_settings(
2286
+ conn,
2287
+ state=state,
2288
+ maintenance_notification_hash=maintenance_notification_hash,
2289
+ host_address=host_address,
2290
+ relaxed_timeout=relaxed_timeout,
2291
+ update_notification_hash=update_notification_hash,
2292
+ reset_host_address=reset_host_address,
2293
+ reset_relaxed_timeout=reset_relaxed_timeout,
2294
+ )
2295
+
2296
+ if include_free_connections:
2297
+ for conn in self._get_free_connections():
2298
+ if self._should_update_connection(
2299
+ conn,
2300
+ matching_pattern,
2301
+ matching_address,
2302
+ matching_notification_hash,
2303
+ ):
2304
+ self.update_connection_settings(
2305
+ conn,
2306
+ state=state,
2307
+ maintenance_notification_hash=maintenance_notification_hash,
2308
+ host_address=host_address,
2309
+ relaxed_timeout=relaxed_timeout,
2310
+ update_notification_hash=update_notification_hash,
2311
+ reset_host_address=reset_host_address,
2312
+ reset_relaxed_timeout=reset_relaxed_timeout,
2313
+ )
2314
+
2315
+ def update_connection_kwargs(
2316
+ self,
2317
+ **kwargs,
2318
+ ):
2319
+ """
2320
+ Update the connection kwargs for all future connections.
2321
+
2322
+ This method updates the connection kwargs for all future connections created by the pool.
2323
+ Existing connections are not affected.
2324
+ """
2325
+ self.connection_kwargs.update(kwargs)
2326
+
2327
+ def update_active_connections_for_reconnect(
2328
+ self,
2329
+ moving_address_src: Optional[str] = None,
2330
+ ):
2331
+ """
2332
+ Mark all active connections for reconnect.
2333
+ This is used when a cluster node is migrated to a different address.
2334
+
2335
+ :param moving_address_src: The address of the node that is being moved.
2336
+ """
2337
+ with self._get_pool_lock():
2338
+ for conn in self._get_in_use_connections():
2339
+ if self._should_update_connection(
2340
+ conn, "connected_address", moving_address_src
2341
+ ):
2342
+ conn.mark_for_reconnect()
2343
+
2344
+ def disconnect_free_connections(
2345
+ self,
2346
+ moving_address_src: Optional[str] = None,
2347
+ ):
2348
+ """
2349
+ Disconnect all free/available connections.
2350
+ This is used when a cluster node is migrated to a different address.
2351
+
2352
+ :param moving_address_src: The address of the node that is being moved.
2353
+ """
2354
+ with self._get_pool_lock():
2355
+ for conn in self._get_free_connections():
2356
+ if self._should_update_connection(
2357
+ conn, "connected_address", moving_address_src
2358
+ ):
2359
+ conn.disconnect()
2360
+
2361
+
2362
+ class ConnectionPool(MaintNotificationsAbstractConnectionPool, ConnectionPoolInterface):
1316
2363
  """
1317
2364
  Create a connection pool. ``If max_connections`` is set, then this
1318
2365
  object raises :py:class:`~redis.exceptions.ConnectionError` when the pool's
@@ -1323,6 +2370,12 @@ class ConnectionPool:
1323
2370
  unix sockets.
1324
2371
  :py:class:`~redis.SSLConnection` can be used for SSL enabled connections.
1325
2372
 
2373
+ If ``maint_notifications_config`` is provided, the connection pool will support
2374
+ maintenance notifications.
2375
+ Maintenance notifications are supported only with RESP3.
2376
+ If the ``maint_notifications_config`` is not provided but the ``protocol`` is 3,
2377
+ the maintenance notifications will be enabled by default.
2378
+
1326
2379
  Any additional keyword arguments are passed to the constructor of
1327
2380
  ``connection_class``.
1328
2381
  """
@@ -1381,6 +2434,7 @@ class ConnectionPool:
1381
2434
  connection_class=Connection,
1382
2435
  max_connections: Optional[int] = None,
1383
2436
  cache_factory: Optional[CacheFactoryInterface] = None,
2437
+ maint_notifications_config: Optional[MaintNotificationsConfig] = None,
1384
2438
  **connection_kwargs,
1385
2439
  ):
1386
2440
  max_connections = max_connections or 2**31
@@ -1388,16 +2442,16 @@ class ConnectionPool:
1388
2442
  raise ValueError('"max_connections" must be a positive integer')
1389
2443
 
1390
2444
  self.connection_class = connection_class
1391
- self.connection_kwargs = connection_kwargs
2445
+ self._connection_kwargs = connection_kwargs
1392
2446
  self.max_connections = max_connections
1393
2447
  self.cache = None
1394
2448
  self._cache_factory = cache_factory
1395
2449
 
1396
2450
  if connection_kwargs.get("cache_config") or connection_kwargs.get("cache"):
1397
- if connection_kwargs.get("protocol") not in [3, "3"]:
2451
+ if self._connection_kwargs.get("protocol") not in [3, "3"]:
1398
2452
  raise RedisError("Client caching is only supported with RESP version 3")
1399
2453
 
1400
- cache = self.connection_kwargs.get("cache")
2454
+ cache = self._connection_kwargs.get("cache")
1401
2455
 
1402
2456
  if cache is not None:
1403
2457
  if not isinstance(cache, CacheInterface):
@@ -1409,13 +2463,13 @@ class ConnectionPool:
1409
2463
  self.cache = self._cache_factory.get_cache()
1410
2464
  else:
1411
2465
  self.cache = CacheFactory(
1412
- self.connection_kwargs.get("cache_config")
2466
+ self._connection_kwargs.get("cache_config")
1413
2467
  ).get_cache()
1414
2468
 
1415
2469
  connection_kwargs.pop("cache", None)
1416
2470
  connection_kwargs.pop("cache_config", None)
1417
2471
 
1418
- self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None)
2472
+ self._event_dispatcher = self._connection_kwargs.get("event_dispatcher", None)
1419
2473
  if self._event_dispatcher is None:
1420
2474
  self._event_dispatcher = EventDispatcher()
1421
2475
 
@@ -1431,6 +2485,12 @@ class ConnectionPool:
1431
2485
  self._fork_lock = threading.RLock()
1432
2486
  self._lock = threading.RLock()
1433
2487
 
2488
+ MaintNotificationsAbstractConnectionPool.__init__(
2489
+ self,
2490
+ maint_notifications_config=maint_notifications_config,
2491
+ **connection_kwargs,
2492
+ )
2493
+
1434
2494
  self.reset()
1435
2495
 
1436
2496
  def __repr__(self) -> str:
@@ -1441,6 +2501,14 @@ class ConnectionPool:
1441
2501
  f"({conn_kwargs})>)>"
1442
2502
  )
1443
2503
 
2504
+ @property
2505
+ def connection_kwargs(self) -> Dict[str, Any]:
2506
+ return self._connection_kwargs
2507
+
2508
+ @connection_kwargs.setter
2509
+ def connection_kwargs(self, value: Dict[str, Any]):
2510
+ self._connection_kwargs = value
2511
+
1444
2512
  def get_protocol(self):
1445
2513
  """
1446
2514
  Returns:
@@ -1536,7 +2604,11 @@ class ConnectionPool:
1536
2604
  # pool before all data has been read or the socket has been
1537
2605
  # closed. either way, reconnect and verify everything is good.
1538
2606
  try:
1539
- if connection.can_read() and self.cache is None:
2607
+ if (
2608
+ connection.can_read()
2609
+ and self.cache is None
2610
+ and not self.maint_notifications_enabled()
2611
+ ):
1540
2612
  raise ConnectionError("Connection has data")
1541
2613
  except (ConnectionError, TimeoutError, OSError):
1542
2614
  connection.disconnect()
@@ -1548,7 +2620,6 @@ class ConnectionPool:
1548
2620
  # leak it
1549
2621
  self.release(connection)
1550
2622
  raise
1551
-
1552
2623
  return connection
1553
2624
 
1554
2625
  def get_encoder(self) -> Encoder:
@@ -1566,12 +2637,13 @@ class ConnectionPool:
1566
2637
  raise MaxConnectionsError("Too many connections")
1567
2638
  self._created_connections += 1
1568
2639
 
2640
+ kwargs = dict(self.connection_kwargs)
2641
+
1569
2642
  if self.cache is not None:
1570
2643
  return CacheProxyConnection(
1571
- self.connection_class(**self.connection_kwargs), self.cache, self._lock
2644
+ self.connection_class(**kwargs), self.cache, self._lock
1572
2645
  )
1573
-
1574
- return self.connection_class(**self.connection_kwargs)
2646
+ return self.connection_class(**kwargs)
1575
2647
 
1576
2648
  def release(self, connection: "Connection") -> None:
1577
2649
  "Releases the connection back to the pool"
@@ -1585,6 +2657,8 @@ class ConnectionPool:
1585
2657
  return
1586
2658
 
1587
2659
  if self.owns_connection(connection):
2660
+ if connection.should_reconnect():
2661
+ connection.disconnect()
1588
2662
  self._available_connections.append(connection)
1589
2663
  self._event_dispatcher.dispatch(
1590
2664
  AfterConnectionReleasedEvent(connection)
@@ -1605,7 +2679,7 @@ class ConnectionPool:
1605
2679
  Disconnects connections in the pool
1606
2680
 
1607
2681
  If ``inuse_connections`` is True, disconnect connections that are
1608
- current in use, potentially by other threads. Otherwise only disconnect
2682
+ currently in use, potentially by other threads. Otherwise only disconnect
1609
2683
  connections that are idle in the pool.
1610
2684
  """
1611
2685
  self._checkpid()
@@ -1646,6 +2720,17 @@ class ConnectionPool:
1646
2720
  for conn in self._in_use_connections:
1647
2721
  conn.set_re_auth_token(token)
1648
2722
 
2723
+ def _get_pool_lock(self):
2724
+ return self._lock
2725
+
2726
+ def _get_free_connections(self):
2727
+ with self._lock:
2728
+ return self._available_connections
2729
+
2730
+ def _get_in_use_connections(self):
2731
+ with self._lock:
2732
+ return self._in_use_connections
2733
+
1649
2734
  async def _mock(self, error: RedisError):
1650
2735
  """
1651
2736
  Dummy functions, needs to be passed as error callback to retry object.
@@ -1699,6 +2784,8 @@ class BlockingConnectionPool(ConnectionPool):
1699
2784
  ):
1700
2785
  self.queue_class = queue_class
1701
2786
  self.timeout = timeout
2787
+ self._in_maintenance = False
2788
+ self._locked = False
1702
2789
  super().__init__(
1703
2790
  connection_class=connection_class,
1704
2791
  max_connections=max_connections,
@@ -1707,16 +2794,27 @@ class BlockingConnectionPool(ConnectionPool):
1707
2794
 
1708
2795
  def reset(self):
1709
2796
  # Create and fill up a thread safe queue with ``None`` values.
1710
- self.pool = self.queue_class(self.max_connections)
1711
- while True:
1712
- try:
1713
- self.pool.put_nowait(None)
1714
- except Full:
1715
- break
2797
+ try:
2798
+ if self._in_maintenance:
2799
+ self._lock.acquire()
2800
+ self._locked = True
2801
+ self.pool = self.queue_class(self.max_connections)
2802
+ while True:
2803
+ try:
2804
+ self.pool.put_nowait(None)
2805
+ except Full:
2806
+ break
1716
2807
 
1717
- # Keep a list of actual connection instances so that we can
1718
- # disconnect them later.
1719
- self._connections = []
2808
+ # Keep a list of actual connection instances so that we can
2809
+ # disconnect them later.
2810
+ self._connections = []
2811
+ finally:
2812
+ if self._locked:
2813
+ try:
2814
+ self._lock.release()
2815
+ except Exception:
2816
+ pass
2817
+ self._locked = False
1720
2818
 
1721
2819
  # this must be the last operation in this method. while reset() is
1722
2820
  # called when holding _fork_lock, other threads in this process
@@ -1731,14 +2829,28 @@ class BlockingConnectionPool(ConnectionPool):
1731
2829
 
1732
2830
  def make_connection(self):
1733
2831
  "Make a fresh connection."
1734
- if self.cache is not None:
1735
- connection = CacheProxyConnection(
1736
- self.connection_class(**self.connection_kwargs), self.cache, self._lock
1737
- )
1738
- else:
1739
- connection = self.connection_class(**self.connection_kwargs)
1740
- self._connections.append(connection)
1741
- return connection
2832
+ try:
2833
+ if self._in_maintenance:
2834
+ self._lock.acquire()
2835
+ self._locked = True
2836
+
2837
+ if self.cache is not None:
2838
+ connection = CacheProxyConnection(
2839
+ self.connection_class(**self.connection_kwargs),
2840
+ self.cache,
2841
+ self._lock,
2842
+ )
2843
+ else:
2844
+ connection = self.connection_class(**self.connection_kwargs)
2845
+ self._connections.append(connection)
2846
+ return connection
2847
+ finally:
2848
+ if self._locked:
2849
+ try:
2850
+ self._lock.release()
2851
+ except Exception:
2852
+ pass
2853
+ self._locked = False
1742
2854
 
1743
2855
  @deprecated_args(
1744
2856
  args_to_warn=["*"],
@@ -1764,16 +2876,27 @@ class BlockingConnectionPool(ConnectionPool):
1764
2876
  # self.timeout then raise a ``ConnectionError``.
1765
2877
  connection = None
1766
2878
  try:
1767
- connection = self.pool.get(block=True, timeout=self.timeout)
1768
- except Empty:
1769
- # Note that this is not caught by the redis client and will be
1770
- # raised unless handled by application code. If you want never to
1771
- raise ConnectionError("No connection available.")
1772
-
1773
- # If the ``connection`` is actually ``None`` then that's a cue to make
1774
- # a new connection to add to the pool.
1775
- if connection is None:
1776
- connection = self.make_connection()
2879
+ if self._in_maintenance:
2880
+ self._lock.acquire()
2881
+ self._locked = True
2882
+ try:
2883
+ connection = self.pool.get(block=True, timeout=self.timeout)
2884
+ except Empty:
2885
+ # Note that this is not caught by the redis client and will be
2886
+ # raised unless handled by application code. If you want never to
2887
+ raise ConnectionError("No connection available.")
2888
+
2889
+ # If the ``connection`` is actually ``None`` then that's a cue to make
2890
+ # a new connection to add to the pool.
2891
+ if connection is None:
2892
+ connection = self.make_connection()
2893
+ finally:
2894
+ if self._locked:
2895
+ try:
2896
+ self._lock.release()
2897
+ except Exception:
2898
+ pass
2899
+ self._locked = False
1777
2900
 
1778
2901
  try:
1779
2902
  # ensure this connection is connected to Redis
@@ -1801,25 +2924,76 @@ class BlockingConnectionPool(ConnectionPool):
1801
2924
  "Releases the connection back to the pool."
1802
2925
  # Make sure we haven't changed process.
1803
2926
  self._checkpid()
1804
- if not self.owns_connection(connection):
1805
- # pool doesn't own this connection. do not add it back
1806
- # to the pool. instead add a None value which is a placeholder
1807
- # that will cause the pool to recreate the connection if
1808
- # its needed.
1809
- connection.disconnect()
1810
- self.pool.put_nowait(None)
1811
- return
1812
2927
 
1813
- # Put the connection back into the pool.
1814
2928
  try:
1815
- self.pool.put_nowait(connection)
1816
- except Full:
1817
- # perhaps the pool has been reset() after a fork? regardless,
1818
- # we don't want this connection
1819
- pass
2929
+ if self._in_maintenance:
2930
+ self._lock.acquire()
2931
+ self._locked = True
2932
+ if not self.owns_connection(connection):
2933
+ # pool doesn't own this connection. do not add it back
2934
+ # to the pool. instead add a None value which is a placeholder
2935
+ # that will cause the pool to recreate the connection if
2936
+ # its needed.
2937
+ connection.disconnect()
2938
+ self.pool.put_nowait(None)
2939
+ return
2940
+ if connection.should_reconnect():
2941
+ connection.disconnect()
2942
+ # Put the connection back into the pool.
2943
+ try:
2944
+ self.pool.put_nowait(connection)
2945
+ except Full:
2946
+ # perhaps the pool has been reset() after a fork? regardless,
2947
+ # we don't want this connection
2948
+ pass
2949
+ finally:
2950
+ if self._locked:
2951
+ try:
2952
+ self._lock.release()
2953
+ except Exception:
2954
+ pass
2955
+ self._locked = False
1820
2956
 
1821
- def disconnect(self):
1822
- "Disconnects all connections in the pool."
2957
+ def disconnect(self, inuse_connections: bool = True):
2958
+ "Disconnects either all connections in the pool or just the free connections."
1823
2959
  self._checkpid()
1824
- for connection in self._connections:
1825
- connection.disconnect()
2960
+ try:
2961
+ if self._in_maintenance:
2962
+ self._lock.acquire()
2963
+ self._locked = True
2964
+ if inuse_connections:
2965
+ connections = self._connections
2966
+ else:
2967
+ connections = self._get_free_connections()
2968
+ for connection in connections:
2969
+ connection.disconnect()
2970
+ finally:
2971
+ if self._locked:
2972
+ try:
2973
+ self._lock.release()
2974
+ except Exception:
2975
+ pass
2976
+ self._locked = False
2977
+
2978
+ def _get_free_connections(self):
2979
+ with self._lock:
2980
+ return {conn for conn in self.pool.queue if conn}
2981
+
2982
+ def _get_in_use_connections(self):
2983
+ with self._lock:
2984
+ # free connections
2985
+ connections_in_queue = {conn for conn in self.pool.queue if conn}
2986
+ # in self._connections we keep all created connections
2987
+ # so the ones that are not in the queue are the in use ones
2988
+ return {
2989
+ conn for conn in self._connections if conn not in connections_in_queue
2990
+ }
2991
+
2992
+ def set_in_maintenance(self, in_maintenance: bool):
2993
+ """
2994
+ Sets a flag that this Blocking ConnectionPool is in maintenance mode.
2995
+
2996
+ This is used to prevent new connections from being created while we are in maintenance mode.
2997
+ The pool will be in maintenance mode only when we are processing a MOVING notification.
2998
+ """
2999
+ self._in_maintenance = in_maintenance