redis 7.0.0b2__py3-none-any.whl → 7.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. redis/__init__.py +1 -1
  2. redis/_parsers/base.py +6 -0
  3. redis/_parsers/helpers.py +64 -6
  4. redis/asyncio/client.py +14 -5
  5. redis/asyncio/cluster.py +5 -1
  6. redis/asyncio/connection.py +19 -1
  7. redis/asyncio/http/__init__.py +0 -0
  8. redis/asyncio/http/http_client.py +265 -0
  9. redis/asyncio/multidb/__init__.py +0 -0
  10. redis/asyncio/multidb/client.py +530 -0
  11. redis/asyncio/multidb/command_executor.py +339 -0
  12. redis/asyncio/multidb/config.py +210 -0
  13. redis/asyncio/multidb/database.py +69 -0
  14. redis/asyncio/multidb/event.py +84 -0
  15. redis/asyncio/multidb/failover.py +125 -0
  16. redis/asyncio/multidb/failure_detector.py +38 -0
  17. redis/asyncio/multidb/healthcheck.py +285 -0
  18. redis/background.py +204 -0
  19. redis/client.py +49 -27
  20. redis/cluster.py +9 -1
  21. redis/commands/core.py +64 -29
  22. redis/commands/json/commands.py +2 -2
  23. redis/commands/search/__init__.py +2 -2
  24. redis/commands/search/aggregation.py +24 -26
  25. redis/commands/search/commands.py +10 -10
  26. redis/commands/search/field.py +2 -2
  27. redis/commands/search/query.py +12 -12
  28. redis/connection.py +1613 -1263
  29. redis/data_structure.py +81 -0
  30. redis/event.py +84 -10
  31. redis/exceptions.py +8 -0
  32. redis/http/__init__.py +0 -0
  33. redis/http/http_client.py +425 -0
  34. redis/maint_notifications.py +18 -7
  35. redis/multidb/__init__.py +0 -0
  36. redis/multidb/circuit.py +144 -0
  37. redis/multidb/client.py +526 -0
  38. redis/multidb/command_executor.py +350 -0
  39. redis/multidb/config.py +207 -0
  40. redis/multidb/database.py +130 -0
  41. redis/multidb/event.py +89 -0
  42. redis/multidb/exception.py +17 -0
  43. redis/multidb/failover.py +125 -0
  44. redis/multidb/failure_detector.py +104 -0
  45. redis/multidb/healthcheck.py +282 -0
  46. redis/retry.py +14 -1
  47. redis/utils.py +34 -0
  48. {redis-7.0.0b2.dist-info → redis-7.0.1.dist-info}/METADATA +17 -4
  49. {redis-7.0.0b2.dist-info → redis-7.0.1.dist-info}/RECORD +51 -25
  50. {redis-7.0.0b2.dist-info → redis-7.0.1.dist-info}/WHEEL +0 -0
  51. {redis-7.0.0b2.dist-info → redis-7.0.1.dist-info}/licenses/LICENSE +0 -0
