redis 7.0.0b3__py3-none-any.whl → 7.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
redis/connection.py CHANGED
@@ -5,7 +5,7 @@ import sys
5
5
  import threading
6
6
  import time
7
7
  import weakref
8
- from abc import abstractmethod
8
+ from abc import ABC, abstractmethod
9
9
  from itertools import chain
10
10
  from queue import Empty, Full, LifoQueue
11
11
  from typing import (
@@ -178,10 +178,6 @@ class ConnectionInterface:
178
178
  def set_parser(self, parser_class):
179
179
  pass
180
180
 
181
- @abstractmethod
182
- def set_maint_notifications_pool_handler(self, maint_notifications_pool_handler):
183
- pass
184
-
185
181
  @abstractmethod
186
182
  def get_protocol(self):
187
183
  pass
@@ -245,29 +241,6 @@ class ConnectionInterface:
245
241
  def re_auth(self):
246
242
  pass
247
243
 
248
- @property
249
- @abstractmethod
250
- def maintenance_state(self) -> MaintenanceState:
251
- """
252
- Returns the current maintenance state of the connection.
253
- """
254
- pass
255
-
256
- @maintenance_state.setter
257
- @abstractmethod
258
- def maintenance_state(self, state: "MaintenanceState"):
259
- """
260
- Sets the current maintenance state of the connection.
261
- """
262
- pass
263
-
264
- @abstractmethod
265
- def getpeername(self):
266
- """
267
- Returns the peer name of the connection.
268
- """
269
- pass
270
-
271
244
  @abstractmethod
272
245
  def mark_for_reconnect(self):
273
246
  """
@@ -284,250 +257,189 @@ class ConnectionInterface:
284
257
  pass
285
258
 
286
259
  @abstractmethod
287
- def get_resolved_ip(self):
288
- """
289
- Get resolved ip address for the connection.
290
- """
291
- pass
292
-
293
- @abstractmethod
294
- def update_current_socket_timeout(self, relaxed_timeout: Optional[float] = None):
260
+ def reset_should_reconnect(self):
295
261
  """
296
- Update the timeout for the current socket.
262
+ Reset the internal flag to False.
297
263
  """
298
264
  pass
299
265
 
300
- @abstractmethod
301
- def set_tmp_settings(
302
- self,
303
- tmp_host_address: Optional[str] = None,
304
- tmp_relaxed_timeout: Optional[float] = None,
305
- ):
306
- """
307
- Updates temporary host address and timeout settings for the connection.
308
- """
309
- pass
310
266
 
311
- @abstractmethod
312
- def reset_tmp_settings(
313
- self,
314
- reset_host_address: bool = False,
315
- reset_relaxed_timeout: bool = False,
316
- ):
317
- """
318
- Resets temporary host address and timeout settings for the connection.
319
- """
320
- pass
267
+ class MaintNotificationsAbstractConnection:
268
+ """
269
+ Abstract class for handling maintenance notifications logic.
270
+ This class is expected to be used as base class together with ConnectionInterface.
321
271
 
272
+ This class is intended to be used with multiple inheritance!
322
273
 
323
- class AbstractConnection(ConnectionInterface):
324
- "Manages communication to and from a Redis server"
274
+ All logic related to maintenance notifications is encapsulated in this class.
275
+ """
325
276
 
326
277
  def __init__(
327
278
  self,
328
- db: int = 0,
329
- password: Optional[str] = None,
330
- socket_timeout: Optional[float] = None,
331
- socket_connect_timeout: Optional[float] = None,
332
- retry_on_timeout: bool = False,
333
- retry_on_error: Union[Iterable[Type[Exception]], object] = SENTINEL,
334
- encoding: str = "utf-8",
335
- encoding_errors: str = "strict",
336
- decode_responses: bool = False,
337
- parser_class=DefaultParser,
338
- socket_read_size: int = 65536,
339
- health_check_interval: int = 0,
340
- client_name: Optional[str] = None,
341
- lib_name: Optional[str] = "redis-py",
342
- lib_version: Optional[str] = get_lib_version(),
343
- username: Optional[str] = None,
344
- retry: Union[Any, None] = None,
345
- redis_connect_func: Optional[Callable[[], None]] = None,
346
- credential_provider: Optional[CredentialProvider] = None,
347
- protocol: Optional[int] = 2,
348
- command_packer: Optional[Callable[[], None]] = None,
349
- event_dispatcher: Optional[EventDispatcher] = None,
279
+ maint_notifications_config: Optional[MaintNotificationsConfig],
350
280
  maint_notifications_pool_handler: Optional[
351
281
  MaintNotificationsPoolHandler
352
282
  ] = None,
353
- maint_notifications_config: Optional[MaintNotificationsConfig] = None,
354
283
  maintenance_state: "MaintenanceState" = MaintenanceState.NONE,
355
284
  maintenance_notification_hash: Optional[int] = None,
356
285
  orig_host_address: Optional[str] = None,
357
286
  orig_socket_timeout: Optional[float] = None,
358
287
  orig_socket_connect_timeout: Optional[float] = None,
288
+ parser: Optional[Union[_HiredisParser, _RESP3Parser]] = None,
359
289
  ):
360
290
  """
361
- Initialize a new Connection.
362
- To specify a retry policy for specific errors, first set
363
- `retry_on_error` to a list of the error/s to retry on, then set
364
- `retry` to a valid `Retry` object.
365
- To retry on TimeoutError, `retry_on_timeout` can also be set to `True`.
366
- """
367
- if (username or password) and credential_provider is not None:
368
- raise DataError(
369
- "'username' and 'password' cannot be passed along with 'credential_"
370
- "provider'. Please provide only one of the following arguments: \n"
371
- "1. 'password' and (optional) 'username'\n"
372
- "2. 'credential_provider'"
373
- )
374
- if event_dispatcher is None:
375
- self._event_dispatcher = EventDispatcher()
376
- else:
377
- self._event_dispatcher = event_dispatcher
378
- self.pid = os.getpid()
379
- self.db = db
380
- self.client_name = client_name
381
- self.lib_name = lib_name
382
- self.lib_version = lib_version
383
- self.credential_provider = credential_provider
384
- self.password = password
385
- self.username = username
386
- self.socket_timeout = socket_timeout
387
- if socket_connect_timeout is None:
388
- socket_connect_timeout = socket_timeout
389
- self.socket_connect_timeout = socket_connect_timeout
390
- self.retry_on_timeout = retry_on_timeout
391
- if retry_on_error is SENTINEL:
392
- retry_on_errors_list = []
393
- else:
394
- retry_on_errors_list = list(retry_on_error)
395
- if retry_on_timeout:
396
- # Add TimeoutError to the errors list to retry on
397
- retry_on_errors_list.append(TimeoutError)
398
- self.retry_on_error = retry_on_errors_list
399
- if retry or self.retry_on_error:
400
- if retry is None:
401
- self.retry = Retry(NoBackoff(), 1)
402
- else:
403
- # deep-copy the Retry object as it is mutable
404
- self.retry = copy.deepcopy(retry)
405
- if self.retry_on_error:
406
- # Update the retry's supported errors with the specified errors
407
- self.retry.update_supported_errors(self.retry_on_error)
408
- else:
409
- self.retry = Retry(NoBackoff(), 0)
410
- self.health_check_interval = health_check_interval
411
- self.next_health_check = 0
412
- self.redis_connect_func = redis_connect_func
413
- self.encoder = Encoder(encoding, encoding_errors, decode_responses)
414
- self.handshake_metadata = None
415
- self._sock = None
416
- self._socket_read_size = socket_read_size
417
- self._connect_callbacks = []
418
- self._buffer_cutoff = 6000
419
- self._re_auth_token: Optional[TokenInterface] = None
420
- try:
421
- p = int(protocol)
422
- except TypeError:
423
- p = DEFAULT_RESP_VERSION
424
- except ValueError:
425
- raise ConnectionError("protocol must be an integer")
426
- finally:
427
- if p < 2 or p > 3:
428
- raise ConnectionError("protocol must be either 2 or 3")
429
- # p = DEFAULT_RESP_VERSION
430
- self.protocol = p
431
- if self.protocol == 3 and parser_class == DefaultParser:
432
- parser_class = _RESP3Parser
433
- self.set_parser(parser_class)
291
+ Initialize the maintenance notifications for the connection.
434
292
 
293
+ Args:
294
+ maint_notifications_config (MaintNotificationsConfig): The configuration for maintenance notifications.
295
+ maint_notifications_pool_handler (Optional[MaintNotificationsPoolHandler]): The pool handler for maintenance notifications.
296
+ maintenance_state (MaintenanceState): The current maintenance state of the connection.
297
+ maintenance_notification_hash (Optional[int]): The current maintenance notification hash of the connection.
298
+ orig_host_address (Optional[str]): The original host address of the connection.
299
+ orig_socket_timeout (Optional[float]): The original socket timeout of the connection.
300
+ orig_socket_connect_timeout (Optional[float]): The original socket connect timeout of the connection.
301
+ parser (Optional[Union[_HiredisParser, _RESP3Parser]]): The parser to use for maintenance notifications.
302
+ If not provided, the parser from the connection is used.
303
+ This is useful when the parser is created after this object.
304
+ """
435
305
  self.maint_notifications_config = maint_notifications_config
436
-
437
- # Set up maintenance notifications if enabled
306
+ self.maintenance_state = maintenance_state
307
+ self.maintenance_notification_hash = maintenance_notification_hash
438
308
  self._configure_maintenance_notifications(
439
309
  maint_notifications_pool_handler,
440
310
  orig_host_address,
441
311
  orig_socket_timeout,
442
312
  orig_socket_connect_timeout,
313
+ parser,
443
314
  )
444
315
 
445
- self._should_reconnect = False
446
- self.maintenance_state = maintenance_state
447
- self.maintenance_notification_hash = maintenance_notification_hash
316
+ @abstractmethod
317
+ def _get_parser(self) -> Union[_HiredisParser, _RESP3Parser]:
318
+ pass
448
319
 
449
- self._command_packer = self._construct_command_packer(command_packer)
320
+ @abstractmethod
321
+ def _get_socket(self) -> Optional[socket.socket]:
322
+ pass
450
323
 
451
- def __repr__(self):
452
- repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()])
453
- return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>"
324
+ @abstractmethod
325
+ def get_protocol(self) -> Union[int, str]:
326
+ """
327
+ Returns:
328
+ The RESP protocol version, or ``None`` if the protocol is not specified,
329
+ in which case the server default will be used.
330
+ """
331
+ pass
454
332
 
333
+ @property
455
334
  @abstractmethod
456
- def repr_pieces(self):
335
+ def host(self) -> str:
457
336
  pass
458
337
 
459
- def __del__(self):
460
- try:
461
- self.disconnect()
462
- except Exception:
463
- pass
338
+ @host.setter
339
+ @abstractmethod
340
+ def host(self, value: str):
341
+ pass
464
342
 
465
- def _construct_command_packer(self, packer):
466
- if packer is not None:
467
- return packer
468
- elif HIREDIS_AVAILABLE:
469
- return HiredisRespSerializer()
470
- else:
471
- return PythonRespSerializer(self._buffer_cutoff, self.encoder.encode)
343
+ @property
344
+ @abstractmethod
345
+ def socket_timeout(self) -> Optional[Union[float, int]]:
346
+ pass
472
347
 
473
- def register_connect_callback(self, callback):
474
- """
475
- Register a callback to be called when the connection is established either
476
- initially or reconnected. This allows listeners to issue commands that
477
- are ephemeral to the connection, for example pub/sub subscription or
478
- key tracking. The callback must be a _method_ and will be kept as
479
- a weak reference.
480
- """
481
- wm = weakref.WeakMethod(callback)
482
- if wm not in self._connect_callbacks:
483
- self._connect_callbacks.append(wm)
348
+ @socket_timeout.setter
349
+ @abstractmethod
350
+ def socket_timeout(self, value: Optional[Union[float, int]]):
351
+ pass
484
352
 
485
- def deregister_connect_callback(self, callback):
486
- """
487
- De-register a previously registered callback. It will no-longer receive
488
- notifications on connection events. Calling this is not required when the
489
- listener goes away, since the callbacks are kept as weak methods.
490
- """
491
- try:
492
- self._connect_callbacks.remove(weakref.WeakMethod(callback))
493
- except ValueError:
494
- pass
353
+ @property
354
+ @abstractmethod
355
+ def socket_connect_timeout(self) -> Optional[Union[float, int]]:
356
+ pass
495
357
 
496
- def set_parser(self, parser_class):
497
- """
498
- Creates a new instance of parser_class with socket size:
499
- _socket_read_size and assigns it to the parser for the connection
500
- :param parser_class: The required parser class
501
- """
502
- self._parser = parser_class(socket_read_size=self._socket_read_size)
358
+ @socket_connect_timeout.setter
359
+ @abstractmethod
360
+ def socket_connect_timeout(self, value: Optional[Union[float, int]]):
361
+ pass
362
+
363
+ @abstractmethod
364
+ def send_command(self, *args, **kwargs):
365
+ pass
366
+
367
+ @abstractmethod
368
+ def read_response(
369
+ self,
370
+ disable_decoding=False,
371
+ *,
372
+ disconnect_on_error=True,
373
+ push_request=False,
374
+ ):
375
+ pass
376
+
377
+ @abstractmethod
378
+ def disconnect(self, *args):
379
+ pass
503
380
 
504
381
  def _configure_maintenance_notifications(
505
382
  self,
506
- maint_notifications_pool_handler=None,
383
+ maint_notifications_pool_handler: Optional[
384
+ MaintNotificationsPoolHandler
385
+ ] = None,
507
386
  orig_host_address=None,
508
387
  orig_socket_timeout=None,
509
388
  orig_socket_connect_timeout=None,
389
+ parser: Optional[Union[_HiredisParser, _RESP3Parser]] = None,
510
390
  ):
511
- """Enable maintenance notifications by setting up handlers and storing original connection parameters."""
391
+ """
392
+ Enable maintenance notifications by setting up
393
+ handlers and storing original connection parameters.
394
+
395
+ Should be used ONLY with parsers that support push notifications.
396
+ """
512
397
  if (
513
398
  not self.maint_notifications_config
514
399
  or not self.maint_notifications_config.enabled
515
400
  ):
401
+ self._maint_notifications_pool_handler = None
516
402
  self._maint_notifications_connection_handler = None
517
403
  return
518
404
 