redis/multidb/event.py ADDED
@@ -0,0 +1,89 @@
1
+ from typing import List
2
+
3
+ from redis.client import Redis
4
+ from redis.event import EventListenerInterface, OnCommandsFailEvent
5
+ from redis.multidb.database import SyncDatabase
6
+ from redis.multidb.failure_detector import FailureDetector
7
+
8
+
9
+ class ActiveDatabaseChanged:
10
+ """
11
+ Event fired when an active database has been changed.
12
+ """
13
+
14
+ def __init__(
15
+ self,
16
+ old_database: SyncDatabase,
17
+ new_database: SyncDatabase,
18
+ command_executor,
19
+ **kwargs,
20
+ ):
21
+ self._old_database = old_database
22
+ self._new_database = new_database
23
+ self._command_executor = command_executor
24
+ self._kwargs = kwargs
25
+
26
+ @property
27
+ def old_database(self) -> SyncDatabase:
28
+ return self._old_database
29
+
30
+ @property
31
+ def new_database(self) -> SyncDatabase:
32
+ return self._new_database
33
+
34
+ @property
35
+ def command_executor(self):
36
+ return self._command_executor
37
+
38
+ @property
39
+ def kwargs(self):
40
+ return self._kwargs
41
+
42
+
43
+ class ResubscribeOnActiveDatabaseChanged(EventListenerInterface):
44
+ """
45
+ Re-subscribe the currently active pub / sub to a new active database.
46
+ """
47
+
48
+ def listen(self, event: ActiveDatabaseChanged):
49
+ old_pubsub = event.command_executor.active_pubsub
50
+
51
+ if old_pubsub is not None:
52
+ # Re-assign old channels and patterns so they will be automatically subscribed on connection.
53
+ new_pubsub = event.new_database.client.pubsub(**event.kwargs)
54
+ new_pubsub.channels = old_pubsub.channels
55
+ new_pubsub.patterns = old_pubsub.patterns
56
+ new_pubsub.shard_channels = old_pubsub.shard_channels
57
+ new_pubsub.on_connect(None)
58
+ event.command_executor.active_pubsub = new_pubsub
59
+ old_pubsub.close()
60
+
61
+
62
+ class CloseConnectionOnActiveDatabaseChanged(EventListenerInterface):
63
+ """
64
+ Close connection to the old active database.
65
+ """
66
+
67
+ def listen(self, event: ActiveDatabaseChanged):
68
+ event.old_database.client.close()
69
+
70
+ if isinstance(event.old_database.client, Redis):
71
+ event.old_database.client.connection_pool.update_active_connections_for_reconnect()
72
+ event.old_database.client.connection_pool.disconnect()
73
+ else:
74
+ for node in event.old_database.client.nodes_manager.nodes_cache.values():
75
+ node.redis_connection.connection_pool.update_active_connections_for_reconnect()
76
+ node.redis_connection.connection_pool.disconnect()
77
+
78
+
79
+ class RegisterCommandFailure(EventListenerInterface):
80
+ """
81
+ Event listener that registers command failures and passing it to the failure detectors.
82
+ """
83
+
84
+ def __init__(self, failure_detectors: List[FailureDetector]):
85
+ self._failure_detectors = failure_detectors
86
+
87
+ def listen(self, event: OnCommandsFailEvent) -> None:
88
+ for failure_detector in self._failure_detectors:
89
+ failure_detector.register_failure(event.exception, event.commands)
@@ -0,0 +1,17 @@
1
+ class NoValidDatabaseException(Exception):
2
+ pass
3
+
4
+
5
+ class UnhealthyDatabaseException(Exception):
6
+ """Exception raised when a database is unhealthy due to an underlying exception."""
7
+
8
+ def __init__(self, message, database, original_exception):
9
+ super().__init__(message)
10
+ self.database = database
11
+ self.original_exception = original_exception
12
+
13
+
14
+ class TemporaryUnavailableException(Exception):
15
+ """Exception raised when all databases in setup are temporary unavailable."""
16
+
17
+ pass
@@ -0,0 +1,125 @@
1
+ import time
2
+ from abc import ABC, abstractmethod
3
+
4
+ from redis.data_structure import WeightedList
5
+ from redis.multidb.circuit import State as CBState
6
+ from redis.multidb.database import Databases, SyncDatabase
7
+ from redis.multidb.exception import (
8
+ NoValidDatabaseException,
9
+ TemporaryUnavailableException,
10
+ )
11
+
12
+ DEFAULT_FAILOVER_ATTEMPTS = 10
13
+ DEFAULT_FAILOVER_DELAY = 12
14
+
15
+
16
+ class FailoverStrategy(ABC):
17
+ @abstractmethod
18
+ def database(self) -> SyncDatabase:
19
+ """Select the database according to the strategy."""
20
+ pass
21
+
22
+ @abstractmethod
23
+ def set_databases(self, databases: Databases) -> None:
24
+ """Set the database strategy operates on."""
25
+ pass
26
+
27
+
28
+ class FailoverStrategyExecutor(ABC):
29
+ @property
30
+ @abstractmethod
31
+ def failover_attempts(self) -> int:
32
+ """The number of failover attempts."""
33
+ pass
34
+
35
+ @property
36
+ @abstractmethod
37
+ def failover_delay(self) -> float:
38
+ """The delay between failover attempts."""
39
+ pass
40
+
41
+ @property
42
+ @abstractmethod
43
+ def strategy(self) -> FailoverStrategy:
44
+ """The strategy to execute."""
45
+ pass
46
+
47
+ @abstractmethod
48
+ def execute(self) -> SyncDatabase:
49
+ """Execute the failover strategy."""
50
+ pass
51
+
52
+
53
+ class WeightBasedFailoverStrategy(FailoverStrategy):
54
+ """
55
+ Failover strategy based on database weights.
56
+ """
57
+
58
+ def __init__(self) -> None:
59
+ self._databases = WeightedList()
60
+
61
+ def database(self) -> SyncDatabase:
62
+ for database, _ in self._databases:
63
+ if database.circuit.state == CBState.CLOSED:
64
+ return database
65
+
66
+ raise NoValidDatabaseException("No valid database available for communication")
67
+
68
+ def set_databases(self, databases: Databases) -> None:
69
+ self._databases = databases
70
+
71
+
72
+ class DefaultFailoverStrategyExecutor(FailoverStrategyExecutor):
73
+ """
74
+ Executes given failover strategy.
75
+ """
76
+
77
+ def __init__(
78
+ self,
79
+ strategy: FailoverStrategy,
80
+ failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS,
81
+ failover_delay: float = DEFAULT_FAILOVER_DELAY,
82
+ ):
83
+ self._strategy = strategy
84
+ self._failover_attempts = failover_attempts
85
+ self._failover_delay = failover_delay
86
+ self._next_attempt_ts: int = 0
87
+ self._failover_counter: int = 0
88
+
89
+ @property
90
+ def failover_attempts(self) -> int:
91
+ return self._failover_attempts
92
+
93
+ @property
94
+ def failover_delay(self) -> float:
95
+ return self._failover_delay
96
+
97
+ @property
98
+ def strategy(self) -> FailoverStrategy:
99
+ return self._strategy
100
+
101
+ def execute(self) -> SyncDatabase:
102
+ try:
103
+ database = self._strategy.database()
104
+ self._reset()
105
+ return database
106
+ except NoValidDatabaseException as e:
107
+ if self._next_attempt_ts == 0:
108
+ self._next_attempt_ts = time.time() + self._failover_delay
109
+ self._failover_counter += 1
110
+ elif time.time() >= self._next_attempt_ts:
111
+ self._next_attempt_ts += self._failover_delay
112
+ self._failover_counter += 1
113
+
114
+ if self._failover_counter > self._failover_attempts:
115
+ self._reset()
116
+ raise e
117
+ else:
118
+ raise TemporaryUnavailableException(
119
+ "No database connections currently available. "
120
+ "This is a temporary condition - please retry the operation."
121
+ )
122
+
123
+ def _reset(self) -> None:
124
+ self._next_attempt_ts = 0
125
+ self._failover_counter = 0
@@ -0,0 +1,104 @@
1
+ import math
2
+ import threading
3
+ from abc import ABC, abstractmethod
4
+ from datetime import datetime, timedelta
5
+ from typing import List, Type
6
+
7
+ from typing_extensions import Optional
8
+
9
+ from redis.multidb.circuit import State as CBState
10
+
11
+ DEFAULT_MIN_NUM_FAILURES = 1000
12
+ DEFAULT_FAILURE_RATE_THRESHOLD = 0.1
13
+ DEFAULT_FAILURES_DETECTION_WINDOW = 2
14
+
15
+
16
+ class FailureDetector(ABC):
17
+ @abstractmethod
18
+ def register_failure(self, exception: Exception, cmd: tuple) -> None:
19
+ """Register a failure that occurred during command execution."""
20
+ pass
21
+
22
+ @abstractmethod
23
+ def register_command_execution(self, cmd: tuple) -> None:
24
+ """Register a command execution."""
25
+ pass
26
+
27
+ @abstractmethod
28
+ def set_command_executor(self, command_executor) -> None:
29
+ """Set the command executor for this failure."""
30
+ pass
31
+
32
+
33
+ class CommandFailureDetector(FailureDetector):
34
+ """
35
+ Detects a failure based on a threshold of failed commands during a specific period of time.
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ min_num_failures: int = DEFAULT_MIN_NUM_FAILURES,
41
+ failure_rate_threshold: float = DEFAULT_FAILURE_RATE_THRESHOLD,
42
+ failure_detection_window: float = DEFAULT_FAILURES_DETECTION_WINDOW,
43
+ error_types: Optional[List[Type[Exception]]] = None,
44
+ ) -> None:
45
+ """
46
+ Initialize a new CommandFailureDetector instance.
47
+
48
+ Args:
49
+ min_num_failures: Minimal count of failures required for failover
50
+ failure_rate_threshold: Percentage of failures required for failover
51
+ failure_detection_window: Time interval for executing health checks.
52
+ error_types: Optional list of exception types to trigger failover. If None, all exceptions are counted.
53
+
54
+ The detector tracks command failures within a sliding time window. When the number of failures
55
+ exceeds the threshold within the specified duration, it triggers failure detection.
56
+ """
57
+ self._command_executor = None
58
+ self._min_num_failures = min_num_failures
59
+ self._failure_rate_threshold = failure_rate_threshold
60
+ self._failure_detection_window = failure_detection_window
61
+ self._error_types = error_types
62
+ self._commands_executed: int = 0
63
+ self._start_time: datetime = datetime.now()
64
+ self._end_time: datetime = self._start_time + timedelta(
65
+ seconds=self._failure_detection_window
66
+ )
67
+ self._failures_count: int = 0
68
+ self._lock = threading.RLock()
69
+
70
+ def register_failure(self, exception: Exception, cmd: tuple) -> None:
71
+ with self._lock:
72
+ if self._error_types:
73
+ if type(exception) in self._error_types:
74
+ self._failures_count += 1
75
+ else:
76
+ self._failures_count += 1
77
+
78
+ self._check_threshold()
79
+
80
+ def set_command_executor(self, command_executor) -> None:
81
+ self._command_executor = command_executor
82
+
83
+ def register_command_execution(self, cmd: tuple) -> None:
84
+ with self._lock:
85
+ if not self._start_time < datetime.now() < self._end_time:
86
+ self._reset()
87
+
88
+ self._commands_executed += 1
89
+
90
+ def _check_threshold(self):
91
+ if self._failures_count >= self._min_num_failures and self._failures_count >= (
92
+ math.ceil(self._commands_executed * self._failure_rate_threshold)
93
+ ):
94
+ self._command_executor.active_database.circuit.state = CBState.OPEN
95
+ self._reset()
96
+
97
+ def _reset(self) -> None:
98
+ with self._lock:
99
+ self._start_time = datetime.now()
100
+ self._end_time = self._start_time + timedelta(
101
+ seconds=self._failure_detection_window
102
+ )
103
+ self._failures_count = 0
104
+ self._commands_executed = 0
@@ -0,0 +1,282 @@
1
+ import logging
2
+ from abc import ABC, abstractmethod
3
+ from enum import Enum
4
+ from time import sleep
5
+ from typing import List, Optional, Tuple, Union
6
+
7
+ from redis import Redis
8
+ from redis.backoff import NoBackoff
9
+ from redis.http.http_client import DEFAULT_TIMEOUT, HttpClient
10
+ from redis.multidb.exception import UnhealthyDatabaseException
11
+ from redis.retry import Retry
12
+
13
+ DEFAULT_HEALTH_CHECK_PROBES = 3
14
+ DEFAULT_HEALTH_CHECK_INTERVAL = 5
15
+ DEFAULT_HEALTH_CHECK_DELAY = 0.5
16
+ DEFAULT_LAG_AWARE_TOLERANCE = 5000
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class HealthCheck(ABC):
22
+ @abstractmethod
23
+ def check_health(self, database) -> bool:
24
+ """Function to determine the health status."""
25
+ pass
26
+
27
+
28
+ class HealthCheckPolicy(ABC):
29
+ """
30
+ Health checks execution policy.
31
+ """
32
+
33
+ @property
34
+ @abstractmethod
35
+ def health_check_probes(self) -> int:
36
+ """Number of probes to execute health checks."""
37
+ pass
38
+
39
+ @property
40
+ @abstractmethod
41
+ def health_check_delay(self) -> float:
42
+ """Delay between health check probes."""
43
+ pass
44
+
45
+ @abstractmethod
46
+ def execute(self, health_checks: List[HealthCheck], database) -> bool:
47
+ """Execute health checks and return database health status."""
48
+ pass
49
+
50
+
51
+ class AbstractHealthCheckPolicy(HealthCheckPolicy):
52
+ def __init__(self, health_check_probes: int, health_check_delay: float):
53
+ if health_check_probes < 1:
54
+ raise ValueError("health_check_probes must be greater than 0")
55
+ self._health_check_probes = health_check_probes
56
+ self._health_check_delay = health_check_delay
57
+
58
+ @property
59
+ def health_check_probes(self) -> int:
60
+ return self._health_check_probes
61
+
62
+ @property
63
+ def health_check_delay(self) -> float:
64
+ return self._health_check_delay
65
+
66
+ @abstractmethod
67
+ def execute(self, health_checks: List[HealthCheck], database) -> bool:
68
+ pass
69
+
70
+
71
+ class HealthyAllPolicy(AbstractHealthCheckPolicy):
72
+ """
73
+ Policy that returns True if all health check probes are successful.
74
+ """
75
+
76
+ def __init__(self, health_check_probes: int, health_check_delay: float):
77
+ super().__init__(health_check_probes, health_check_delay)
78
+
79
+ def execute(self, health_checks: List[HealthCheck], database) -> bool:
80
+ for health_check in health_checks:
81
+ for attempt in range(self.health_check_probes):
82
+ try:
83
+ if not health_check.check_health(database):
84
+ return False
85
+ except Exception as e:
86
+ raise UnhealthyDatabaseException("Unhealthy database", database, e)
87
+
88
+ if attempt < self.health_check_probes - 1:
89
+ sleep(self._health_check_delay)
90
+ return True
91
+
92
+
93
+ class HealthyMajorityPolicy(AbstractHealthCheckPolicy):
94
+ """
95
+ Policy that returns True if a majority of health check probes are successful.
96
+ """
97
+
98
+ def __init__(self, health_check_probes: int, health_check_delay: float):
99
+ super().__init__(health_check_probes, health_check_delay)
100
+
101
+ def execute(self, health_checks: List[HealthCheck], database) -> bool:
102
+ for health_check in health_checks:
103
+ if self.health_check_probes % 2 == 0:
104
+ allowed_unsuccessful_probes = self.health_check_probes / 2
105
+ else:
106
+ allowed_unsuccessful_probes = (self.health_check_probes + 1) / 2
107
+
108
+ for attempt in range(self.health_check_probes):
109
+ try:
110
+ if not health_check.check_health(database):
111
+ allowed_unsuccessful_probes -= 1
112
+ if allowed_unsuccessful_probes <= 0:
113
+ return False
114
+ except Exception as e:
115
+ allowed_unsuccessful_probes -= 1
116
+ if allowed_unsuccessful_probes <= 0:
117
+ raise UnhealthyDatabaseException(
118
+ "Unhealthy database", database, e
119
+ )
120
+
121
+ if attempt < self.health_check_probes - 1:
122
+ sleep(self._health_check_delay)
123
+ return True
124
+
125
+
126
+ class HealthyAnyPolicy(AbstractHealthCheckPolicy):
127
+ """
128
+ Policy that returns True if at least one health check probe is successful.
129
+ """
130
+
131
+ def __init__(self, health_check_probes: int, health_check_delay: float):
132
+ super().__init__(health_check_probes, health_check_delay)
133
+
134
+ def execute(self, health_checks: List[HealthCheck], database) -> bool:
135
+ is_healthy = False
136
+
137
+ for health_check in health_checks:
138
+ exception = None
139
+
140
+ for attempt in range(self.health_check_probes):
141
+ try:
142
+ if health_check.check_health(database):
143
+ is_healthy = True
144
+ break
145
+ else:
146
+ is_healthy = False
147
+ except Exception as e:
148
+ exception = UnhealthyDatabaseException(
149
+ "Unhealthy database", database, e
150
+ )
151
+
152
+ if attempt < self.health_check_probes - 1:
153
+ sleep(self._health_check_delay)
154
+
155
+ if not is_healthy and not exception:
156
+ return is_healthy
157
+ elif not is_healthy and exception:
158
+ raise exception
159
+
160
+ return is_healthy
161
+
162
+
163
+ class HealthCheckPolicies(Enum):
164
+ HEALTHY_ALL = HealthyAllPolicy
165
+ HEALTHY_MAJORITY = HealthyMajorityPolicy
166
+ HEALTHY_ANY = HealthyAnyPolicy
167
+
168
+
169
+ DEFAULT_HEALTH_CHECK_POLICY: HealthCheckPolicies = HealthCheckPolicies.HEALTHY_ALL
170
+
171
+
172
+ class PingHealthCheck(HealthCheck):
173
+ """
174
+ Health check based on PING command.
175
+ """
176
+
177
+ def check_health(self, database) -> bool:
178
+ if isinstance(database.client, Redis):
179
+ return database.client.execute_command("PING")
180
+ else:
181
+ # For a cluster checks if all nodes are healthy.
182
+ all_nodes = database.client.get_nodes()
183
+ for node in all_nodes:
184
+ if not node.redis_connection.execute_command("PING"):
185
+ return False
186
+
187
+ return True
188
+
189
+
190
+ class LagAwareHealthCheck(HealthCheck):
191
+ """
192
+ Health check available for Redis Enterprise deployments.
193
+ Verify via REST API that the database is healthy based on different lags.
194
+ """
195
+
196
+ def __init__(
197
+ self,
198
+ rest_api_port: int = 9443,
199
+ lag_aware_tolerance: int = DEFAULT_LAG_AWARE_TOLERANCE,
200
+ timeout: float = DEFAULT_TIMEOUT,
201
+ auth_basic: Optional[Tuple[str, str]] = None,
202
+ verify_tls: bool = True,
203
+ # TLS verification (server) options
204
+ ca_file: Optional[str] = None,
205
+ ca_path: Optional[str] = None,
206
+ ca_data: Optional[Union[str, bytes]] = None,
207
+ # Mutual TLS (client cert) options
208
+ client_cert_file: Optional[str] = None,
209
+ client_key_file: Optional[str] = None,
210
+ client_key_password: Optional[str] = None,
211
+ ):
212
+ """
213
+ Initialize LagAwareHealthCheck with the specified parameters.
214
+
215
+ Args:
216
+ rest_api_port: Port number for Redis Enterprise REST API (default: 9443)
217
+ lag_aware_tolerance: Tolerance in lag between databases in MS (default: 100)
218
+ timeout: Request timeout in seconds (default: DEFAULT_TIMEOUT)
219
+ auth_basic: Tuple of (username, password) for basic authentication
220
+ verify_tls: Whether to verify TLS certificates (default: True)
221
+ ca_file: Path to CA certificate file for TLS verification
222
+ ca_path: Path to CA certificates directory for TLS verification
223
+ ca_data: CA certificate data as string or bytes
224
+ client_cert_file: Path to client certificate file for mutual TLS
225
+ client_key_file: Path to client private key file for mutual TLS
226
+ client_key_password: Password for encrypted client private key
227
+ """
228
+ self._http_client = HttpClient(
229
+ timeout=timeout,
230
+ auth_basic=auth_basic,
231
+ retry=Retry(NoBackoff(), retries=0),
232
+ verify_tls=verify_tls,
233
+ ca_file=ca_file,
234
+ ca_path=ca_path,
235
+ ca_data=ca_data,
236
+ client_cert_file=client_cert_file,
237
+ client_key_file=client_key_file,
238
+ client_key_password=client_key_password,
239
+ )
240
+ self._rest_api_port = rest_api_port
241
+ self._lag_aware_tolerance = lag_aware_tolerance
242
+
243
+ def check_health(self, database) -> bool:
244
+ if database.health_check_url is None:
245
+ raise ValueError(
246
+ "Database health check url is not set. Please check DatabaseConfig for the current database."
247
+ )
248
+
249
+ if isinstance(database.client, Redis):
250
+ db_host = database.client.get_connection_kwargs()["host"]
251
+ else:
252
+ db_host = database.client.startup_nodes[0].host
253
+
254
+ base_url = f"{database.health_check_url}:{self._rest_api_port}"
255
+ self._http_client.base_url = base_url
256
+
257
+ # Find bdb matching to the current database host
258
+ matching_bdb = None
259
+ for bdb in self._http_client.get("/v1/bdbs"):
260
+ for endpoint in bdb["endpoints"]:
261
+ if endpoint["dns_name"] == db_host:
262
+ matching_bdb = bdb
263
+ break
264
+
265
+ # In case if the host was set as public IP
266
+ for addr in endpoint["addr"]:
267
+ if addr == db_host:
268
+ matching_bdb = bdb
269
+ break
270
+
271
+ if matching_bdb is None:
272
+ logger.warning("LagAwareHealthCheck failed: Couldn't find a matching bdb")
273
+ raise ValueError("Could not find a matching bdb")
274
+
275
+ url = (
276
+ f"/v1/bdbs/{matching_bdb['uid']}/availability"
277
+ f"?extend_check=lag&availability_lag_tolerance_ms={self._lag_aware_tolerance}"
278
+ )
279
+ self._http_client.get(url, expect_json=False)
280
+
281
+ # Status checked in an http client, otherwise HttpError will be raised
282
+ return True
redis/retry.py CHANGED
@@ -1,7 +1,17 @@
1
1
  import abc
2
2
  import socket
3
3
  from time import sleep
4
- from typing import TYPE_CHECKING, Any, Callable, Generic, Iterable, Tuple, Type, TypeVar
4
+ from typing import (
5
+ TYPE_CHECKING,
6
+ Any,
7
+ Callable,
8
+ Generic,
9
+ Iterable,
10
+ Optional,
11
+ Tuple,
12
+ Type,
13
+ TypeVar,
14
+ )
5
15
 
6
16
  from redis.exceptions import ConnectionError, TimeoutError
7
17
 
@@ -91,6 +101,7 @@ class Retry(AbstractRetry[Exception]):
91
101
  self,
92
102
  do: Callable[[], T],
93
103
  fail: Callable[[Exception], Any],
104
+ is_retryable: Optional[Callable[[Exception], bool]] = None,
94
105
  ) -> T:
95
106
  """