519
- # Set up pool handler if available
520
- if maint_notifications_pool_handler:
521
- self._parser.set_node_moving_push_handler(
522
- maint_notifications_pool_handler.handle_notification
405
+ if not parser:
406
+ raise RedisError(
407
+ "To configure maintenance notifications, a parser must be provided!"
523
408
  )
524
409
 
525
- # Set up connection handler
526
- self._maint_notifications_connection_handler = (
527
- MaintNotificationsConnectionHandler(self, self.maint_notifications_config)
528
- )
529
- self._parser.set_maintenance_push_handler(
530
- self._maint_notifications_connection_handler.handle_notification
410
+ if not isinstance(parser, _HiredisParser) and not isinstance(
411
+ parser, _RESP3Parser
412
+ ):
413
+ raise RedisError(
414
+ "Maintenance notifications are only supported with hiredis and RESP3 parsers!"
415
+ )
416
+
417
+ if maint_notifications_pool_handler:
418
+ # Extract a reference to a new pool handler that copies all properties
419
+ # of the original one and has a different connection reference
420
+ # This is needed because when we attach the handler to the parser
421
+ # we need to make sure that the handler has a reference to the
422
+ # connection that the parser is attached to.
423
+ self._maint_notifications_pool_handler = (
424
+ maint_notifications_pool_handler.get_handler_for_connection()
425
+ )
426
+ self._maint_notifications_pool_handler.set_connection(self)
427
+ else:
428
+ self._maint_notifications_pool_handler = None
429
+
430
+ self._maint_notifications_connection_handler = (
431
+ MaintNotificationsConnectionHandler(self, self.maint_notifications_config)
432
+ )
433
+
434
+ # Set up pool handler if available
435
+ if self._maint_notifications_pool_handler:
436
+ parser.set_node_moving_push_handler(
437
+ self._maint_notifications_pool_handler.handle_notification
438
+ )
439
+
440
+ # Set up connection handler
441
+ parser.set_maintenance_push_handler(
442
+ self._maint_notifications_connection_handler.handle_notification
531
443
  )
532
444
 
533
445
  # Store original connection parameters
@@ -541,14 +453,24 @@ class AbstractConnection(ConnectionInterface):
541
453
  else self.socket_connect_timeout
542
454
  )
543
455
 
544
- def set_maint_notifications_pool_handler(
456
+ def set_maint_notifications_pool_handler_for_connection(
545
457
  self, maint_notifications_pool_handler: MaintNotificationsPoolHandler
546
458
  ):
547
- maint_notifications_pool_handler.set_connection(self)
548
- self._parser.set_node_moving_push_handler(
549
- maint_notifications_pool_handler.handle_notification
459
+ # Deep copy the pool handler to avoid sharing the same pool handler
460
+ # between multiple connections, because otherwise each connection will override
461
+ # the connection reference and the pool handler will only hold a reference
462
+ # to the last connection that was set.
463
+ maint_notifications_pool_handler_copy = (
464
+ maint_notifications_pool_handler.get_handler_for_connection()
465
+ )
466
+
467
+ maint_notifications_pool_handler_copy.set_connection(self)
468
+ self._get_parser().set_node_moving_push_handler(
469
+ maint_notifications_pool_handler_copy.handle_notification
550
470
  )
551
471
 
472
+ self._maint_notifications_pool_handler = maint_notifications_pool_handler_copy
473
+
552
474
  # Update maintenance notification connection handler if it doesn't exist
553
475
  if not self._maint_notifications_connection_handler:
554
476
  self._maint_notifications_connection_handler = (
@@ -556,7 +478,7 @@ class AbstractConnection(ConnectionInterface):
556
478
  self, maint_notifications_pool_handler.config
557
479
  )
558
480
  )
559
- self._parser.set_maintenance_push_handler(
481
+ self._get_parser().set_maintenance_push_handler(
560
482
  self._maint_notifications_connection_handler.handle_notification
561
483
  )
562
484
  else:
@@ -564,130 +486,7 @@ class AbstractConnection(ConnectionInterface):
564
486
  maint_notifications_pool_handler.config
565
487
  )
566
488
 
567
- def connect(self):
568
- "Connects to the Redis server if not already connected"
569
- self.connect_check_health(check_health=True)
570
-
571
- def connect_check_health(
572
- self, check_health: bool = True, retry_socket_connect: bool = True
573
- ):
574
- if self._sock:
575
- return
576
- try:
577
- if retry_socket_connect:
578
- sock = self.retry.call_with_retry(
579
- lambda: self._connect(), lambda error: self.disconnect(error)
580
- )
581
- else:
582
- sock = self._connect()
583
- except socket.timeout:
584
- raise TimeoutError("Timeout connecting to server")
585
- except OSError as e:
586
- raise ConnectionError(self._error_message(e))
587
-
588
- self._sock = sock
589
- try:
590
- if self.redis_connect_func is None:
591
- # Use the default on_connect function
592
- self.on_connect_check_health(check_health=check_health)
593
- else:
594
- # Use the passed function redis_connect_func
595
- self.redis_connect_func(self)
596
- except RedisError:
597
- # clean up after any error in on_connect
598
- self.disconnect()
599
- raise
600
-
601
- # run any user callbacks. right now the only internal callback
602
- # is for pubsub channel/pattern resubscription
603
- # first, remove any dead weakrefs
604
- self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()]
605
- for ref in self._connect_callbacks:
606
- callback = ref()
607
- if callback:
608
- callback(self)
609
-
610
- @abstractmethod
611
- def _connect(self):
612
- pass
613
-
614
- @abstractmethod
615
- def _host_error(self):
616
- pass
617
-
618
- def _error_message(self, exception):
619
- return format_error_message(self._host_error(), exception)
620
-
621
- def on_connect(self):
622
- self.on_connect_check_health(check_health=True)
623
-
624
- def on_connect_check_health(self, check_health: bool = True):
625
- "Initialize the connection, authenticate and select a database"
626
- self._parser.on_connect(self)
627
- parser = self._parser
628
-
629
- auth_args = None
630
- # if credential provider or username and/or password are set, authenticate
631
- if self.credential_provider or (self.username or self.password):
632
- cred_provider = (
633
- self.credential_provider
634
- or UsernamePasswordCredentialProvider(self.username, self.password)
635
- )
636
- auth_args = cred_provider.get_credentials()
637
-
638
- # if resp version is specified and we have auth args,
639
- # we need to send them via HELLO
640
- if auth_args and self.protocol not in [2, "2"]:
641
- if isinstance(self._parser, _RESP2Parser):
642
- self.set_parser(_RESP3Parser)
643
- # update cluster exception classes
644
- self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
645
- self._parser.on_connect(self)
646
- if len(auth_args) == 1:
647
- auth_args = ["default", auth_args[0]]
648
- # avoid checking health here -- PING will fail if we try
649
- # to check the health prior to the AUTH
650
- self.send_command(
651
- "HELLO", self.protocol, "AUTH", *auth_args, check_health=False
652
- )
653
- self.handshake_metadata = self.read_response()
654
- # if response.get(b"proto") != self.protocol and response.get(
655
- # "proto"
656
- # ) != self.protocol:
657
- # raise ConnectionError("Invalid RESP version")
658
- elif auth_args:
659
- # avoid checking health here -- PING will fail if we try
660
- # to check the health prior to the AUTH
661
- self.send_command("AUTH", *auth_args, check_health=False)
662
-
663
- try:
664
- auth_response = self.read_response()
665
- except AuthenticationWrongNumberOfArgsError:
666
- # a username and password were specified but the Redis
667
- # server seems to be < 6.0.0 which expects a single password
668
- # arg. retry auth with just the password.
669
- # https://github.com/andymccurdy/redis-py/issues/1274
670
- self.send_command("AUTH", auth_args[-1], check_health=False)
671
- auth_response = self.read_response()
672
-
673
- if str_if_bytes(auth_response) != "OK":
674
- raise AuthenticationError("Invalid Username or Password")
675
-
676
- # if resp version is specified, switch to it
677
- elif self.protocol not in [2, "2"]:
678
- if isinstance(self._parser, _RESP2Parser):
679
- self.set_parser(_RESP3Parser)
680
- # update cluster exception classes
681
- self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
682
- self._parser.on_connect(self)
683
- self.send_command("HELLO", self.protocol, check_health=check_health)
684
- self.handshake_metadata = self.read_response()
685
- if (
686
- self.handshake_metadata.get(b"proto") != self.protocol
687
- and self.handshake_metadata.get("proto") != self.protocol
688
- ):
689
- raise ConnectionError("Invalid RESP version")
690
-
489
+ def activate_maint_notifications_handling_if_enabled(self, check_health=True):
691
490
  # Send maintenance notifications handshake if RESP3 is active
692
491
  # and maintenance notifications are enabled
693
492
  # and we have a host to determine the endpoint type from
@@ -695,16 +494,29 @@ class AbstractConnection(ConnectionInterface):
695
494
  # we just log a warning if the handshake fails
696
495
  # When the mode is enabled=True, we raise an exception in case of failure
697
496
  if (
698
- self.protocol not in [2, "2"]
497
+ self.get_protocol() not in [2, "2"]
699
498
  and self.maint_notifications_config
700
499
  and self.maint_notifications_config.enabled
701
500
  and self._maint_notifications_connection_handler
702
501
  and hasattr(self, "host")
703
502
  ):
704
- try:
705
- endpoint_type = self.maint_notifications_config.get_endpoint_type(
706
- self.host, self
503
+ self._enable_maintenance_notifications(
504
+ maint_notifications_config=self.maint_notifications_config,
505
+ check_health=check_health,
506
+ )
507
+
508
+ def _enable_maintenance_notifications(
509
+ self, maint_notifications_config: MaintNotificationsConfig, check_health=True
510
+ ):
511
+ try:
512
+ host = getattr(self, "host", None)
513
+ if host is None:
514
+ raise ValueError(
515
+ "Cannot enable maintenance notifications for connection"
516
+ " object that doesn't have a host attribute."
707
517
  )
518
+ else:
519
+ endpoint_type = maint_notifications_config.get_endpoint_type(host, self)
708
520
  self.send_command(
709
521
  "CLIENT",
710
522
  "MAINT_NOTIFICATIONS",
@@ -714,337 +526,105 @@ class AbstractConnection(ConnectionInterface):
714
526
  check_health=check_health,
715
527
  )
716
528
  response = self.read_response()
717
- if str_if_bytes(response) != "OK":
529
+ if not response or str_if_bytes(response) != "OK":
718
530
  raise ResponseError(
719
531
  "The server doesn't support maintenance notifications"
720
532
  )
721
- except Exception as e:
722
- if (
723
- isinstance(e, ResponseError)
724
- and self.maint_notifications_config.enabled == "auto"
725
- ):
726
- # Log warning but don't fail the connection
727
- import logging
533
+ except Exception as e:
534
+ if (
535
+ isinstance(e, ResponseError)
536
+ and maint_notifications_config.enabled == "auto"
537
+ ):
538
+ # Log warning but don't fail the connection
539
+ import logging
728
540
 
729
- logger = logging.getLogger(__name__)
730
- logger.warning(f"Failed to enable maintenance notifications: {e}")
731
- else:
732
- raise
541
+ logger = logging.getLogger(__name__)
542
+ logger.warning(f"Failed to enable maintenance notifications: {e}")
543
+ else:
544
+ raise
733
545
 
734
- # if a client_name is given, set it
735
- if self.client_name:
736
- self.send_command(
737
- "CLIENT",
738
- "SETNAME",
739
- self.client_name,
740
- check_health=check_health,
741
- )
742
- if str_if_bytes(self.read_response()) != "OK":
743
- raise ConnectionError("Error setting client name")
546
+ def get_resolved_ip(self) -> Optional[str]:
547
+ """
548
+ Extract the resolved IP address from an
549
+ established connection or resolve it from the host.
550
+
551
+ First tries to get the actual IP from the socket (most accurate),
552
+ then falls back to DNS resolution if needed.
553
+
554
+ Args:
555
+ connection: The connection object to extract the IP from
556
+
557
+ Returns:
558
+ str: The resolved IP address, or None if it cannot be determined
559
+ """
744
560
 
561
+ # Method 1: Try to get the actual IP from the established socket connection
562
+ # This is most accurate as it shows the exact IP being used
745
563
  try:
746
- # set the library name and version
747
- if self.lib_name:
748
- self.send_command(
749
- "CLIENT",
750
- "SETINFO",
751
- "LIB-NAME",
752
- self.lib_name,
753
- check_health=check_health,
754
- )
755
- self.read_response()
756
- except ResponseError:
564
+ conn_socket = self._get_socket()
565
+ if conn_socket is not None:
566
+ peer_addr = conn_socket.getpeername()
567
+ if peer_addr and len(peer_addr) >= 1:
568
+ # For TCP sockets, peer_addr is typically (host, port) tuple
569
+ # Return just the host part
570
+ return peer_addr[0]
571
+ except (AttributeError, OSError):
572
+ # Socket might not be connected or getpeername() might fail
757
573
  pass
758
574
 
575
+ # Method 2: Fallback to DNS resolution of the host
576
+ # This is less accurate but works when socket is not available
759
577
  try:
760
- if self.lib_version:
761
- self.send_command(
762
- "CLIENT",
763
- "SETINFO",
764
- "LIB-VER",
765
- self.lib_version,
766
- check_health=check_health,
578
+ host = getattr(self, "host", "localhost")
579
+ port = getattr(self, "port", 6379)
580
+ if host:
581
+ # Use getaddrinfo to resolve the hostname to IP
582
+ # This mimics what the connection would do during _connect()
583
+ addr_info = socket.getaddrinfo(
584
+ host, port, socket.AF_UNSPEC, socket.SOCK_STREAM
767
585
  )
768
- self.read_response()
769
- except ResponseError:
586
+ if addr_info:
587
+ # Return the IP from the first result
588
+ # addr_info[0] is (family, socktype, proto, canonname, sockaddr)
589
+ # sockaddr[0] is the IP address
590
+ return str(addr_info[0][4][0])
591
+ except (AttributeError, OSError, socket.gaierror):
592
+ # DNS resolution might fail
770
593
  pass
771
594
 
772
- # if a database is specified, switch to it
773
- if self.db:
774
- self.send_command("SELECT", self.db, check_health=check_health)
775
- if str_if_bytes(self.read_response()) != "OK":
776
- raise ConnectionError("Invalid Database")
595
+ return None
777
596
 
778
- def disconnect(self, *args):
779
- "Disconnects from the Redis server"
780
- self._parser.on_disconnect()
597
+ @property
598
+ def maintenance_state(self) -> MaintenanceState:
599
+ return self._maintenance_state
781
600
 
782
- conn_sock = self._sock
783
- self._sock = None
784
- # reset the reconnect flag
785
- self._should_reconnect = False
786
- if conn_sock is None:
787
- return
601
+ @maintenance_state.setter
602
+ def maintenance_state(self, state: "MaintenanceState"):
603
+ self._maintenance_state = state
788
604
 
789
- if os.getpid() == self.pid:
790
- try:
791
- conn_sock.shutdown(socket.SHUT_RDWR)
792
- except (OSError, TypeError):
793
- pass
794
-
795
- try:
796
- conn_sock.close()
797
- except OSError:
798
- pass
799
-
800
- def _send_ping(self):
801
- """Send PING, expect PONG in return"""
802
- self.send_command("PING", check_health=False)
803
- if str_if_bytes(self.read_response()) != "PONG":
804
- raise ConnectionError("Bad response from PING health check")
805
-
806
- def _ping_failed(self, error):
807
- """Function to call when PING fails"""
808
- self.disconnect()
809
-
810
- def check_health(self):
811
- """Check the health of the connection with a PING/PONG"""
812
- if self.health_check_interval and time.monotonic() > self.next_health_check:
813
- self.retry.call_with_retry(self._send_ping, self._ping_failed)
814
-
815
- def send_packed_command(self, command, check_health=True):
816
- """Send an already packed command to the Redis server"""
817
- if not self._sock:
818
- self.connect_check_health(check_health=False)
819
- # guard against health check recursion
820
- if check_health:
821
- self.check_health()
822
- try:
823
- if isinstance(command, str):
824
- command = [command]
825
- for item in command:
826
- self._sock.sendall(item)
827
- except socket.timeout:
828
- self.disconnect()
829
- raise TimeoutError("Timeout writing to socket")
830
- except OSError as e:
831
- self.disconnect()
832
- if len(e.args) == 1:
833
- errno, errmsg = "UNKNOWN", e.args[0]
834
- else:
835
- errno = e.args[0]
836
- errmsg = e.args[1]
837
- raise ConnectionError(f"Error {errno} while writing to socket. {errmsg}.")
838
- except BaseException:
839
- # BaseExceptions can be raised when a socket send operation is not
840
- # finished, e.g. due to a timeout. Ideally, a caller could then re-try
841
- # to send un-sent data. However, the send_packed_command() API
842
- # does not support it so there is no point in keeping the connection open.
843
- self.disconnect()
844
- raise
845
-
846
- def send_command(self, *args, **kwargs):
847
- """Pack and send a command to the Redis server"""
848
- self.send_packed_command(
849
- self._command_packer.pack(*args),
850
- check_health=kwargs.get("check_health", True),
851
- )
852
-
853
- def can_read(self, timeout=0):
854
- """Poll the socket to see if there's data that can be read."""
855
- sock = self._sock
856
- if not sock:
857
- self.connect()
858
-
859
- host_error = self._host_error()
860
-
861
- try:
862
- return self._parser.can_read(timeout)
863
-
864
- except OSError as e:
865
- self.disconnect()
866
- raise ConnectionError(f"Error while reading from {host_error}: {e.args}")
867
-
868
- def read_response(
869
- self,
870
- disable_decoding=False,
871
- *,
872
- disconnect_on_error=True,
873
- push_request=False,
874
- ):
875
- """Read the response from a previously sent command"""
876
-
877
- host_error = self._host_error()
878
-
879
- try:
880
- if self.protocol in ["3", 3]:
881
- response = self._parser.read_response(
882
- disable_decoding=disable_decoding, push_request=push_request
883
- )
884
- else:
885
- response = self._parser.read_response(disable_decoding=disable_decoding)
886
- except socket.timeout:
887
- if disconnect_on_error:
888
- self.disconnect()
889
- raise TimeoutError(f"Timeout reading from {host_error}")
890
- except OSError as e:
891
- if disconnect_on_error:
892
- self.disconnect()
893
- raise ConnectionError(f"Error while reading from {host_error} : {e.args}")
894
- except BaseException:
895
- # Also by default close in case of BaseException. A lot of code
896
- # relies on this behaviour when doing Command/Response pairs.
897
- # See #1128.
898
- if disconnect_on_error:
899
- self.disconnect()
900
- raise
901
-
902
- if self.health_check_interval:
903
- self.next_health_check = time.monotonic() + self.health_check_interval
904
-
905
- if isinstance(response, ResponseError):
906
- try:
907
- raise response
908
- finally:
909
- del response # avoid creating ref cycles
910
- return response
911
-
912
- def pack_command(self, *args):
913
- """Pack a series of arguments into the Redis protocol"""
914
- return self._command_packer.pack(*args)
915
-
916
- def pack_commands(self, commands):
917
- """Pack multiple commands into the Redis protocol"""
918
- output = []
919
- pieces = []
920
- buffer_length = 0
921
- buffer_cutoff = self._buffer_cutoff
922
-
923
- for cmd in commands:
924
- for chunk in self._command_packer.pack(*cmd):
925
- chunklen = len(chunk)
926
- if (
927
- buffer_length > buffer_cutoff
928
- or chunklen > buffer_cutoff
929
- or isinstance(chunk, memoryview)
930
- ):
931
- if pieces:
932
- output.append(SYM_EMPTY.join(pieces))
933
- buffer_length = 0
934
- pieces = []
935
-
936
- if chunklen > buffer_cutoff or isinstance(chunk, memoryview):
937
- output.append(chunk)
938
- else:
939
- pieces.append(chunk)
940
- buffer_length += chunklen
941
-
942
- if pieces:
943
- output.append(SYM_EMPTY.join(pieces))
944
- return output
945
-
946
- def get_protocol(self) -> Union[int, str]:
947
- return self.protocol
948
-
949
- @property
950
- def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]:
951
- return self._handshake_metadata
952
-
953
- @handshake_metadata.setter
954
- def handshake_metadata(self, value: Union[Dict[bytes, bytes], Dict[str, str]]):
955
- self._handshake_metadata = value
956
-
957
- def set_re_auth_token(self, token: TokenInterface):
958
- self._re_auth_token = token
959
-
960
- def re_auth(self):
961
- if self._re_auth_token is not None:
962
- self.send_command(
963
- "AUTH",
964
- self._re_auth_token.try_get("oid"),
965
- self._re_auth_token.get_value(),
966
- )
967
- self.read_response()
968
- self._re_auth_token = None
969
-
970
- def get_resolved_ip(self) -> Optional[str]:
971
- """
972
- Extract the resolved IP address from an
973
- established connection or resolve it from the host.
974
-
975
- First tries to get the actual IP from the socket (most accurate),
976
- then falls back to DNS resolution if needed.
977
-
978
- Args:
979
- connection: The connection object to extract the IP from
980
-
981
- Returns:
982
- str: The resolved IP address, or None if it cannot be determined
983
- """
984
-
985
- # Method 1: Try to get the actual IP from the established socket connection
986
- # This is most accurate as it shows the exact IP being used
987
- try:
988
- if self._sock is not None:
989
- peer_addr = self._sock.getpeername()
990
- if peer_addr and len(peer_addr) >= 1:
991
- # For TCP sockets, peer_addr is typically (host, port) tuple
992
- # Return just the host part
993
- return peer_addr[0]
994
- except (AttributeError, OSError):
995
- # Socket might not be connected or getpeername() might fail
996
- pass
997
-
998
- # Method 2: Fallback to DNS resolution of the host
999
- # This is less accurate but works when socket is not available
1000
- try:
1001
- host = getattr(self, "host", "localhost")
1002
- port = getattr(self, "port", 6379)
1003
- if host:
1004
- # Use getaddrinfo to resolve the hostname to IP
1005
- # This mimics what the connection would do during _connect()
1006
- addr_info = socket.getaddrinfo(
1007
- host, port, socket.AF_UNSPEC, socket.SOCK_STREAM
1008
- )
1009
- if addr_info:
1010
- # Return the IP from the first result
1011
- # addr_info[0] is (family, socktype, proto, canonname, sockaddr)
1012
- # sockaddr[0] is the IP address
1013
- return addr_info[0][4][0]
1014
- except (AttributeError, OSError, socket.gaierror):
1015
- # DNS resolution might fail
1016
- pass
1017
-
1018
- return None
1019
-
1020
- @property
1021
- def maintenance_state(self) -> MaintenanceState:
1022
- return self._maintenance_state
1023
-
1024
- @maintenance_state.setter
1025
- def maintenance_state(self, state: "MaintenanceState"):
1026
- self._maintenance_state = state
1027
-
1028
- def getpeername(self):
1029
- if not self._sock:
1030
- return None
1031
- return self._sock.getpeername()[0]
1032
-
1033
- def mark_for_reconnect(self):
1034
- self._should_reconnect = True
1035
-
1036
- def should_reconnect(self):
1037
- return self._should_reconnect
605
+ def getpeername(self):
606
+ """
607
+ Returns the peer name of the connection.
608
+ """
609
+ conn_socket = self._get_socket()
610
+ if conn_socket:
611
+ return conn_socket.getpeername()[0]
612
+ return None
1038
613
 
1039
614
  def update_current_socket_timeout(self, relaxed_timeout: Optional[float] = None):
1040
- if self._sock:
615
+ conn_socket = self._get_socket()
616
+ if conn_socket:
1041
617
  timeout = relaxed_timeout if relaxed_timeout != -1 else self.socket_timeout
1042
- self._sock.settimeout(timeout)
1043
- self.update_parser_buffer_timeout(timeout)
618
+ conn_socket.settimeout(timeout)
619
+ self.update_parser_timeout(timeout)
1044
620
 
1045
- def update_parser_buffer_timeout(self, timeout: Optional[float] = None):
1046
- if self._parser and self._parser._buffer:
1047
- self._parser._buffer.socket_timeout = timeout
621
+ def update_parser_timeout(self, timeout: Optional[float] = None):
622
+ parser = self._get_parser()
623
+ if parser and parser._buffer:
624
+ if isinstance(parser, _RESP3Parser) and timeout:
625
+ parser._buffer.socket_timeout = timeout
626
+ elif isinstance(parser, _HiredisParser):
627
+ parser._socket_timeout = timeout
1048
628
 
1049
629
  def set_tmp_settings(
1050
630
  self,
@@ -1054,8 +634,8 @@ class AbstractConnection(ConnectionInterface):
1054
634
  """
1055
635
  The value of SENTINEL is used to indicate that the property should not be updated.
1056
636
  """
1057
- if tmp_host_address is not SENTINEL:
1058
- self.host = tmp_host_address
637
+ if tmp_host_address and tmp_host_address != SENTINEL:
638
+ self.host = str(tmp_host_address)
1059
639
  if tmp_relaxed_timeout != -1:
1060
640
  self.socket_timeout = tmp_relaxed_timeout
1061
641
  self.socket_connect_timeout = tmp_relaxed_timeout
@@ -1072,26 +652,609 @@ class AbstractConnection(ConnectionInterface):
1072
652
  self.socket_connect_timeout = self.orig_socket_connect_timeout
1073
653
 
1074
654
 
1075
- class Connection(AbstractConnection):
1076
- "Manages TCP communication to and from a Redis server"
655
+ class AbstractConnection(MaintNotificationsAbstractConnection, ConnectionInterface):
656
+ "Manages communication to and from a Redis server"
1077
657
 
1078
658
  def __init__(
1079
659
  self,
1080
- host="localhost",
1081
- port=6379,
1082
- socket_keepalive=False,
1083
- socket_keepalive_options=None,
1084
- socket_type=0,
1085
- **kwargs,
1086
- ):
1087
- self.host = host
1088
- self.port = int(port)
1089
- self.socket_keepalive = socket_keepalive
1090
- self.socket_keepalive_options = socket_keepalive_options or {}
1091
- self.socket_type = socket_type
1092
- super().__init__(**kwargs)
1093
-
1094
- def repr_pieces(self):
660
+ db: int = 0,
661
+ password: Optional[str] = None,
662
+ socket_timeout: Optional[float] = None,
663
+ socket_connect_timeout: Optional[float] = None,
664
+ retry_on_timeout: bool = False,
665
+ retry_on_error: Union[Iterable[Type[Exception]], object] = SENTINEL,
666
+ encoding: str = "utf-8",
667
+ encoding_errors: str = "strict",
668
+ decode_responses: bool = False,
669
+ parser_class=DefaultParser,
670
+ socket_read_size: int = 65536,
671
+ health_check_interval: int = 0,
672
+ client_name: Optional[str] = None,
673
+ lib_name: Optional[str] = "redis-py",
674
+ lib_version: Optional[str] = get_lib_version(),
675
+ username: Optional[str] = None,
676
+ retry: Union[Any, None] = None,
677
+ redis_connect_func: Optional[Callable[[], None]] = None,
678
+ credential_provider: Optional[CredentialProvider] = None,
679
+ protocol: Optional[int] = 2,
680
+ command_packer: Optional[Callable[[], None]] = None,
681
+ event_dispatcher: Optional[EventDispatcher] = None,
682
+ maint_notifications_config: Optional[MaintNotificationsConfig] = None,
683
+ maint_notifications_pool_handler: Optional[
684
+ MaintNotificationsPoolHandler
685
+ ] = None,
686
+ maintenance_state: "MaintenanceState" = MaintenanceState.NONE,
687
+ maintenance_notification_hash: Optional[int] = None,
688
+ orig_host_address: Optional[str] = None,
689
+ orig_socket_timeout: Optional[float] = None,
690
+ orig_socket_connect_timeout: Optional[float] = None,
691
+ ):
692
+ """
693
+ Initialize a new Connection.
694
+ To specify a retry policy for specific errors, first set
695
+ `retry_on_error` to a list of the error/s to retry on, then set
696
+ `retry` to a valid `Retry` object.
697
+ To retry on TimeoutError, `retry_on_timeout` can also be set to `True`.
698
+ """
699
+ if (username or password) and credential_provider is not None:
700
+ raise DataError(
701
+ "'username' and 'password' cannot be passed along with 'credential_"
702
+ "provider'. Please provide only one of the following arguments: \n"
703
+ "1. 'password' and (optional) 'username'\n"
704
+ "2. 'credential_provider'"
705
+ )
706
+ if event_dispatcher is None:
707
+ self._event_dispatcher = EventDispatcher()
708
+ else:
709
+ self._event_dispatcher = event_dispatcher
710
+ self.pid = os.getpid()
711
+ self.db = db
712
+ self.client_name = client_name
713
+ self.lib_name = lib_name
714
+ self.lib_version = lib_version
715
+ self.credential_provider = credential_provider
716
+ self.password = password
717
+ self.username = username
718
+ self._socket_timeout = socket_timeout
719
+ if socket_connect_timeout is None:
720
+ socket_connect_timeout = socket_timeout
721
+ self._socket_connect_timeout = socket_connect_timeout
722
+ self.retry_on_timeout = retry_on_timeout
723
+ if retry_on_error is SENTINEL:
724
+ retry_on_errors_list = []
725
+ else:
726
+ retry_on_errors_list = list(retry_on_error)
727
+ if retry_on_timeout:
728
+ # Add TimeoutError to the errors list to retry on
729
+ retry_on_errors_list.append(TimeoutError)
730
+ self.retry_on_error = retry_on_errors_list
731
+ if retry or self.retry_on_error:
732
+ if retry is None:
733
+ self.retry = Retry(NoBackoff(), 1)
734
+ else:
735
+ # deep-copy the Retry object as it is mutable
736
+ self.retry = copy.deepcopy(retry)
737
+ if self.retry_on_error:
738
+ # Update the retry's supported errors with the specified errors
739
+ self.retry.update_supported_errors(self.retry_on_error)
740
+ else:
741
+ self.retry = Retry(NoBackoff(), 0)
742
+ self.health_check_interval = health_check_interval
743
+ self.next_health_check = 0
744
+ self.redis_connect_func = redis_connect_func
745
+ self.encoder = Encoder(encoding, encoding_errors, decode_responses)
746
+ self.handshake_metadata = None
747
+ self._sock = None
748
+ self._socket_read_size = socket_read_size
749
+ self._connect_callbacks = []
750
+ self._buffer_cutoff = 6000
751
+ self._re_auth_token: Optional[TokenInterface] = None
752
+ try:
753
+ p = int(protocol)
754
+ except TypeError:
755
+ p = DEFAULT_RESP_VERSION
756
+ except ValueError:
757
+ raise ConnectionError("protocol must be an integer")
758
+ finally:
759
+ if p < 2 or p > 3:
760
+ raise ConnectionError("protocol must be either 2 or 3")
761
+ # p = DEFAULT_RESP_VERSION
762
+ self.protocol = p
763
+ if self.protocol == 3 and parser_class == _RESP2Parser:
764
+ # If the protocol is 3 but the parser is RESP2, change it to RESP3
765
+ # This is needed because the parser might be set before the protocol
766
+ # or might be provided as a kwarg to the constructor
767
+ # We need to react on discrepancy only for RESP2 and RESP3
768
+ # as hiredis supports both
769
+ parser_class = _RESP3Parser
770
+ self.set_parser(parser_class)
771
+
772
+ self._command_packer = self._construct_command_packer(command_packer)
773
+ self._should_reconnect = False
774
+
775
+ # Set up maintenance notifications
776
+ MaintNotificationsAbstractConnection.__init__(
777
+ self,
778
+ maint_notifications_config,
779
+ maint_notifications_pool_handler,
780
+ maintenance_state,
781
+ maintenance_notification_hash,
782
+ orig_host_address,
783
+ orig_socket_timeout,
784
+ orig_socket_connect_timeout,
785
+ self._parser,
786
+ )
787
+
788
+ def __repr__(self):
789
+ repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()])
790
+ return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>"
791
+
792
+ @abstractmethod
793
+ def repr_pieces(self):
794
+ pass
795
+
796
+ def __del__(self):
797
+ try:
798
+ self.disconnect()
799
+ except Exception:
800
+ pass
801
+
802
+ def _construct_command_packer(self, packer):
803
+ if packer is not None:
804
+ return packer
805
+ elif HIREDIS_AVAILABLE:
806
+ return HiredisRespSerializer()
807
+ else:
808
+ return PythonRespSerializer(self._buffer_cutoff, self.encoder.encode)
809
+
810
+ def register_connect_callback(self, callback):
811
+ """
812
+ Register a callback to be called when the connection is established either
813
+ initially or reconnected. This allows listeners to issue commands that
814
+ are ephemeral to the connection, for example pub/sub subscription or
815
+ key tracking. The callback must be a _method_ and will be kept as
816
+ a weak reference.
817
+ """
818
+ wm = weakref.WeakMethod(callback)
819
+ if wm not in self._connect_callbacks:
820
+ self._connect_callbacks.append(wm)
821
+
822
+ def deregister_connect_callback(self, callback):
823
+ """
824
+ De-register a previously registered callback. It will no-longer receive
825
+ notifications on connection events. Calling this is not required when the
826
+ listener goes away, since the callbacks are kept as weak methods.
827
+ """
828
+ try:
829
+ self._connect_callbacks.remove(weakref.WeakMethod(callback))
830
+ except ValueError:
831
+ pass
832
+
833
+ def set_parser(self, parser_class):
834
+ """
835
+ Creates a new instance of parser_class with socket size:
836
+ _socket_read_size and assigns it to the parser for the connection
837
+ :param parser_class: The required parser class
838
+ """
839
+ self._parser = parser_class(socket_read_size=self._socket_read_size)
840
+
841
+ def _get_parser(self) -> Union[_HiredisParser, _RESP3Parser, _RESP2Parser]:
842
+ return self._parser
843
+
844
+ def connect(self):
845
+ "Connects to the Redis server if not already connected"
846
+ self.connect_check_health(check_health=True)
847
+
848
+ def connect_check_health(
849
+ self, check_health: bool = True, retry_socket_connect: bool = True
850
+ ):
851
+ if self._sock:
852
+ return
853
+ try:
854
+ if retry_socket_connect:
855
+ sock = self.retry.call_with_retry(
856
+ lambda: self._connect(), lambda error: self.disconnect(error)
857
+ )
858
+ else:
859
+ sock = self._connect()
860
+ except socket.timeout:
861
+ raise TimeoutError("Timeout connecting to server")
862
+ except OSError as e:
863
+ raise ConnectionError(self._error_message(e))
864
+
865
+ self._sock = sock
866
+ try:
867
+ if self.redis_connect_func is None:
868
+ # Use the default on_connect function
869
+ self.on_connect_check_health(check_health=check_health)
870
+ else:
871
+ # Use the passed function redis_connect_func
872
+ self.redis_connect_func(self)
873
+ except RedisError:
874
+ # clean up after any error in on_connect
875
+ self.disconnect()
876
+ raise
877
+
878
+ # run any user callbacks. right now the only internal callback
879
+ # is for pubsub channel/pattern resubscription
880
+ # first, remove any dead weakrefs
881
+ self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()]
882
+ for ref in self._connect_callbacks:
883
+ callback = ref()
884
+ if callback:
885
+ callback(self)
886
+
887
+ @abstractmethod
888
+ def _connect(self):
889
+ pass
890
+
891
+ @abstractmethod
892
+ def _host_error(self):
893
+ pass
894
+
895
+ def _error_message(self, exception):
896
+ return format_error_message(self._host_error(), exception)
897
+
898
+ def on_connect(self):
899
+ self.on_connect_check_health(check_health=True)
900
+
901
+ def on_connect_check_health(self, check_health: bool = True):
902
+ "Initialize the connection, authenticate and select a database"
903
+ self._parser.on_connect(self)
904
+ parser = self._parser
905
+
906
+ auth_args = None
907
+ # if credential provider or username and/or password are set, authenticate
908
+ if self.credential_provider or (self.username or self.password):
909
+ cred_provider = (
910
+ self.credential_provider
911
+ or UsernamePasswordCredentialProvider(self.username, self.password)
912
+ )
913
+ auth_args = cred_provider.get_credentials()
914
+
915
+ # if resp version is specified and we have auth args,
916
+ # we need to send them via HELLO
917
+ if auth_args and self.protocol not in [2, "2"]:
918
+ if isinstance(self._parser, _RESP2Parser):
919
+ self.set_parser(_RESP3Parser)
920
+ # update cluster exception classes
921
+ self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
922
+ self._parser.on_connect(self)
923
+ if len(auth_args) == 1:
924
+ auth_args = ["default", auth_args[0]]
925
+ # avoid checking health here -- PING will fail if we try
926
+ # to check the health prior to the AUTH
927
+ self.send_command(
928
+ "HELLO", self.protocol, "AUTH", *auth_args, check_health=False
929
+ )
930
+ self.handshake_metadata = self.read_response()
931
+ # if response.get(b"proto") != self.protocol and response.get(
932
+ # "proto"
933
+ # ) != self.protocol:
934
+ # raise ConnectionError("Invalid RESP version")
935
+ elif auth_args:
936
+ # avoid checking health here -- PING will fail if we try
937
+ # to check the health prior to the AUTH
938
+ self.send_command("AUTH", *auth_args, check_health=False)
939
+
940
+ try:
941
+ auth_response = self.read_response()
942
+ except AuthenticationWrongNumberOfArgsError:
943
+ # a username and password were specified but the Redis
944
+ # server seems to be < 6.0.0 which expects a single password
945
+ # arg. retry auth with just the password.
946
+ # https://github.com/andymccurdy/redis-py/issues/1274
947
+ self.send_command("AUTH", auth_args[-1], check_health=False)
948
+ auth_response = self.read_response()
949
+
950
+ if str_if_bytes(auth_response) != "OK":
951
+ raise AuthenticationError("Invalid Username or Password")
952
+
953
+ # if resp version is specified, switch to it
954
+ elif self.protocol not in [2, "2"]:
955
+ if isinstance(self._parser, _RESP2Parser):
956
+ self.set_parser(_RESP3Parser)
957
+ # update cluster exception classes
958
+ self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
959
+ self._parser.on_connect(self)
960
+ self.send_command("HELLO", self.protocol, check_health=check_health)
961
+ self.handshake_metadata = self.read_response()
962
+ if (
963
+ self.handshake_metadata.get(b"proto") != self.protocol
964
+ and self.handshake_metadata.get("proto") != self.protocol
965
+ ):
966
+ raise ConnectionError("Invalid RESP version")
967
+
968
+ # Activate maintenance notifications for this connection
969
+ # if enabled in the configuration
970
+ # This is a no-op if maintenance notifications are not enabled
971
+ self.activate_maint_notifications_handling_if_enabled(check_health=check_health)
972
+
973
+ # if a client_name is given, set it
974
+ if self.client_name:
975
+ self.send_command(
976
+ "CLIENT",
977
+ "SETNAME",
978
+ self.client_name,
979
+ check_health=check_health,
980
+ )
981
+ if str_if_bytes(self.read_response()) != "OK":
982
+ raise ConnectionError("Error setting client name")
983
+
984
+ try:
985
+ # set the library name and version
986
+ if self.lib_name:
987
+ self.send_command(
988
+ "CLIENT",
989
+ "SETINFO",
990
+ "LIB-NAME",
991
+ self.lib_name,
992
+ check_health=check_health,
993
+ )
994
+ self.read_response()
995
+ except ResponseError:
996
+ pass
997
+
998
+ try:
999
+ if self.lib_version:
1000
+ self.send_command(
1001
+ "CLIENT",
1002
+ "SETINFO",
1003
+ "LIB-VER",
1004
+ self.lib_version,
1005
+ check_health=check_health,
1006
+ )
1007
+ self.read_response()
1008
+ except ResponseError:
1009
+ pass
1010
+
1011
+ # if a database is specified, switch to it
1012
+ if self.db:
1013
+ self.send_command("SELECT", self.db, check_health=check_health)
1014
+ if str_if_bytes(self.read_response()) != "OK":
1015
+ raise ConnectionError("Invalid Database")
1016
+
1017
+ def disconnect(self, *args):
1018
+ "Disconnects from the Redis server"
1019
+ self._parser.on_disconnect()
1020
+
1021
+ conn_sock = self._sock
1022
+ self._sock = None
1023
+ # reset the reconnect flag
1024
+ self.reset_should_reconnect()
1025
+ if conn_sock is None:
1026
+ return
1027
+
1028
+ if os.getpid() == self.pid:
1029
+ try:
1030
+ conn_sock.shutdown(socket.SHUT_RDWR)
1031
+ except (OSError, TypeError):
1032
+ pass
1033
+
1034
+ try:
1035
+ conn_sock.close()
1036
+ except OSError:
1037
+ pass
1038
+
1039
+ def mark_for_reconnect(self):
1040
+ self._should_reconnect = True
1041
+
1042
+ def should_reconnect(self):
1043
+ return self._should_reconnect
1044
+
1045
+ def reset_should_reconnect(self):
1046
+ self._should_reconnect = False
1047
+
1048
+ def _send_ping(self):
1049
+ """Send PING, expect PONG in return"""
1050
+ self.send_command("PING", check_health=False)
1051
+ if str_if_bytes(self.read_response()) != "PONG":
1052
+ raise ConnectionError("Bad response from PING health check")
1053
+
1054
+ def _ping_failed(self, error):
1055
+ """Function to call when PING fails"""
1056
+ self.disconnect()
1057
+
1058
+ def check_health(self):
1059
+ """Check the health of the connection with a PING/PONG"""
1060
+ if self.health_check_interval and time.monotonic() > self.next_health_check:
1061
+ self.retry.call_with_retry(self._send_ping, self._ping_failed)
1062
+
1063
+ def send_packed_command(self, command, check_health=True):
1064
+ """Send an already packed command to the Redis server"""
1065
+ if not self._sock:
1066
+ self.connect_check_health(check_health=False)
1067
+ # guard against health check recursion
1068
+ if check_health:
1069
+ self.check_health()
1070
+ try:
1071
+ if isinstance(command, str):
1072
+ command = [command]
1073
+ for item in command:
1074
+ self._sock.sendall(item)
1075
+ except socket.timeout:
1076
+ self.disconnect()
1077
+ raise TimeoutError("Timeout writing to socket")
1078
+ except OSError as e:
1079
+ self.disconnect()
1080
+ if len(e.args) == 1:
1081
+ errno, errmsg = "UNKNOWN", e.args[0]
1082
+ else:
1083
+ errno = e.args[0]
1084
+ errmsg = e.args[1]
1085
+ raise ConnectionError(f"Error {errno} while writing to socket. {errmsg}.")
1086
+ except BaseException:
1087
+ # BaseExceptions can be raised when a socket send operation is not
1088
+ # finished, e.g. due to a timeout. Ideally, a caller could then re-try
1089
+ # to send un-sent data. However, the send_packed_command() API
1090
+ # does not support it so there is no point in keeping the connection open.
1091
+ self.disconnect()
1092
+ raise
1093
+
1094
+ def send_command(self, *args, **kwargs):
1095
+ """Pack and send a command to the Redis server"""
1096
+ self.send_packed_command(
1097
+ self._command_packer.pack(*args),
1098
+ check_health=kwargs.get("check_health", True),
1099
+ )
1100
+
1101
+ def can_read(self, timeout=0):
1102
+ """Poll the socket to see if there's data that can be read."""
1103
+ sock = self._sock
1104
+ if not sock:
1105
+ self.connect()
1106
+
1107
+ host_error = self._host_error()
1108
+
1109
+ try:
1110
+ return self._parser.can_read(timeout)
1111
+
1112
+ except OSError as e:
1113
+ self.disconnect()
1114
+ raise ConnectionError(f"Error while reading from {host_error}: {e.args}")
1115
+
1116
+ def read_response(
1117
+ self,
1118
+ disable_decoding=False,
1119
+ *,
1120
+ disconnect_on_error=True,
1121
+ push_request=False,
1122
+ ):
1123
+ """Read the response from a previously sent command"""
1124
+
1125
+ host_error = self._host_error()
1126
+
1127
+ try:
1128
+ if self.protocol in ["3", 3]:
1129
+ response = self._parser.read_response(
1130
+ disable_decoding=disable_decoding, push_request=push_request
1131
+ )
1132
+ else:
1133
+ response = self._parser.read_response(disable_decoding=disable_decoding)
1134
+ except socket.timeout:
1135
+ if disconnect_on_error:
1136
+ self.disconnect()
1137
+ raise TimeoutError(f"Timeout reading from {host_error}")
1138
+ except OSError as e:
1139
+ if disconnect_on_error:
1140
+ self.disconnect()
1141
+ raise ConnectionError(f"Error while reading from {host_error} : {e.args}")
1142
+ except BaseException:
1143
+ # Also by default close in case of BaseException. A lot of code
1144
+ # relies on this behaviour when doing Command/Response pairs.
1145
+ # See #1128.
1146
+ if disconnect_on_error:
1147
+ self.disconnect()
1148
+ raise
1149
+
1150
+ if self.health_check_interval:
1151
+ self.next_health_check = time.monotonic() + self.health_check_interval
1152
+
1153
+ if isinstance(response, ResponseError):
1154
+ try:
1155
+ raise response
1156
+ finally:
1157
+ del response # avoid creating ref cycles
1158
+ return response
1159
+
1160
+ def pack_command(self, *args):
1161
+ """Pack a series of arguments into the Redis protocol"""
1162
+ return self._command_packer.pack(*args)
1163
+
1164
+ def pack_commands(self, commands):
1165
+ """Pack multiple commands into the Redis protocol"""
1166
+ output = []
1167
+ pieces = []
1168
+ buffer_length = 0
1169
+ buffer_cutoff = self._buffer_cutoff
1170
+
1171
+ for cmd in commands:
1172
+ for chunk in self._command_packer.pack(*cmd):
1173
+ chunklen = len(chunk)
1174
+ if (
1175
+ buffer_length > buffer_cutoff
1176
+ or chunklen > buffer_cutoff
1177
+ or isinstance(chunk, memoryview)
1178
+ ):
1179
+ if pieces:
1180
+ output.append(SYM_EMPTY.join(pieces))
1181
+ buffer_length = 0
1182
+ pieces = []
1183
+
1184
+ if chunklen > buffer_cutoff or isinstance(chunk, memoryview):
1185
+ output.append(chunk)
1186
+ else:
1187
+ pieces.append(chunk)
1188
+ buffer_length += chunklen
1189
+
1190
+ if pieces:
1191
+ output.append(SYM_EMPTY.join(pieces))
1192
+ return output
1193
+
1194
+ def get_protocol(self) -> Union[int, str]:
1195
+ return self.protocol
1196
+
1197
+ @property
1198
+ def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]:
1199
+ return self._handshake_metadata
1200
+
1201
+ @handshake_metadata.setter
1202
+ def handshake_metadata(self, value: Union[Dict[bytes, bytes], Dict[str, str]]):
1203
+ self._handshake_metadata = value
1204
+
1205
+ def set_re_auth_token(self, token: TokenInterface):
1206
+ self._re_auth_token = token
1207
+
1208
+ def re_auth(self):
1209
+ if self._re_auth_token is not None:
1210
+ self.send_command(
1211
+ "AUTH",
1212
+ self._re_auth_token.try_get("oid"),
1213
+ self._re_auth_token.get_value(),
1214
+ )
1215
+ self.read_response()
1216
+ self._re_auth_token = None
1217
+
1218
+ def _get_socket(self) -> Optional[socket.socket]:
1219
+ return self._sock
1220
+
1221
+ @property
1222
+ def socket_timeout(self) -> Optional[Union[float, int]]:
1223
+ return self._socket_timeout
1224
+
1225
+ @socket_timeout.setter
1226
+ def socket_timeout(self, value: Optional[Union[float, int]]):
1227
+ self._socket_timeout = value
1228
+
1229
+ @property
1230
+ def socket_connect_timeout(self) -> Optional[Union[float, int]]:
1231
+ return self._socket_connect_timeout
1232
+
1233
+ @socket_connect_timeout.setter
1234
+ def socket_connect_timeout(self, value: Optional[Union[float, int]]):
1235
+ self._socket_connect_timeout = value
1236
+
1237
+
1238
+ class Connection(AbstractConnection):
1239
+ "Manages TCP communication to and from a Redis server"
1240
+
1241
+ def __init__(
1242
+ self,
1243
+ host="localhost",
1244
+ port=6379,
1245
+ socket_keepalive=False,
1246
+ socket_keepalive_options=None,
1247
+ socket_type=0,
1248
+ **kwargs,
1249
+ ):
1250
+ self._host = host
1251
+ self.port = int(port)
1252
+ self.socket_keepalive = socket_keepalive
1253
+ self.socket_keepalive_options = socket_keepalive_options or {}
1254
+ self.socket_type = socket_type
1255
+ super().__init__(**kwargs)
1256
+
1257
+ def repr_pieces(self):
1095
1258
  pieces = [("host", self.host), ("port", self.port), ("db", self.db)]
1096
1259
  if self.client_name:
1097
1260
  pieces.append(("client_name", self.client_name))
@@ -1146,8 +1309,16 @@ class Connection(AbstractConnection):
1146
1309
  def _host_error(self):
1147
1310
  return f"{self.host}:{self.port}"
1148
1311
 
1312
+ @property
1313
+ def host(self) -> str:
1314
+ return self._host
1315
+
1316
+ @host.setter
1317
+ def host(self, value: str):
1318
+ self._host = value
1319
+
1149
1320
 
1150
- class CacheProxyConnection(ConnectionInterface):
1321
+ class CacheProxyConnection(MaintNotificationsAbstractConnection, ConnectionInterface):
1151
1322
  DUMMY_CACHE_VALUE = b"foo"
1152
1323
  MIN_ALLOWED_VERSION = "7.4.0"
1153
1324
  DEFAULT_SERVER_NAME = "redis"
@@ -1171,6 +1342,19 @@ class CacheProxyConnection(ConnectionInterface):
1171
1342
  self._current_options = None
1172
1343
  self.register_connect_callback(self._enable_tracking_callback)
1173
1344
 
1345
+ if isinstance(self._conn, MaintNotificationsAbstractConnection):
1346
+ MaintNotificationsAbstractConnection.__init__(
1347
+ self,
1348
+ self._conn.maint_notifications_config,
1349
+ self._conn._maint_notifications_pool_handler,
1350
+ self._conn.maintenance_state,
1351
+ self._conn.maintenance_notification_hash,
1352
+ self._conn.host,
1353
+ self._conn.socket_timeout,
1354
+ self._conn.socket_connect_timeout,
1355
+ self._conn._get_parser(),
1356
+ )
1357
+
1174
1358
  def repr_pieces(self):