96
107
  Execute an operation that might fail and returns its result, or
@@ -104,6 +115,8 @@ class Retry(AbstractRetry[Exception]):
104
115
  try:
105
116
  return do()
106
117
  except self._supported_errors as error:
118
+ if is_retryable and not is_retryable(error):
119
+ raise
107
120
  failures += 1
108
121
  fail(error)
109
122
  if self._retries >= 0 and failures > self._retries:
redis/utils.py CHANGED
@@ -1,6 +1,7 @@
1
1
  import datetime
2
2
  import logging
3
3
  import textwrap
4
+ import warnings
4
5
  from collections.abc import Callable
5
6
  from contextlib import contextmanager
6
7
  from functools import wraps
@@ -312,3 +313,36 @@ def truncate_text(txt, max_length=100):
312
313
  return textwrap.shorten(
313
314
  text=txt, width=max_length, placeholder="...", break_long_words=True
314
315
  )
316
+
317
+
318
+ def dummy_fail():
319
+ """
320
+ Fake function for a Retry object if you don't need to handle each failure.
321
+ """
322
+ pass
323
+
324
+
325
+ async def dummy_fail_async():
326
+ """
327
+ Async fake function for a Retry object if you don't need to handle each failure.
328
+ """
329
+ pass
330
+
331
+
332
+ def experimental(cls):
333
+ """
334
+ Decorator to mark a class as experimental.
335
+ """
336
+ original_init = cls.__init__
337
+
338
+ @wraps(original_init)
339
+ def new_init(self, *args, **kwargs):
340
+ warnings.warn(
341
+ f"{cls.__name__} is an experimental and may change or be removed in future versions.",
342
+ category=UserWarning,
343
+ stacklevel=2,
344
+ )
345
+ original_init(self, *args, **kwargs)
346
+
347
+ cls.__init__ = new_init
348
+ return cls