1175
1359
  return self._conn.repr_pieces()
1176
1360
 
@@ -1183,6 +1367,17 @@ class CacheProxyConnection(ConnectionInterface):
1183
1367
  def set_parser(self, parser_class):
1184
1368
  self._conn.set_parser(parser_class)
1185
1369
 
1370
+ def set_maint_notifications_pool_handler_for_connection(
1371
+ self, maint_notifications_pool_handler
1372
+ ):
1373
+ if isinstance(self._conn, MaintNotificationsAbstractConnection):
1374
+ self._conn.set_maint_notifications_pool_handler_for_connection(
1375
+ maint_notifications_pool_handler
1376
+ )
1377
+
1378
+ def get_protocol(self):
1379
+ return self._conn.get_protocol()
1380
+
1186
1381
  def connect(self):
1187
1382
  self._conn.connect()
1188
1383
 
@@ -1328,6 +1523,109 @@ class CacheProxyConnection(ConnectionInterface):
1328
1523
  def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]:
1329
1524
  return self._conn.handshake_metadata
1330
1525
 
1526
+ def set_re_auth_token(self, token: TokenInterface):
1527
+ self._conn.set_re_auth_token(token)
1528
+
1529
+ def re_auth(self):
1530
+ self._conn.re_auth()
1531
+
1532
+ def mark_for_reconnect(self):
1533
+ self._conn.mark_for_reconnect()
1534
+
1535
+ def should_reconnect(self):
1536
+ return self._conn.should_reconnect()
1537
+
1538
+ def reset_should_reconnect(self):
1539
+ self._conn.reset_should_reconnect()
1540
+
1541
+ @property
1542
+ def host(self) -> str:
1543
+ return self._conn.host
1544
+
1545
+ @host.setter
1546
+ def host(self, value: str):
1547
+ self._conn.host = value
1548
+
1549
+ @property
1550
+ def socket_timeout(self) -> Optional[Union[float, int]]:
1551
+ return self._conn.socket_timeout
1552
+
1553
+ @socket_timeout.setter
1554
+ def socket_timeout(self, value: Optional[Union[float, int]]):
1555
+ self._conn.socket_timeout = value
1556
+
1557
+ @property
1558
+ def socket_connect_timeout(self) -> Optional[Union[float, int]]:
1559
+ return self._conn.socket_connect_timeout
1560
+
1561
+ @socket_connect_timeout.setter
1562
+ def socket_connect_timeout(self, value: Optional[Union[float, int]]):
1563
+ self._conn.socket_connect_timeout = value
1564
+
1565
+ def _get_socket(self) -> Optional[socket.socket]:
1566
+ if isinstance(self._conn, MaintNotificationsAbstractConnection):
1567
+ return self._conn._get_socket()
1568
+ else:
1569
+ raise NotImplementedError(
1570
+ "Maintenance notifications are not supported by this connection type"
1571
+ )
1572
+
1573
+ def _get_maint_notifications_connection_instance(
1574
+ self, connection
1575
+ ) -> MaintNotificationsAbstractConnection:
1576
+ """
1577
+ Validate that connection instance supports maintenance notifications.
1578
+ With this helper method we ensure that we are working
1579
+ with the correct connection type.
1580
+ After twe validate that connection instance supports maintenance notifications
1581
+ we can safely return the connection instance
1582
+ as MaintNotificationsAbstractConnection.
1583
+ """
1584
+ if not isinstance(connection, MaintNotificationsAbstractConnection):
1585
+ raise NotImplementedError(
1586
+ "Maintenance notifications are not supported by this connection type"
1587
+ )
1588
+ else:
1589
+ return connection
1590
+
1591
+ @property
1592
+ def maintenance_state(self) -> MaintenanceState:
1593
+ con = self._get_maint_notifications_connection_instance(self._conn)
1594
+ return con.maintenance_state
1595
+
1596
+ @maintenance_state.setter
1597
+ def maintenance_state(self, state: MaintenanceState):
1598
+ con = self._get_maint_notifications_connection_instance(self._conn)
1599
+ con.maintenance_state = state
1600
+
1601
+ def getpeername(self):
1602
+ con = self._get_maint_notifications_connection_instance(self._conn)
1603
+ return con.getpeername()
1604
+
1605
+ def get_resolved_ip(self):
1606
+ con = self._get_maint_notifications_connection_instance(self._conn)
1607
+ return con.get_resolved_ip()
1608
+
1609
+ def update_current_socket_timeout(self, relaxed_timeout: Optional[float] = None):
1610
+ con = self._get_maint_notifications_connection_instance(self._conn)
1611
+ con.update_current_socket_timeout(relaxed_timeout)
1612
+
1613
+ def set_tmp_settings(
1614
+ self,
1615
+ tmp_host_address: Optional[str] = None,
1616
+ tmp_relaxed_timeout: Optional[float] = None,
1617
+ ):
1618
+ con = self._get_maint_notifications_connection_instance(self._conn)
1619
+ con.set_tmp_settings(tmp_host_address, tmp_relaxed_timeout)
1620
+
1621
+ def reset_tmp_settings(
1622
+ self,
1623
+ reset_host_address: bool = False,
1624
+ reset_relaxed_timeout: bool = False,
1625
+ ):
1626
+ con = self._get_maint_notifications_connection_instance(self._conn)
1627
+ con.reset_tmp_settings(reset_host_address, reset_relaxed_timeout)
1628
+
1331
1629
  def _connect(self):
1332
1630
  self._conn._connect()
1333
1631
 
@@ -1351,15 +1649,6 @@ class CacheProxyConnection(ConnectionInterface):
1351
1649
  else:
1352
1650
  self._cache.delete_by_redis_keys(data[1])
1353
1651
 
1354
- def get_protocol(self):
1355
- return self._conn.get_protocol()
1356
-
1357
- def set_re_auth_token(self, token: TokenInterface):
1358
- self._conn.set_re_auth_token(token)
1359
-
1360
- def re_auth(self):
1361
- self._conn.re_auth()
1362
-
1363
1652
 
1364
1653
  class SSLConnection(Connection):
1365
1654
  """Manages SSL connections to and from the Redis server(s).
@@ -1448,240 +1737,629 @@ class SSLConnection(Connection):
1448
1737
  self.ssl_ciphers = ssl_ciphers
1449
1738
  super().__init__(**kwargs)
1450
1739
 
1451
- def _connect(self):
1452
- """
1453
- Wrap the socket with SSL support, handling potential errors.
1454
- """
1455
- sock = super()._connect()
1456
- try:
1457
- return self._wrap_socket_with_ssl(sock)
1458
- except (OSError, RedisError):
1459
- sock.close()
1460
- raise
1740
+ def _connect(self):
1741
+ """
1742
+ Wrap the socket with SSL support, handling potential errors.
1743
+ """
1744
+ sock = super()._connect()
1745
+ try:
1746
+ return self._wrap_socket_with_ssl(sock)
1747
+ except (OSError, RedisError):
1748
+ sock.close()
1749
+ raise
1750
+
1751
+ def _wrap_socket_with_ssl(self, sock):
1752
+ """
1753
+ Wraps the socket with SSL support.
1754
+
1755
+ Args:
1756
+ sock: The plain socket to wrap with SSL.
1757
+
1758
+ Returns:
1759
+ An SSL wrapped socket.
1760
+ """
1761
+ context = ssl.create_default_context()
1762
+ context.check_hostname = self.check_hostname
1763
+ context.verify_mode = self.cert_reqs
1764
+ if self.ssl_include_verify_flags:
1765
+ for flag in self.ssl_include_verify_flags:
1766
+ context.verify_flags |= flag
1767
+ if self.ssl_exclude_verify_flags:
1768
+ for flag in self.ssl_exclude_verify_flags:
1769
+ context.verify_flags &= ~flag
1770
+ if self.certfile or self.keyfile:
1771
+ context.load_cert_chain(
1772
+ certfile=self.certfile,
1773
+ keyfile=self.keyfile,
1774
+ password=self.certificate_password,
1775
+ )
1776
+ if (
1777
+ self.ca_certs is not None
1778
+ or self.ca_path is not None
1779
+ or self.ca_data is not None
1780
+ ):
1781
+ context.load_verify_locations(
1782
+ cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data
1783
+ )
1784
+ if self.ssl_min_version is not None:
1785
+ context.minimum_version = self.ssl_min_version
1786
+ if self.ssl_ciphers:
1787
+ context.set_ciphers(self.ssl_ciphers)
1788
+ if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE is False:
1789
+ raise RedisError("cryptography is not installed.")
1790
+
1791
+ if self.ssl_validate_ocsp_stapled and self.ssl_validate_ocsp:
1792
+ raise RedisError(
1793
+ "Either an OCSP staple or pure OCSP connection must be validated "
1794
+ "- not both."
1795
+ )
1796
+
1797
+ sslsock = context.wrap_socket(sock, server_hostname=self.host)
1798
+
1799
+ # validation for the stapled case
1800
+ if self.ssl_validate_ocsp_stapled:
1801
+ import OpenSSL
1802
+
1803
+ from .ocsp import ocsp_staple_verifier
1804
+
1805
+ # if a context is provided use it - otherwise, a basic context
1806
+ if self.ssl_ocsp_context is None:
1807
+ staple_ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD)
1808
+ staple_ctx.use_certificate_file(self.certfile)
1809
+ staple_ctx.use_privatekey_file(self.keyfile)
1810
+ else:
1811
+ staple_ctx = self.ssl_ocsp_context
1812
+
1813
+ staple_ctx.set_ocsp_client_callback(
1814
+ ocsp_staple_verifier, self.ssl_ocsp_expected_cert
1815
+ )
1816
+
1817
+ # need another socket
1818
+ con = OpenSSL.SSL.Connection(staple_ctx, socket.socket())
1819
+ con.request_ocsp()
1820
+ con.connect((self.host, self.port))
1821
+ con.do_handshake()
1822
+ con.shutdown()
1823
+ return sslsock
1824
+
1825
+ # pure ocsp validation
1826
+ if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE:
1827
+ from .ocsp import OCSPVerifier
1828
+
1829
+ o = OCSPVerifier(sslsock, self.host, self.port, self.ca_certs)
1830
+ if o.is_valid():
1831
+ return sslsock
1832
+ else:
1833
+ raise ConnectionError("ocsp validation error")
1834
+ return sslsock
1835
+
1836
+
1837
+ class UnixDomainSocketConnection(AbstractConnection):
1838
+ "Manages UDS communication to and from a Redis server"
1839
+
1840
+ def __init__(self, path="", socket_timeout=None, **kwargs):
1841
+ super().__init__(**kwargs)
1842
+ self.path = path
1843
+ self.socket_timeout = socket_timeout
1844
+
1845
+ def repr_pieces(self):
1846
+ pieces = [("path", self.path), ("db", self.db)]
1847
+ if self.client_name:
1848
+ pieces.append(("client_name", self.client_name))
1849
+ return pieces
1850
+
1851
+ def _connect(self):
1852
+ "Create a Unix domain socket connection"
1853
+ sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1854
+ sock.settimeout(self.socket_connect_timeout)
1855
+ try:
1856
+ sock.connect(self.path)
1857
+ except OSError:
1858
+ # Prevent ResourceWarnings for unclosed sockets.
1859
+ try:
1860
+ sock.shutdown(socket.SHUT_RDWR) # ensure a clean close
1861
+ except OSError:
1862
+ pass
1863
+ sock.close()
1864
+ raise
1865
+ sock.settimeout(self.socket_timeout)
1866
+ return sock
1867
+
1868
+ def _host_error(self):
1869
+ return self.path
1870
+
1871
+
1872
+ FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO")
1873
+
1874
+
1875
+ def to_bool(value):
1876
+ if value is None or value == "":
1877
+ return None
1878
+ if isinstance(value, str) and value.upper() in FALSE_STRINGS:
1879
+ return False
1880
+ return bool(value)
1881
+
1882
+
1883
+ def parse_ssl_verify_flags(value):
1884
+ # flags are passed in as a string representation of a list,
1885
+ # e.g. VERIFY_X509_STRICT, VERIFY_X509_PARTIAL_CHAIN
1886
+ verify_flags_str = value.replace("[", "").replace("]", "")
1887
+
1888
+ verify_flags = []
1889
+ for flag in verify_flags_str.split(","):
1890
+ flag = flag.strip()
1891
+ if not hasattr(VerifyFlags, flag):
1892
+ raise ValueError(f"Invalid ssl verify flag: {flag}")
1893
+ verify_flags.append(getattr(VerifyFlags, flag))
1894
+ return verify_flags
1895
+
1896
+
1897
+ URL_QUERY_ARGUMENT_PARSERS = {
1898
+ "db": int,
1899
+ "socket_timeout": float,
1900
+ "socket_connect_timeout": float,
1901
+ "socket_keepalive": to_bool,
1902
+ "retry_on_timeout": to_bool,
1903
+ "retry_on_error": list,
1904
+ "max_connections": int,
1905
+ "health_check_interval": int,
1906
+ "ssl_check_hostname": to_bool,
1907
+ "ssl_include_verify_flags": parse_ssl_verify_flags,
1908
+ "ssl_exclude_verify_flags": parse_ssl_verify_flags,
1909
+ "timeout": float,
1910
+ }
1911
+
1912
+
1913
+ def parse_url(url):
1914
+ if not (
1915
+ url.startswith("redis://")
1916
+ or url.startswith("rediss://")
1917
+ or url.startswith("unix://")
1918
+ ):
1919
+ raise ValueError(
1920
+ "Redis URL must specify one of the following "
1921
+ "schemes (redis://, rediss://, unix://)"
1922
+ )
1923
+
1924
+ url = urlparse(url)
1925
+ kwargs = {}
1926
+
1927
+ for name, value in parse_qs(url.query).items():
1928
+ if value and len(value) > 0:
1929
+ value = unquote(value[0])
1930
+ parser = URL_QUERY_ARGUMENT_PARSERS.get(name)
1931
+ if parser:
1932
+ try:
1933
+ kwargs[name] = parser(value)
1934
+ except (TypeError, ValueError):
1935
+ raise ValueError(f"Invalid value for '{name}' in connection URL.")
1936
+ else:
1937
+ kwargs[name] = value
1938
+
1939
+ if url.username:
1940
+ kwargs["username"] = unquote(url.username)
1941
+ if url.password:
1942
+ kwargs["password"] = unquote(url.password)
1943
+
1944
+ # We only support redis://, rediss:// and unix:// schemes.
1945
+ if url.scheme == "unix":
1946
+ if url.path:
1947
+ kwargs["path"] = unquote(url.path)
1948
+ kwargs["connection_class"] = UnixDomainSocketConnection
1949
+
1950
+ else: # implied: url.scheme in ("redis", "rediss"):
1951
+ if url.hostname:
1952
+ kwargs["host"] = unquote(url.hostname)
1953
+ if url.port:
1954
+ kwargs["port"] = int(url.port)
1955
+
1956
+ # If there's a path argument, use it as the db argument if a
1957
+ # querystring value wasn't specified
1958
+ if url.path and "db" not in kwargs:
1959
+ try:
1960
+ kwargs["db"] = int(unquote(url.path).replace("/", ""))
1961
+ except (AttributeError, ValueError):
1962
+ pass
1963
+
1964
+ if url.scheme == "rediss":
1965
+ kwargs["connection_class"] = SSLConnection
1966
+
1967
+ return kwargs
1968
+
1969
+
1970
+ _CP = TypeVar("_CP", bound="ConnectionPool")
1971
+
1972
+
1973
+ class ConnectionPoolInterface(ABC):
1974
+ @abstractmethod
1975
+ def get_protocol(self):
1976
+ pass
1977
+
1978
+ @abstractmethod
1979
+ def reset(self):
1980
+ pass
1981
+
1982
+ @abstractmethod
1983
+ @deprecated_args(
1984
+ args_to_warn=["*"],
1985
+ reason="Use get_connection() without args instead",
1986
+ version="5.3.0",
1987
+ )
1988
+ def get_connection(
1989
+ self, command_name: Optional[str], *keys, **options
1990
+ ) -> ConnectionInterface:
1991
+ pass
1992
+
1993
+ @abstractmethod
1994
+ def get_encoder(self):
1995
+ pass
1461
1996
 
1462
- def _wrap_socket_with_ssl(self, sock):
1463
- """
1464
- Wraps the socket with SSL support.
1997
+ @abstractmethod
1998
+ def release(self, connection: ConnectionInterface):
1999
+ pass
1465
2000
 
1466
- Args:
1467
- sock: The plain socket to wrap with SSL.
2001
+ @abstractmethod
2002
+ def disconnect(self, inuse_connections: bool = True):
2003
+ pass
1468
2004
 
1469
- Returns:
1470
- An SSL wrapped socket.
1471
- """
1472
- context = ssl.create_default_context()
1473
- context.check_hostname = self.check_hostname
1474
- context.verify_mode = self.cert_reqs
1475
- if self.ssl_include_verify_flags:
1476
- for flag in self.ssl_include_verify_flags:
1477
- context.verify_flags |= flag
1478
- if self.ssl_exclude_verify_flags:
1479
- for flag in self.ssl_exclude_verify_flags:
1480
- context.verify_flags &= ~flag
1481
- if self.certfile or self.keyfile:
1482
- context.load_cert_chain(
1483
- certfile=self.certfile,
1484
- keyfile=self.keyfile,
1485
- password=self.certificate_password,
1486
- )
1487
- if (
1488
- self.ca_certs is not None
1489
- or self.ca_path is not None
1490
- or self.ca_data is not None
1491
- ):
1492
- context.load_verify_locations(
1493
- cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data
1494
- )
1495
- if self.ssl_min_version is not None:
1496
- context.minimum_version = self.ssl_min_version
1497
- if self.ssl_ciphers:
1498
- context.set_ciphers(self.ssl_ciphers)
1499
- if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE is False:
1500
- raise RedisError("cryptography is not installed.")
2005
+ @abstractmethod
2006
+ def close(self):
2007
+ pass
1501
2008
 
1502
- if self.ssl_validate_ocsp_stapled and self.ssl_validate_ocsp:
1503
- raise RedisError(
1504
- "Either an OCSP staple or pure OCSP connection must be validated "
1505
- "- not both."
1506
- )
2009
+ @abstractmethod
2010
+ def set_retry(self, retry: Retry):
2011
+ pass
1507
2012
 
1508
- sslsock = context.wrap_socket(sock, server_hostname=self.host)
2013
+ @abstractmethod
2014
+ def re_auth_callback(self, token: TokenInterface):
2015
+ pass
1509
2016
 
1510
- # validation for the stapled case
1511
- if self.ssl_validate_ocsp_stapled:
1512
- import OpenSSL
1513
2017
 
1514
- from .ocsp import ocsp_staple_verifier
2018
+ class MaintNotificationsAbstractConnectionPool:
2019
+ """
2020
+ Abstract class for handling maintenance notifications logic.
2021
+ This class is mixed into the ConnectionPool classes.
1515
2022
 
1516
- # if a context is provided use it - otherwise, a basic context
1517
- if self.ssl_ocsp_context is None:
1518
- staple_ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD)
1519
- staple_ctx.use_certificate_file(self.certfile)
1520
- staple_ctx.use_privatekey_file(self.keyfile)
1521
- else:
1522
- staple_ctx = self.ssl_ocsp_context
2023
+ This class is not intended to be used directly!
1523
2024
 
1524
- staple_ctx.set_ocsp_client_callback(
1525
- ocsp_staple_verifier, self.ssl_ocsp_expected_cert
2025
+ All logic related to maintenance notifications and
2026
+ connection pool handling is encapsulated in this class.
2027
+ """
2028
+
2029
+ def __init__(
2030
+ self,
2031
+ maint_notifications_config: Optional[MaintNotificationsConfig] = None,
2032
+ **kwargs,
2033
+ ):
2034
+ # Initialize maintenance notifications
2035
+ is_protocol_supported = kwargs.get("protocol") in [3, "3"]
2036
+ if maint_notifications_config is None and is_protocol_supported:
2037
+ maint_notifications_config = MaintNotificationsConfig()
2038
+
2039
+ if maint_notifications_config and maint_notifications_config.enabled:
2040
+ if not is_protocol_supported:
2041
+ raise RedisError(
2042
+ "Maintenance notifications handlers on connection are only supported with RESP version 3"
2043
+ )
2044
+
2045
+ self._maint_notifications_pool_handler = MaintNotificationsPoolHandler(
2046
+ self, maint_notifications_config
1526
2047
  )
1527
2048
 
1528
- # need another socket
1529
- con = OpenSSL.SSL.Connection(staple_ctx, socket.socket())
1530
- con.request_ocsp()
1531
- con.connect((self.host, self.port))
1532
- con.do_handshake()
1533
- con.shutdown()
1534
- return sslsock
2049
+ self._update_connection_kwargs_for_maint_notifications(
2050
+ self._maint_notifications_pool_handler
2051
+ )
2052
+ else:
2053
+ self._maint_notifications_pool_handler = None
1535
2054
 
1536
- # pure ocsp validation
1537
- if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE:
1538
- from .ocsp import OCSPVerifier
2055
+ @property
2056
+ @abstractmethod
2057
+ def connection_kwargs(self) -> Dict[str, Any]:
2058
+ pass
1539
2059
 
1540
- o = OCSPVerifier(sslsock, self.host, self.port, self.ca_certs)
1541
- if o.is_valid():
1542
- return sslsock
1543
- else:
1544
- raise ConnectionError("ocsp validation error")
1545
- return sslsock
2060
+ @connection_kwargs.setter
2061
+ @abstractmethod
2062
+ def connection_kwargs(self, value: Dict[str, Any]):
2063
+ pass
1546
2064
 
2065
+ @abstractmethod
2066
+ def _get_pool_lock(self) -> threading.RLock:
2067
+ pass
1547
2068
 
1548
- class UnixDomainSocketConnection(AbstractConnection):
1549
- "Manages UDS communication to and from a Redis server"
2069
+ @abstractmethod
2070
+ def _get_free_connections(self) -> Iterable["MaintNotificationsAbstractConnection"]:
2071
+ pass
1550
2072
 
1551
- def __init__(self, path="", socket_timeout=None, **kwargs):
1552
- super().__init__(**kwargs)
1553
- self.path = path
1554
- self.socket_timeout = socket_timeout
2073
+ @abstractmethod
2074
+ def _get_in_use_connections(
2075
+ self,
2076
+ ) -> Iterable["MaintNotificationsAbstractConnection"]:
2077
+ pass
1555
2078
 
1556
- def repr_pieces(self):
1557
- pieces = [("path", self.path), ("db", self.db)]
1558
- if self.client_name:
1559
- pieces.append(("client_name", self.client_name))
1560
- return pieces
2079
+ def maint_notifications_enabled(self):
2080
+ """
2081
+ Returns:
2082
+ True if the maintenance notifications are enabled, False otherwise.
2083
+ The maintenance notifications config is stored in the pool handler.
2084
+ If the pool handler is not set, the maintenance notifications are not enabled.
2085
+ """
2086
+ maint_notifications_config = (
2087
+ self._maint_notifications_pool_handler.config
2088
+ if self._maint_notifications_pool_handler
2089
+ else None
2090
+ )
1561
2091
 
1562
- def _connect(self):
1563
- "Create a Unix domain socket connection"
1564
- sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1565
- sock.settimeout(self.socket_connect_timeout)
1566
- try:
1567
- sock.connect(self.path)
1568
- except OSError:
1569
- # Prevent ResourceWarnings for unclosed sockets.
1570
- try:
1571
- sock.shutdown(socket.SHUT_RDWR) # ensure a clean close
1572
- except OSError:
1573
- pass
1574
- sock.close()
1575
- raise
1576
- sock.settimeout(self.socket_timeout)
1577
- return sock
2092
+ return maint_notifications_config and maint_notifications_config.enabled
1578
2093
 
1579
- def _host_error(self):
1580
- return self.path
2094
+ def update_maint_notifications_config(
2095
+ self, maint_notifications_config: MaintNotificationsConfig
2096
+ ):
2097
+ """
2098
+ Updates the maintenance notifications configuration.
2099
+ This method should be called only if the pool was created
2100
+ without enabling the maintenance notifications and
2101
+ in a later point in time maintenance notifications
2102
+ are requested to be enabled.
2103
+ """
2104
+ if (
2105
+ self.maint_notifications_enabled()
2106
+ and not maint_notifications_config.enabled
2107
+ ):
2108
+ raise ValueError(
2109
+ "Cannot disable maintenance notifications after enabling them"
2110
+ )
2111
+ # first update pool settings
2112
+ if not self._maint_notifications_pool_handler:
2113
+ self._maint_notifications_pool_handler = MaintNotificationsPoolHandler(
2114
+ self, maint_notifications_config
2115
+ )
2116
+ else:
2117
+ self._maint_notifications_pool_handler.config = maint_notifications_config
1581
2118
 
2119
+ # then update connection kwargs and existing connections
2120
+ self._update_connection_kwargs_for_maint_notifications(
2121
+ self._maint_notifications_pool_handler
2122
+ )
2123
+ self._update_maint_notifications_configs_for_connections(
2124
+ self._maint_notifications_pool_handler
2125
+ )
1582
2126
 
1583
- FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO")
2127
+ def _update_connection_kwargs_for_maint_notifications(
2128
+ self, maint_notifications_pool_handler: MaintNotificationsPoolHandler
2129
+ ):
2130
+ """
2131
+ Update the connection kwargs for all future connections.
2132
+ """
2133
+ if not self.maint_notifications_enabled():
2134
+ return
1584
2135
 
2136
+ self.connection_kwargs.update(
2137
+ {
2138
+ "maint_notifications_pool_handler": maint_notifications_pool_handler,
2139
+ "maint_notifications_config": maint_notifications_pool_handler.config,
2140
+ }
2141
+ )
1585
2142
 
1586
- def to_bool(value):
1587
- if value is None or value == "":
1588
- return None
1589
- if isinstance(value, str) and value.upper() in FALSE_STRINGS:
1590
- return False
1591
- return bool(value)
2143
+ # Store original connection parameters for maintenance notifications.
2144
+ if self.connection_kwargs.get("orig_host_address", None) is None:
2145
+ # If orig_host_address is None it means we haven't
2146
+ # configured the original values yet
2147
+ self.connection_kwargs.update(
2148
+ {
2149
+ "orig_host_address": self.connection_kwargs.get("host"),
2150
+ "orig_socket_timeout": self.connection_kwargs.get(
2151
+ "socket_timeout", None
2152
+ ),
2153
+ "orig_socket_connect_timeout": self.connection_kwargs.get(
2154
+ "socket_connect_timeout", None
2155
+ ),
2156
+ }
2157
+ )
2158
+
2159
+ def _update_maint_notifications_configs_for_connections(
2160
+ self, maint_notifications_pool_handler: MaintNotificationsPoolHandler
2161
+ ):
2162
+ """Update the maintenance notifications config for all connections in the pool."""
2163
+ with self._get_pool_lock():
2164
+ for conn in self._get_free_connections():
2165
+ conn.set_maint_notifications_pool_handler_for_connection(
2166
+ maint_notifications_pool_handler
2167
+ )
2168
+ conn.maint_notifications_config = (
2169
+ maint_notifications_pool_handler.config
2170
+ )
2171
+ conn.disconnect()
2172
+ for conn in self._get_in_use_connections():
2173
+ conn.set_maint_notifications_pool_handler_for_connection(
2174
+ maint_notifications_pool_handler
2175
+ )
2176
+ conn.maint_notifications_config = (
2177
+ maint_notifications_pool_handler.config
2178
+ )
2179
+ conn.mark_for_reconnect()
2180
+
2181
+ def _should_update_connection(
2182
+ self,
2183
+ conn: "MaintNotificationsAbstractConnection",
2184
+ matching_pattern: Literal[
2185
+ "connected_address", "configured_address", "notification_hash"
2186
+ ] = "connected_address",
2187
+ matching_address: Optional[str] = None,
2188
+ matching_notification_hash: Optional[int] = None,
2189
+ ) -> bool:
2190
+ """
2191
+ Check if the connection should be updated based on the matching criteria.
2192
+ """
2193
+ if matching_pattern == "connected_address":
2194
+ if matching_address and conn.getpeername() != matching_address:
2195
+ return False
2196
+ elif matching_pattern == "configured_address":
2197
+ if matching_address and conn.host != matching_address:
2198
+ return False
2199
+ elif matching_pattern == "notification_hash":
2200
+ if (
2201
+ matching_notification_hash
2202
+ and conn.maintenance_notification_hash != matching_notification_hash
2203
+ ):
2204
+ return False
2205
+ return True
1592
2206
 
2207
+ def update_connection_settings(
2208
+ self,
2209
+ conn: "MaintNotificationsAbstractConnection",
2210
+ state: Optional["MaintenanceState"] = None,
2211
+ maintenance_notification_hash: Optional[int] = None,
2212
+ host_address: Optional[str] = None,
2213
+ relaxed_timeout: Optional[float] = None,
2214
+ update_notification_hash: bool = False,
2215
+ reset_host_address: bool = False,
2216
+ reset_relaxed_timeout: bool = False,
2217
+ ):
2218
+ """
2219
+ Update the settings for a single connection.
2220
+ """
2221
+ if state:
2222
+ conn.maintenance_state = state
1593
2223
 
1594
- def parse_ssl_verify_flags(value):
1595
- # flags are passed in as a string representation of a list,
1596
- # e.g. VERIFY_X509_STRICT, VERIFY_X509_PARTIAL_CHAIN
1597
- verify_flags_str = value.replace("[", "").replace("]", "")
2224
+ if update_notification_hash:
2225
+ # update the notification hash only if requested
2226
+ conn.maintenance_notification_hash = maintenance_notification_hash
1598
2227
 
1599
- verify_flags = []
1600
- for flag in verify_flags_str.split(","):
1601
- flag = flag.strip()
1602
- if not hasattr(VerifyFlags, flag):
1603
- raise ValueError(f"Invalid ssl verify flag: {flag}")
1604
- verify_flags.append(getattr(VerifyFlags, flag))
1605
- return verify_flags
2228
+ if host_address is not None:
2229
+ conn.set_tmp_settings(tmp_host_address=host_address)
1606
2230
 
2231
+ if relaxed_timeout is not None:
2232
+ conn.set_tmp_settings(tmp_relaxed_timeout=relaxed_timeout)
1607
2233
 
1608
- URL_QUERY_ARGUMENT_PARSERS = {
1609
- "db": int,
1610
- "socket_timeout": float,
1611
- "socket_connect_timeout": float,
1612
- "socket_keepalive": to_bool,
1613
- "retry_on_timeout": to_bool,
1614
- "retry_on_error": list,
1615
- "max_connections": int,
1616
- "health_check_interval": int,
1617
- "ssl_check_hostname": to_bool,
1618
- "ssl_include_verify_flags": parse_ssl_verify_flags,
1619
- "ssl_exclude_verify_flags": parse_ssl_verify_flags,
1620
- "timeout": float,
1621
- }
2234
+ if reset_relaxed_timeout or reset_host_address:
2235
+ conn.reset_tmp_settings(
2236
+ reset_host_address=reset_host_address,
2237
+ reset_relaxed_timeout=reset_relaxed_timeout,
2238
+ )
1622
2239
 
2240
+ conn.update_current_socket_timeout(relaxed_timeout)
1623
2241
 
1624
- def parse_url(url):
1625
- if not (
1626
- url.startswith("redis://")
1627
- or url.startswith("rediss://")
1628
- or url.startswith("unix://")
2242
+ def update_connections_settings(
2243
+ self,
2244
+ state: Optional["MaintenanceState"] = None,
2245
+ maintenance_notification_hash: Optional[int] = None,
2246
+ host_address: Optional[str] = None,
2247
+ relaxed_timeout: Optional[float] = None,
2248
+ matching_address: Optional[str] = None,
2249
+ matching_notification_hash: Optional[int] = None,
2250
+ matching_pattern: Literal[
2251
+ "connected_address", "configured_address", "notification_hash"
2252
+ ] = "connected_address",
2253
+ update_notification_hash: bool = False,
2254
+ reset_host_address: bool = False,
2255
+ reset_relaxed_timeout: bool = False,
2256
+ include_free_connections: bool = True,
1629
2257
  ):
1630
- raise ValueError(
1631
- "Redis URL must specify one of the following "
1632
- "schemes (redis://, rediss://, unix://)"
1633
- )
1634
-
1635
- url = urlparse(url)
1636
- kwargs = {}
2258
+ """
2259
+ Update the settings for all matching connections in the pool.
1637
2260
 
1638
- for name, value in parse_qs(url.query).items():
1639
- if value and len(value) > 0:
1640
- value = unquote(value[0])
1641
- parser = URL_QUERY_ARGUMENT_PARSERS.get(name)
1642
- if parser:
1643
- try:
1644
- kwargs[name] = parser(value)
1645
- except (TypeError, ValueError):
1646
- raise ValueError(f"Invalid value for '{name}' in connection URL.")
1647
- else:
1648
- kwargs[name] = value
2261
+ This method does not create new connections.
2262
+ This method does not affect the connection kwargs.
1649
2263
 
1650
- if url.username:
1651
- kwargs["username"] = unquote(url.username)
1652
- if url.password:
1653
- kwargs["password"] = unquote(url.password)
2264
+ :param state: The maintenance state to set for the connection.
2265
+ :param maintenance_notification_hash: The hash of the maintenance notification
2266
+ to set for the connection.
2267
+ :param host_address: The host address to set for the connection.
2268
+ :param relaxed_timeout: The relaxed timeout to set for the connection.
2269
+ :param matching_address: The address to match for the connection.
2270
+ :param matching_notification_hash: The notification hash to match for the connection.
2271
+ :param matching_pattern: The pattern to match for the connection.
2272
+ :param update_notification_hash: Whether to update the notification hash for the connection.
2273
+ :param reset_host_address: Whether to reset the host address to the original address.
2274
+ :param reset_relaxed_timeout: Whether to reset the relaxed timeout to the original timeout.
2275
+ :param include_free_connections: Whether to include free/available connections.
2276
+ """
2277
+ with self._get_pool_lock():
2278
+ for conn in self._get_in_use_connections():
2279
+ if self._should_update_connection(
2280
+ conn,
2281
+ matching_pattern,
2282
+ matching_address,
2283
+ matching_notification_hash,
2284
+ ):
2285
+ self.update_connection_settings(
2286
+ conn,
2287
+ state=state,
2288
+ maintenance_notification_hash=maintenance_notification_hash,
2289
+ host_address=host_address,
2290
+ relaxed_timeout=relaxed_timeout,
2291
+ update_notification_hash=update_notification_hash,
2292
+ reset_host_address=reset_host_address,
2293
+ reset_relaxed_timeout=reset_relaxed_timeout,
2294
+ )
1654
2295
 
1655
- # We only support redis://, rediss:// and unix:// schemes.
1656
- if url.scheme == "unix":
1657
- if url.path:
1658
- kwargs["path"] = unquote(url.path)
1659
- kwargs["connection_class"] = UnixDomainSocketConnection
2296
+ if include_free_connections:
2297
+ for conn in self._get_free_connections():
2298
+ if self._should_update_connection(
2299
+ conn,
2300
+ matching_pattern,
2301
+ matching_address,
2302
+ matching_notification_hash,
2303
+ ):
2304
+ self.update_connection_settings(
2305
+ conn,
2306
+ state=state,
2307
+ maintenance_notification_hash=maintenance_notification_hash,
2308
+ host_address=host_address,
2309
+ relaxed_timeout=relaxed_timeout,
2310
+ update_notification_hash=update_notification_hash,
2311
+ reset_host_address=reset_host_address,
2312
+ reset_relaxed_timeout=reset_relaxed_timeout,
2313
+ )
1660
2314
 
1661
- else: # implied: url.scheme in ("redis", "rediss"):
1662
- if url.hostname:
1663
- kwargs["host"] = unquote(url.hostname)
1664
- if url.port:
1665
- kwargs["port"] = int(url.port)
2315
+ def update_connection_kwargs(
2316
+ self,
2317
+ **kwargs,
2318
+ ):
2319
+ """
2320
+ Update the connection kwargs for all future connections.
1666
2321
 
1667
- # If there's a path argument, use it as the db argument if a
1668
- # querystring value wasn't specified
1669
- if url.path and "db" not in kwargs:
1670
- try:
1671
- kwargs["db"] = int(unquote(url.path).replace("/", ""))
1672
- except (AttributeError, ValueError):
1673
- pass
2322
+ This method updates the connection kwargs for all future connections created by the pool.
2323
+ Existing connections are not affected.
2324
+ """
2325
+ self.connection_kwargs.update(kwargs)
1674
2326
 
1675
- if url.scheme == "rediss":
1676
- kwargs["connection_class"] = SSLConnection
2327
+ def update_active_connections_for_reconnect(
2328
+ self,
2329
+ moving_address_src: Optional[str] = None,
2330
+ ):
2331
+ """
2332
+ Mark all active connections for reconnect.
2333
+ This is used when a cluster node is migrated to a different address.
1677
2334
 
1678
- return kwargs
2335
+ :param moving_address_src: The address of the node that is being moved.
2336
+ """
2337
+ with self._get_pool_lock():
2338
+ for conn in self._get_in_use_connections():
2339
+ if self._should_update_connection(
2340
+ conn, "connected_address", moving_address_src
2341
+ ):
2342
+ conn.mark_for_reconnect()
1679
2343
 
2344
+ def disconnect_free_connections(
2345
+ self,
2346
+ moving_address_src: Optional[str] = None,
2347
+ ):
2348
+ """
2349
+ Disconnect all free/available connections.
2350
+ This is used when a cluster node is migrated to a different address.
1680
2351
 
1681
- _CP = TypeVar("_CP", bound="ConnectionPool")
2352
+ :param moving_address_src: The address of the node that is being moved.
2353
+ """
2354
+ with self._get_pool_lock():
2355
+ for conn in self._get_free_connections():
2356
+ if self._should_update_connection(
2357
+ conn, "connected_address", moving_address_src
2358
+ ):
2359
+ conn.disconnect()
1682
2360
 
1683
2361
 
1684
- class ConnectionPool:
2362
+ class ConnectionPool(MaintNotificationsAbstractConnectionPool, ConnectionPoolInterface):
1685
2363
  """
1686
2364
  Create a connection pool. ``If max_connections`` is set, then this
1687
2365
  object raises :py:class:`~redis.exceptions.ConnectionError` when the pool's
@@ -1692,6 +2370,12 @@ class ConnectionPool:
1692
2370
  unix sockets.
1693
2371
  :py:class:`~redis.SSLConnection` can be used for SSL enabled connections.
1694
2372
 
2373
+ If ``maint_notifications_config`` is provided, the connection pool will support
2374
+ maintenance notifications.
2375
+ Maintenance notifications are supported only with RESP3.
2376
+ If the ``maint_notifications_config`` is not provided but the ``protocol`` is 3,
2377
+ the maintenance notifications will be enabled by default.
2378
+
1695
2379
  Any additional keyword arguments are passed to the constructor of
1696
2380
  ``connection_class``.
1697
2381
  """
@@ -1750,6 +2434,7 @@ class ConnectionPool:
1750
2434
  connection_class=Connection,
1751
2435
  max_connections: Optional[int] = None,
1752
2436
  cache_factory: Optional[CacheFactoryInterface] = None,
2437
+ maint_notifications_config: Optional[MaintNotificationsConfig] = None,
1753
2438
  **connection_kwargs,
1754
2439
  ):
1755
2440
  max_connections = max_connections or 2**31
@@ -1757,16 +2442,16 @@ class ConnectionPool:
1757
2442
  raise ValueError('"max_connections" must be a positive integer')
1758
2443
 
1759
2444
  self.connection_class = connection_class
1760
- self.connection_kwargs = connection_kwargs
2445
+ self._connection_kwargs = connection_kwargs
1761
2446
  self.max_connections = max_connections
1762
2447
  self.cache = None
1763
2448
  self._cache_factory = cache_factory
1764
2449
 
1765
2450
  if connection_kwargs.get("cache_config") or connection_kwargs.get("cache"):
1766
- if self.connection_kwargs.get("protocol") not in [3, "3"]:
2451
+ if self._connection_kwargs.get("protocol") not in [3, "3"]:
1767
2452
  raise RedisError("Client caching is only supported with RESP version 3")
1768
2453
 
1769
- cache = self.connection_kwargs.get("cache")
2454
+ cache = self._connection_kwargs.get("cache")
1770
2455
 
1771
2456
  if cache is not None:
1772
2457
  if not isinstance(cache, CacheInterface):
@@ -1778,29 +2463,13 @@ class ConnectionPool:
1778
2463
  self.cache = self._cache_factory.get_cache()
1779
2464
  else:
1780
2465
  self.cache = CacheFactory(
1781
- self.connection_kwargs.get("cache_config")
2466
+ self._connection_kwargs.get("cache_config")
1782
2467
  ).get_cache()
1783
2468
 
1784
2469
  connection_kwargs.pop("cache", None)
1785
2470
  connection_kwargs.pop("cache_config", None)
1786
2471
 
1787
- if self.connection_kwargs.get(
1788
- "maint_notifications_pool_handler"
1789
- ) or self.connection_kwargs.get("maint_notifications_config"):
1790
- if self.connection_kwargs.get("protocol") not in [3, "3"]:
1791
- raise RedisError(
1792
- "Push handlers on connection are only supported with RESP version 3"
1793
- )
1794
- config = self.connection_kwargs.get("maint_notifications_config", None) or (
1795
- self.connection_kwargs.get("maint_notifications_pool_handler").config
1796
- if self.connection_kwargs.get("maint_notifications_pool_handler")
1797
- else None
1798
- )
1799
-
1800
- if config and config.enabled:
1801
- self._update_connection_kwargs_for_maint_notifications()
1802
-
1803
- self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None)
2472
+ self._event_dispatcher = self._connection_kwargs.get("event_dispatcher", None)
1804
2473
  if self._event_dispatcher is None:
1805
2474
  self._event_dispatcher = EventDispatcher()
1806
2475
 
@@ -1816,6 +2485,12 @@ class ConnectionPool:
1816
2485
  self._fork_lock = threading.RLock()
1817
2486
  self._lock = threading.RLock()
1818
2487
 
2488
+ MaintNotificationsAbstractConnectionPool.__init__(
2489
+ self,
2490
+ maint_notifications_config=maint_notifications_config,
2491
+ **connection_kwargs,
2492
+ )
2493
+
1819
2494
  self.reset()
1820
2495
 
1821
2496
  def __repr__(self) -> str:
@@ -1826,76 +2501,21 @@ class ConnectionPool:
1826
2501
  f"({conn_kwargs})>)>"
1827
2502
  )
1828
2503
 
2504
+ @property
2505
+ def connection_kwargs(self) -> Dict[str, Any]:
2506
+ return self._connection_kwargs
2507
+
2508
+ @connection_kwargs.setter
2509
+ def connection_kwargs(self, value: Dict[str, Any]):
2510
+ self._connection_kwargs = value
2511
+
1829
2512
  def get_protocol(self):
1830
2513
  """
1831
2514
  Returns:
1832
2515
  The RESP protocol version, or ``None`` if the protocol is not specified,
1833
2516
  in which case the server default will be used.
1834
- """
1835
- return self.connection_kwargs.get("protocol", None)
1836
-
1837
- def maint_notifications_pool_handler_enabled(self):
1838
- """
1839
- Returns:
1840
- True if the maintenance notifications pool handler is enabled, False otherwise.
1841
- """
1842
- maint_notifications_config = self.connection_kwargs.get(
1843
- "maint_notifications_config", None
1844
- )
1845
-
1846
- return maint_notifications_config and maint_notifications_config.enabled
1847
-
1848
- def set_maint_notifications_pool_handler(
1849
- self, maint_notifications_pool_handler: MaintNotificationsPoolHandler
1850
- ):
1851
- self.connection_kwargs.update(
1852
- {
1853
- "maint_notifications_pool_handler": maint_notifications_pool_handler,
1854
- "maint_notifications_config": maint_notifications_pool_handler.config,
1855
- }
1856
- )
1857
- self._update_connection_kwargs_for_maint_notifications()
1858
-
1859
- self._update_maint_notifications_configs_for_connections(
1860
- maint_notifications_pool_handler
1861
- )
1862
-
1863
- def _update_maint_notifications_configs_for_connections(
1864
- self, maint_notifications_pool_handler
1865
- ):
1866
- """Update the maintenance notifications config for all connections in the pool."""
1867
- with self._lock:
1868
- for conn in self._available_connections:
1869
- conn.set_maint_notifications_pool_handler(
1870
- maint_notifications_pool_handler
1871
- )
1872
- conn.maint_notifications_config = (
1873
- maint_notifications_pool_handler.config
1874
- )
1875
- for conn in self._in_use_connections:
1876
- conn.set_maint_notifications_pool_handler(
1877
- maint_notifications_pool_handler
1878
- )
1879
- conn.maint_notifications_config = (
1880
- maint_notifications_pool_handler.config
1881
- )
1882
-
1883
- def _update_connection_kwargs_for_maint_notifications(self):
1884
- """Store original connection parameters for maintenance notifications."""
1885
- if self.connection_kwargs.get("orig_host_address", None) is None:
1886
- # If orig_host_address is None it means we haven't
1887
- # configured the original values yet
1888
- self.connection_kwargs.update(
1889
- {
1890
- "orig_host_address": self.connection_kwargs.get("host"),
1891
- "orig_socket_timeout": self.connection_kwargs.get(
1892
- "socket_timeout", None
1893
- ),
1894
- "orig_socket_connect_timeout": self.connection_kwargs.get(
1895
- "socket_connect_timeout", None
1896
- ),
1897
- }
1898
- )
2517
+ """
2518
+ return self.connection_kwargs.get("protocol", None)
1899
2519
 
1900
2520
  def reset(self) -> None:
1901
2521
  self._created_connections = 0
@@ -1987,7 +2607,7 @@ class ConnectionPool:
1987
2607
  if (
1988
2608
  connection.can_read()
1989
2609
  and self.cache is None
1990
- and not self.maint_notifications_pool_handler_enabled()
2610
+ and not self.maint_notifications_enabled()
1991
2611
  ):
1992
2612
  raise ConnectionError("Connection has data")
1993
2613
  except (ConnectionError, TimeoutError, OSError):
@@ -2059,7 +2679,7 @@ class ConnectionPool:
2059
2679
  Disconnects connections in the pool
2060
2680
 
2061
2681
  If ``inuse_connections`` is True, disconnect connections that are
2062
- current in use, potentially by other threads. Otherwise only disconnect
2682
+ currently in use, potentially by other threads. Otherwise only disconnect
2063
2683
  connections that are idle in the pool.
2064
2684
  """
2065
2685
  self._checkpid()
@@ -2100,185 +2720,16 @@ class ConnectionPool:
2100
2720
  for conn in self._in_use_connections:
2101
2721
  conn.set_re_auth_token(token)
2102
2722
 
2103
- def _should_update_connection(
2104
- self,
2105
- conn: "Connection",
2106
- matching_pattern: Literal[
2107
- "connected_address", "configured_address", "notification_hash"
2108
- ] = "connected_address",
2109
- matching_address: Optional[str] = None,
2110
- matching_notification_hash: Optional[int] = None,
2111
- ) -> bool:
2112
- """
2113
- Check if the connection should be updated based on the matching criteria.
2114
- """
2115
- if matching_pattern == "connected_address":
2116
- if matching_address and conn.getpeername() != matching_address:
2117
- return False
2118
- elif matching_pattern == "configured_address":
2119
- if matching_address and conn.host != matching_address:
2120
- return False
2121
- elif matching_pattern == "notification_hash":
2122
- if (
2123
- matching_notification_hash
2124
- and conn.maintenance_notification_hash != matching_notification_hash
2125
- ):
2126
- return False
2127
- return True
2128
-
2129
- def update_connection_settings(
2130
- self,
2131
- conn: "Connection",
2132
- state: Optional["MaintenanceState"] = None,
2133
- maintenance_notification_hash: Optional[int] = None,
2134
- host_address: Optional[str] = None,
2135
- relaxed_timeout: Optional[float] = None,
2136
- update_notification_hash: bool = False,
2137
- reset_host_address: bool = False,
2138
- reset_relaxed_timeout: bool = False,
2139
- ):
2140
- """
2141
- Update the settings for a single connection.
2142
- """
2143
- if state:
2144
- conn.maintenance_state = state
2145
-
2146
- if update_notification_hash:
2147
- # update the notification hash only if requested
2148
- conn.maintenance_notification_hash = maintenance_notification_hash
2149
-
2150
- if host_address is not None:
2151
- conn.set_tmp_settings(tmp_host_address=host_address)
2152
-
2153
- if relaxed_timeout is not None:
2154
- conn.set_tmp_settings(tmp_relaxed_timeout=relaxed_timeout)
2155
-
2156
- if reset_relaxed_timeout or reset_host_address:
2157
- conn.reset_tmp_settings(
2158
- reset_host_address=reset_host_address,
2159
- reset_relaxed_timeout=reset_relaxed_timeout,
2160
- )
2161
-
2162
- conn.update_current_socket_timeout(relaxed_timeout)
2163
-
2164
- def update_connections_settings(
2165
- self,
2166
- state: Optional["MaintenanceState"] = None,
2167
- maintenance_notification_hash: Optional[int] = None,
2168
- host_address: Optional[str] = None,
2169
- relaxed_timeout: Optional[float] = None,
2170
- matching_address: Optional[str] = None,
2171
- matching_notification_hash: Optional[int] = None,
2172
- matching_pattern: Literal[
2173
- "connected_address", "configured_address", "notification_hash"
2174
- ] = "connected_address",
2175
- update_notification_hash: bool = False,
2176
- reset_host_address: bool = False,
2177
- reset_relaxed_timeout: bool = False,
2178
- include_free_connections: bool = True,
2179
- ):
2180
- """
2181
- Update the settings for all matching connections in the pool.
2182
-
2183
- This method does not create new connections.
2184
- This method does not affect the connection kwargs.
2185
-
2186
- :param state: The maintenance state to set for the connection.
2187
- :param maintenance_notification_hash: The hash of the maintenance notification
2188
- to set for the connection.
2189
- :param host_address: The host address to set for the connection.
2190
- :param relaxed_timeout: The relaxed timeout to set for the connection.
2191
- :param matching_address: The address to match for the connection.
2192
- :param matching_notification_hash: The notification hash to match for the connection.
2193
- :param matching_pattern: The pattern to match for the connection.
2194
- :param update_notification_hash: Whether to update the notification hash for the connection.
2195
- :param reset_host_address: Whether to reset the host address to the original address.
2196
- :param reset_relaxed_timeout: Whether to reset the relaxed timeout to the original timeout.
2197
- :param include_free_connections: Whether to include free/available connections.
2198
- """
2199
- with self._lock:
2200
- for conn in self._in_use_connections:
2201
- if self._should_update_connection(
2202
- conn,
2203
- matching_pattern,
2204
- matching_address,
2205
- matching_notification_hash,
2206
- ):
2207
- self.update_connection_settings(
2208
- conn,
2209
- state=state,
2210
- maintenance_notification_hash=maintenance_notification_hash,
2211
- host_address=host_address,
2212
- relaxed_timeout=relaxed_timeout,
2213
- update_notification_hash=update_notification_hash,
2214
- reset_host_address=reset_host_address,
2215
- reset_relaxed_timeout=reset_relaxed_timeout,
2216
- )
2217
-
2218
- if include_free_connections:
2219
- for conn in self._available_connections:
2220
- if self._should_update_connection(
2221
- conn,
2222
- matching_pattern,
2223
- matching_address,
2224
- matching_notification_hash,
2225
- ):
2226
- self.update_connection_settings(
2227
- conn,
2228
- state=state,
2229
- maintenance_notification_hash=maintenance_notification_hash,
2230
- host_address=host_address,
2231
- relaxed_timeout=relaxed_timeout,
2232
- update_notification_hash=update_notification_hash,
2233
- reset_host_address=reset_host_address,
2234
- reset_relaxed_timeout=reset_relaxed_timeout,
2235
- )
2236
-
2237
- def update_connection_kwargs(
2238
- self,
2239
- **kwargs,
2240
- ):
2241
- """
2242
- Update the connection kwargs for all future connections.
2243
-
2244
- This method updates the connection kwargs for all future connections created by the pool.
2245
- Existing connections are not affected.
2246
- """
2247
- self.connection_kwargs.update(kwargs)
2248
-
2249
- def update_active_connections_for_reconnect(
2250
- self,
2251
- moving_address_src: Optional[str] = None,
2252
- ):
2253
- """
2254
- Mark all active connections for reconnect.
2255
- This is used when a cluster node is migrated to a different address.
2723
+ def _get_pool_lock(self):
2724
+ return self._lock
2256
2725
 
2257
- :param moving_address_src: The address of the node that is being moved.
2258
- """
2726
+ def _get_free_connections(self):
2259
2727
  with self._lock:
2260
- for conn in self._in_use_connections:
2261
- if self._should_update_connection(
2262
- conn, "connected_address", moving_address_src
2263
- ):
2264
- conn.mark_for_reconnect()
2265
-
2266
- def disconnect_free_connections(
2267
- self,
2268
- moving_address_src: Optional[str] = None,
2269
- ):
2270
- """
2271
- Disconnect all free/available connections.
2272
- This is used when a cluster node is migrated to a different address.
2728
+ return self._available_connections
2273
2729
 
2274
- :param moving_address_src: The address of the node that is being moved.
2275
- """
2730
+ def _get_in_use_connections(self):
2276
2731
  with self._lock:
2277
- for conn in self._available_connections:
2278
- if self._should_update_connection(
2279
- conn, "connected_address", moving_address_src
2280
- ):
2281
- conn.disconnect()
2732
+ return self._in_use_connections
2282
2733
 
2283
2734
  async def _mock(self, error: RedisError):
2284
2735
  """
@@ -2391,7 +2842,7 @@ class BlockingConnectionPool(ConnectionPool):
2391
2842
  )
2392
2843
  else:
2393
2844
  connection = self.connection_class(**self.connection_kwargs)
2394
- self._connections.append(connection)
2845
+ self._connections.append(connection)
2395
2846
  return connection
2396
2847
  finally:
2397
2848
  if self._locked:
@@ -2503,14 +2954,18 @@ class BlockingConnectionPool(ConnectionPool):
2503
2954
  pass
2504
2955
  self._locked = False
2505
2956
 
2506
- def disconnect(self):
2507
- "Disconnects all connections in the pool."
2957
+ def disconnect(self, inuse_connections: bool = True):
2958
+ "Disconnects either all connections in the pool or just the free connections."
2508
2959
  self._checkpid()
2509
2960
  try:
2510
2961
  if self._in_maintenance:
2511
2962
  self._lock.acquire()
2512
2963
  self._locked = True
2513
- for connection in self._connections:
2964
+ if inuse_connections:
2965
+ connections = self._connections
2966
+ else:
2967
+ connections = self._get_free_connections()
2968
+ for connection in connections:
2514
2969
  connection.disconnect()
2515
2970
  finally:
2516
2971
  if self._locked:
@@ -2520,124 +2975,19 @@ class BlockingConnectionPool(ConnectionPool):
2520
2975
  pass
2521
2976
  self._locked = False
2522
2977
 
2523
- def update_connections_settings(
2524
- self,
2525
- state: Optional["MaintenanceState"] = None,
2526
- maintenance_notification_hash: Optional[int] = None,
2527
- relaxed_timeout: Optional[float] = None,
2528
- host_address: Optional[str] = None,
2529
- matching_address: Optional[str] = None,
2530
- matching_notification_hash: Optional[int] = None,
2531
- matching_pattern: Literal[
2532
- "connected_address", "configured_address", "notification_hash"
2533
- ] = "connected_address",
2534
- update_notification_hash: bool = False,
2535
- reset_host_address: bool = False,
2536
- reset_relaxed_timeout: bool = False,
2537
- include_free_connections: bool = True,
2538
- ):
2539
- """
2540
- Override base class method to work with BlockingConnectionPool's structure.
2541
- """
2978
+ def _get_free_connections(self):
2542
2979
  with self._lock:
2543
- if include_free_connections:
2544
- for conn in tuple(self._connections):
2545
- if self._should_update_connection(
2546
- conn,
2547
- matching_pattern,
2548
- matching_address,
2549
- matching_notification_hash,
2550
- ):
2551
- self.update_connection_settings(
2552
- conn,
2553
- state=state,
2554
- maintenance_notification_hash=maintenance_notification_hash,
2555
- host_address=host_address,
2556
- relaxed_timeout=relaxed_timeout,
2557
- update_notification_hash=update_notification_hash,
2558
- reset_host_address=reset_host_address,
2559
- reset_relaxed_timeout=reset_relaxed_timeout,
2560
- )
2561
- else:
2562
- connections_in_queue = {conn for conn in self.pool.queue if conn}
2563
- for conn in self._connections:
2564
- if conn not in connections_in_queue:
2565
- if self._should_update_connection(
2566
- conn,
2567
- matching_pattern,
2568
- matching_address,
2569
- matching_notification_hash,
2570
- ):
2571
- self.update_connection_settings(
2572
- conn,
2573
- state=state,
2574
- maintenance_notification_hash=maintenance_notification_hash,
2575
- host_address=host_address,
2576
- relaxed_timeout=relaxed_timeout,
2577
- update_notification_hash=update_notification_hash,
2578
- reset_host_address=reset_host_address,
2579
- reset_relaxed_timeout=reset_relaxed_timeout,
2580
- )
2581
-
2582
- def update_active_connections_for_reconnect(
2583
- self,
2584
- moving_address_src: Optional[str] = None,
2585
- ):
2586
- """
2587
- Mark all active connections for reconnect.
2588
- This is used when a cluster node is migrated to a different address.
2980
+ return {conn for conn in self.pool.queue if conn}
2589
2981
 
2590
- :param moving_address_src: The address of the node that is being moved.
2591
- """
2982
+ def _get_in_use_connections(self):
2592
2983
  with self._lock:
2984
+ # free connections
2593
2985
  connections_in_queue = {conn for conn in self.pool.queue if conn}
2594
- for conn in self._connections:
2595
- if conn not in connections_in_queue:
2596
- if self._should_update_connection(
2597
- conn,
2598
- matching_pattern="connected_address",
2599
- matching_address=moving_address_src,
2600
- ):
2601
- conn.mark_for_reconnect()
2602
-
2603
- def disconnect_free_connections(
2604
- self,
2605
- moving_address_src: Optional[str] = None,
2606
- ):
2607
- """
2608
- Disconnect all free/available connections.
2609
- This is used when a cluster node is migrated to a different address.
2610
-
2611
- :param moving_address_src: The address of the node that is being moved.
2612
- """
2613
- with self._lock:
2614
- existing_connections = self.pool.queue
2615
-
2616
- for conn in existing_connections:
2617
- if conn:
2618
- if self._should_update_connection(
2619
- conn, "connected_address", moving_address_src
2620
- ):
2621
- conn.disconnect()
2622
-
2623
- def _update_maint_notifications_config_for_connections(
2624
- self, maint_notifications_config
2625
- ):
2626
- for conn in tuple(self._connections):
2627
- conn.maint_notifications_config = maint_notifications_config
2628
-
2629
- def _update_maint_notifications_configs_for_connections(
2630
- self, maint_notifications_pool_handler
2631
- ):
2632
- """Update the maintenance notifications config for all connections in the pool."""
2633
- with self._lock:
2634
- for conn in tuple(self._connections):
2635
- conn.set_maint_notifications_pool_handler(
2636
- maint_notifications_pool_handler
2637
- )
2638
- conn.maint_notifications_config = (
2639
- maint_notifications_pool_handler.config
2640
- )
2986
+ # in self._connections we keep all created connections
2987
+ # so the ones that are not in the queue are the in use ones
2988
+ return {
2989
+ conn for conn in self._connections if conn not in connections_in_queue
2990
+ }
2641
2991
 
2642
2992
  def set_in_maintenance(self, in_maintenance: bool):
2643
2993
  """