redis 6.4.0__tar.gz → 7.0.0b2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {redis-6.4.0 → redis-7.0.0b2}/PKG-INFO +1 -1
- {redis-6.4.0 → redis-7.0.0b2}/redis/__init__.py +1 -1
- {redis-6.4.0 → redis-7.0.0b2}/redis/_parsers/base.py +187 -8
- {redis-6.4.0 → redis-7.0.0b2}/redis/_parsers/hiredis.py +16 -10
- {redis-6.4.0 → redis-7.0.0b2}/redis/_parsers/resp3.py +11 -5
- {redis-6.4.0 → redis-7.0.0b2}/redis/asyncio/client.py +51 -3
- {redis-6.4.0 → redis-7.0.0b2}/redis/asyncio/cluster.py +52 -4
- {redis-6.4.0 → redis-7.0.0b2}/redis/asyncio/connection.py +43 -1
- {redis-6.4.0 → redis-7.0.0b2}/redis/cache.py +1 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/client.py +72 -13
- {redis-6.4.0 → redis-7.0.0b2}/redis/cluster.py +5 -2
- {redis-6.4.0 → redis-7.0.0b2}/redis/commands/core.py +285 -285
- {redis-6.4.0 → redis-7.0.0b2}/redis/commands/helpers.py +0 -20
- {redis-6.4.0 → redis-7.0.0b2}/redis/commands/search/query.py +12 -12
- {redis-6.4.0 → redis-7.0.0b2}/redis/commands/vectorset/commands.py +43 -25
- {redis-6.4.0 → redis-7.0.0b2}/redis/commands/vectorset/utils.py +40 -4
- {redis-6.4.0 → redis-7.0.0b2}/redis/connection.py +884 -60
- redis-7.0.0b2/redis/maint_notifications.py +799 -0
- redis-7.0.0b2/tests/test_asyncio/test_ssl.py +143 -0
- redis-7.0.0b2/tests/test_asyncio/test_usage_counter.py +16 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/test_vsets.py +113 -1
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_cluster.py +16 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_connection_pool.py +71 -3
- redis-7.0.0b2/tests/test_maint_notifications.py +893 -0
- redis-7.0.0b2/tests/test_maint_notifications_handling.py +2226 -0
- redis-7.0.0b2/tests/test_scenario/__init__.py +0 -0
- redis-7.0.0b2/tests/test_scenario/conftest.py +125 -0
- redis-7.0.0b2/tests/test_scenario/fault_injector_client.py +150 -0
- redis-7.0.0b2/tests/test_scenario/maint_notifications_helpers.py +326 -0
- redis-7.0.0b2/tests/test_scenario/test_maint_notifications.py +1111 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_ssl.py +97 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_vsets.py +111 -1
- redis-6.4.0/tests/test_asyncio/test_ssl.py +0 -56
- {redis-6.4.0 → redis-7.0.0b2}/.gitignore +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/LICENSE +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/README.md +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/dev_requirements.txt +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/pyproject.toml +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/_parsers/__init__.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/_parsers/commands.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/_parsers/encoders.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/_parsers/helpers.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/_parsers/resp2.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/_parsers/socket.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/asyncio/__init__.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/asyncio/lock.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/asyncio/retry.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/asyncio/sentinel.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/asyncio/utils.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/auth/__init__.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/auth/err.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/auth/idp.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/auth/token.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/auth/token_manager.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/backoff.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/commands/__init__.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/commands/bf/__init__.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/commands/bf/commands.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/commands/bf/info.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/commands/cluster.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/commands/json/__init__.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/commands/json/_util.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/commands/json/commands.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/commands/json/decoders.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/commands/json/path.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/commands/redismodules.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/commands/search/__init__.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/commands/search/_util.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/commands/search/aggregation.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/commands/search/commands.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/commands/search/dialect.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/commands/search/document.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/commands/search/field.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/commands/search/index_definition.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/commands/search/profile_information.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/commands/search/querystring.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/commands/search/reducers.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/commands/search/result.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/commands/search/suggestion.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/commands/sentinel.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/commands/timeseries/__init__.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/commands/timeseries/commands.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/commands/timeseries/info.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/commands/timeseries/utils.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/commands/vectorset/__init__.py +1 -1
- {redis-6.4.0 → redis-7.0.0b2}/redis/crc.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/credentials.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/event.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/exceptions.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/lock.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/ocsp.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/py.typed +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/retry.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/sentinel.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/typing.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/redis/utils.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/__init__.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/conftest.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/entraid_utils.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/mocks.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/ssl_utils.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/__init__.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/compat.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/conftest.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/mocks.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/test_bloom.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/test_cluster.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/test_cluster_transaction.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/test_commands.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/test_connect.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/test_connection.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/test_connection_pool.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/test_credentials.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/test_cwe_404.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/test_encoding.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/test_hash.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/test_json.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/test_lock.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/test_monitor.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/test_pipeline.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/test_pubsub.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/test_retry.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/test_scripting.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/test_search.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/test_sentinel.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/test_sentinel_managed_connection.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/test_timeseries.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/test_utils.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/testdata/jsontestdata.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/testdata/titles.csv +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/testdata/will_play_text.csv.bz2 +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_auth/__init__.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_auth/test_token.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_auth/test_token_manager.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_backoff.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_bloom.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_cache.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_cluster_transaction.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_command_parser.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_commands.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_connect.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_connection.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_credentials.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_encoding.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_function.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_hash.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_helpers.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_json.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_lock.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_max_connections_error.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_monitor.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_multiprocessing.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_parsers/test_helpers.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_pipeline.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_pubsub.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_retry.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_scripting.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_search.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_sentinel.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_sentinel_managed_connection.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_timeseries.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/test_utils.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/testdata/jsontestdata.py +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/testdata/titles.csv +0 -0
- {redis-6.4.0 → redis-7.0.0b2}/tests/testdata/will_play_text.csv.bz2 +0 -0
|
@@ -1,7 +1,17 @@
|
|
|
1
|
+
import logging
|
|
1
2
|
import sys
|
|
2
3
|
from abc import ABC
|
|
3
4
|
from asyncio import IncompleteReadError, StreamReader, TimeoutError
|
|
4
|
-
from typing import Callable, List, Optional, Protocol, Union
|
|
5
|
+
from typing import Awaitable, Callable, List, Optional, Protocol, Union
|
|
6
|
+
|
|
7
|
+
from redis.maint_notifications import (
|
|
8
|
+
MaintenanceNotification,
|
|
9
|
+
NodeFailedOverNotification,
|
|
10
|
+
NodeFailingOverNotification,
|
|
11
|
+
NodeMigratedNotification,
|
|
12
|
+
NodeMigratingNotification,
|
|
13
|
+
NodeMovingNotification,
|
|
14
|
+
)
|
|
5
15
|
|
|
6
16
|
if sys.version_info.major >= 3 and sys.version_info.minor >= 11:
|
|
7
17
|
from asyncio import timeout as async_timeout
|
|
@@ -50,6 +60,8 @@ NO_AUTH_SET_ERROR = {
|
|
|
50
60
|
"Client sent AUTH, but no password is set": AuthenticationError,
|
|
51
61
|
}
|
|
52
62
|
|
|
63
|
+
logger = logging.getLogger(__name__)
|
|
64
|
+
|
|
53
65
|
|
|
54
66
|
class BaseParser(ABC):
|
|
55
67
|
EXCEPTION_CLASSES = {
|
|
@@ -158,7 +170,77 @@ class AsyncBaseParser(BaseParser):
|
|
|
158
170
|
raise NotImplementedError()
|
|
159
171
|
|
|
160
172
|
|
|
161
|
-
|
|
173
|
+
class MaintenanceNotificationsParser:
|
|
174
|
+
"""Protocol defining maintenance push notification parsing functionality"""
|
|
175
|
+
|
|
176
|
+
@staticmethod
|
|
177
|
+
def parse_maintenance_start_msg(response, notification_type):
|
|
178
|
+
# Expected message format is: <notification_type> <seq_number> <time>
|
|
179
|
+
id = response[1]
|
|
180
|
+
ttl = response[2]
|
|
181
|
+
return notification_type(id, ttl)
|
|
182
|
+
|
|
183
|
+
@staticmethod
|
|
184
|
+
def parse_maintenance_completed_msg(response, notification_type):
|
|
185
|
+
# Expected message format is: <notification_type> <seq_number>
|
|
186
|
+
id = response[1]
|
|
187
|
+
return notification_type(id)
|
|
188
|
+
|
|
189
|
+
@staticmethod
|
|
190
|
+
def parse_moving_msg(response):
|
|
191
|
+
# Expected message format is: MOVING <seq_number> <time> <endpoint>
|
|
192
|
+
id = response[1]
|
|
193
|
+
ttl = response[2]
|
|
194
|
+
if response[3] is None:
|
|
195
|
+
host, port = None, None
|
|
196
|
+
else:
|
|
197
|
+
value = response[3]
|
|
198
|
+
if isinstance(value, bytes):
|
|
199
|
+
value = value.decode()
|
|
200
|
+
host, port = value.split(":")
|
|
201
|
+
port = int(port) if port is not None else None
|
|
202
|
+
|
|
203
|
+
return NodeMovingNotification(id, host, port, ttl)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
_INVALIDATION_MESSAGE = "invalidate"
|
|
207
|
+
_MOVING_MESSAGE = "MOVING"
|
|
208
|
+
_MIGRATING_MESSAGE = "MIGRATING"
|
|
209
|
+
_MIGRATED_MESSAGE = "MIGRATED"
|
|
210
|
+
_FAILING_OVER_MESSAGE = "FAILING_OVER"
|
|
211
|
+
_FAILED_OVER_MESSAGE = "FAILED_OVER"
|
|
212
|
+
|
|
213
|
+
_MAINTENANCE_MESSAGES = (
|
|
214
|
+
_MIGRATING_MESSAGE,
|
|
215
|
+
_MIGRATED_MESSAGE,
|
|
216
|
+
_FAILING_OVER_MESSAGE,
|
|
217
|
+
_FAILED_OVER_MESSAGE,
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING: dict[
|
|
221
|
+
str, tuple[type[MaintenanceNotification], Callable]
|
|
222
|
+
] = {
|
|
223
|
+
_MIGRATING_MESSAGE: (
|
|
224
|
+
NodeMigratingNotification,
|
|
225
|
+
MaintenanceNotificationsParser.parse_maintenance_start_msg,
|
|
226
|
+
),
|
|
227
|
+
_MIGRATED_MESSAGE: (
|
|
228
|
+
NodeMigratedNotification,
|
|
229
|
+
MaintenanceNotificationsParser.parse_maintenance_completed_msg,
|
|
230
|
+
),
|
|
231
|
+
_FAILING_OVER_MESSAGE: (
|
|
232
|
+
NodeFailingOverNotification,
|
|
233
|
+
MaintenanceNotificationsParser.parse_maintenance_start_msg,
|
|
234
|
+
),
|
|
235
|
+
_FAILED_OVER_MESSAGE: (
|
|
236
|
+
NodeFailedOverNotification,
|
|
237
|
+
MaintenanceNotificationsParser.parse_maintenance_completed_msg,
|
|
238
|
+
),
|
|
239
|
+
_MOVING_MESSAGE: (
|
|
240
|
+
NodeMovingNotification,
|
|
241
|
+
MaintenanceNotificationsParser.parse_moving_msg,
|
|
242
|
+
),
|
|
243
|
+
}
|
|
162
244
|
|
|
163
245
|
|
|
164
246
|
class PushNotificationsParser(Protocol):
|
|
@@ -166,16 +248,57 @@ class PushNotificationsParser(Protocol):
|
|
|
166
248
|
|
|
167
249
|
pubsub_push_handler_func: Callable
|
|
168
250
|
invalidation_push_handler_func: Optional[Callable] = None
|
|
251
|
+
node_moving_push_handler_func: Optional[Callable] = None
|
|
252
|
+
maintenance_push_handler_func: Optional[Callable] = None
|
|
169
253
|
|
|
170
254
|
def handle_pubsub_push_response(self, response):
|
|
171
255
|
"""Handle pubsub push responses"""
|
|
172
256
|
raise NotImplementedError()
|
|
173
257
|
|
|
174
258
|
def handle_push_response(self, response, **kwargs):
|
|
175
|
-
|
|
259
|
+
msg_type = response[0]
|
|
260
|
+
if isinstance(msg_type, bytes):
|
|
261
|
+
msg_type = msg_type.decode()
|
|
262
|
+
|
|
263
|
+
if msg_type not in (
|
|
264
|
+
_INVALIDATION_MESSAGE,
|
|
265
|
+
*_MAINTENANCE_MESSAGES,
|
|
266
|
+
_MOVING_MESSAGE,
|
|
267
|
+
):
|
|
176
268
|
return self.pubsub_push_handler_func(response)
|
|
177
|
-
|
|
178
|
-
|
|
269
|
+
|
|
270
|
+
try:
|
|
271
|
+
if (
|
|
272
|
+
msg_type == _INVALIDATION_MESSAGE
|
|
273
|
+
and self.invalidation_push_handler_func
|
|
274
|
+
):
|
|
275
|
+
return self.invalidation_push_handler_func(response)
|
|
276
|
+
|
|
277
|
+
if msg_type == _MOVING_MESSAGE and self.node_moving_push_handler_func:
|
|
278
|
+
parser_function = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[
|
|
279
|
+
msg_type
|
|
280
|
+
][1]
|
|
281
|
+
|
|
282
|
+
notification = parser_function(response)
|
|
283
|
+
return self.node_moving_push_handler_func(notification)
|
|
284
|
+
|
|
285
|
+
if msg_type in _MAINTENANCE_MESSAGES and self.maintenance_push_handler_func:
|
|
286
|
+
parser_function = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[
|
|
287
|
+
msg_type
|
|
288
|
+
][1]
|
|
289
|
+
notification_type = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[
|
|
290
|
+
msg_type
|
|
291
|
+
][0]
|
|
292
|
+
notification = parser_function(response, notification_type)
|
|
293
|
+
|
|
294
|
+
if notification is not None:
|
|
295
|
+
return self.maintenance_push_handler_func(notification)
|
|
296
|
+
except Exception as e:
|
|
297
|
+
logger.error(
|
|
298
|
+
"Error handling {} message ({}): {}".format(msg_type, response, e)
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
return None
|
|
179
302
|
|
|
180
303
|
def set_pubsub_push_handler(self, pubsub_push_handler_func):
|
|
181
304
|
self.pubsub_push_handler_func = pubsub_push_handler_func
|
|
@@ -183,12 +306,20 @@ class PushNotificationsParser(Protocol):
|
|
|
183
306
|
def set_invalidation_push_handler(self, invalidation_push_handler_func):
|
|
184
307
|
self.invalidation_push_handler_func = invalidation_push_handler_func
|
|
185
308
|
|
|
309
|
+
def set_node_moving_push_handler(self, node_moving_push_handler_func):
|
|
310
|
+
self.node_moving_push_handler_func = node_moving_push_handler_func
|
|
311
|
+
|
|
312
|
+
def set_maintenance_push_handler(self, maintenance_push_handler_func):
|
|
313
|
+
self.maintenance_push_handler_func = maintenance_push_handler_func
|
|
314
|
+
|
|
186
315
|
|
|
187
316
|
class AsyncPushNotificationsParser(Protocol):
|
|
188
317
|
"""Protocol defining async RESP3-specific parsing functionality"""
|
|
189
318
|
|
|
190
319
|
pubsub_push_handler_func: Callable
|
|
191
320
|
invalidation_push_handler_func: Optional[Callable] = None
|
|
321
|
+
node_moving_push_handler_func: Optional[Callable[..., Awaitable[None]]] = None
|
|
322
|
+
maintenance_push_handler_func: Optional[Callable[..., Awaitable[None]]] = None
|
|
192
323
|
|
|
193
324
|
async def handle_pubsub_push_response(self, response):
|
|
194
325
|
"""Handle pubsub push responses asynchronously"""
|
|
@@ -196,10 +327,52 @@ class AsyncPushNotificationsParser(Protocol):
|
|
|
196
327
|
|
|
197
328
|
async def handle_push_response(self, response, **kwargs):
|
|
198
329
|
"""Handle push responses asynchronously"""
|
|
199
|
-
|
|
330
|
+
|
|
331
|
+
msg_type = response[0]
|
|
332
|
+
if isinstance(msg_type, bytes):
|
|
333
|
+
msg_type = msg_type.decode()
|
|
334
|
+
|
|
335
|
+
if msg_type not in (
|
|
336
|
+
_INVALIDATION_MESSAGE,
|
|
337
|
+
*_MAINTENANCE_MESSAGES,
|
|
338
|
+
_MOVING_MESSAGE,
|
|
339
|
+
):
|
|
200
340
|
return await self.pubsub_push_handler_func(response)
|
|
201
|
-
|
|
202
|
-
|
|
341
|
+
|
|
342
|
+
try:
|
|
343
|
+
if (
|
|
344
|
+
msg_type == _INVALIDATION_MESSAGE
|
|
345
|
+
and self.invalidation_push_handler_func
|
|
346
|
+
):
|
|
347
|
+
return await self.invalidation_push_handler_func(response)
|
|
348
|
+
|
|
349
|
+
if isinstance(msg_type, bytes):
|
|
350
|
+
msg_type = msg_type.decode()
|
|
351
|
+
|
|
352
|
+
if msg_type == _MOVING_MESSAGE and self.node_moving_push_handler_func:
|
|
353
|
+
parser_function = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[
|
|
354
|
+
msg_type
|
|
355
|
+
][1]
|
|
356
|
+
notification = parser_function(response)
|
|
357
|
+
return await self.node_moving_push_handler_func(notification)
|
|
358
|
+
|
|
359
|
+
if msg_type in _MAINTENANCE_MESSAGES and self.maintenance_push_handler_func:
|
|
360
|
+
parser_function = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[
|
|
361
|
+
msg_type
|
|
362
|
+
][1]
|
|
363
|
+
notification_type = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[
|
|
364
|
+
msg_type
|
|
365
|
+
][0]
|
|
366
|
+
notification = parser_function(response, notification_type)
|
|
367
|
+
|
|
368
|
+
if notification is not None:
|
|
369
|
+
return await self.maintenance_push_handler_func(notification)
|
|
370
|
+
except Exception as e:
|
|
371
|
+
logger.error(
|
|
372
|
+
"Error handling {} message ({}): {}".format(msg_type, response, e)
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
return None
|
|
203
376
|
|
|
204
377
|
def set_pubsub_push_handler(self, pubsub_push_handler_func):
|
|
205
378
|
"""Set the pubsub push handler function"""
|
|
@@ -209,6 +382,12 @@ class AsyncPushNotificationsParser(Protocol):
|
|
|
209
382
|
"""Set the invalidation push handler function"""
|
|
210
383
|
self.invalidation_push_handler_func = invalidation_push_handler_func
|
|
211
384
|
|
|
385
|
+
def set_node_moving_push_handler(self, node_moving_push_handler_func):
|
|
386
|
+
self.node_moving_push_handler_func = node_moving_push_handler_func
|
|
387
|
+
|
|
388
|
+
def set_maintenance_push_handler(self, maintenance_push_handler_func):
|
|
389
|
+
self.maintenance_push_handler_func = maintenance_push_handler_func
|
|
390
|
+
|
|
212
391
|
|
|
213
392
|
class _AsyncRESPBase(AsyncBaseParser):
|
|
214
393
|
"""Base class for async resp parsing"""
|
|
@@ -47,6 +47,8 @@ class _HiredisParser(BaseParser, PushNotificationsParser):
|
|
|
47
47
|
self.socket_read_size = socket_read_size
|
|
48
48
|
self._buffer = bytearray(socket_read_size)
|
|
49
49
|
self.pubsub_push_handler_func = self.handle_pubsub_push_response
|
|
50
|
+
self.node_moving_push_handler_func = None
|
|
51
|
+
self.maintenance_push_handler_func = None
|
|
50
52
|
self.invalidation_push_handler_func = None
|
|
51
53
|
self._hiredis_PushNotificationType = None
|
|
52
54
|
|
|
@@ -141,12 +143,15 @@ class _HiredisParser(BaseParser, PushNotificationsParser):
|
|
|
141
143
|
response, self._hiredis_PushNotificationType
|
|
142
144
|
):
|
|
143
145
|
response = self.handle_push_response(response)
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
)
|
|
148
|
-
else:
|
|
146
|
+
|
|
147
|
+
# if this is a push request return the push response
|
|
148
|
+
if push_request:
|
|
149
149
|
return response
|
|
150
|
+
|
|
151
|
+
return self.read_response(
|
|
152
|
+
disable_decoding=disable_decoding,
|
|
153
|
+
push_request=push_request,
|
|
154
|
+
)
|
|
150
155
|
return response
|
|
151
156
|
|
|
152
157
|
if disable_decoding:
|
|
@@ -169,12 +174,13 @@ class _HiredisParser(BaseParser, PushNotificationsParser):
|
|
|
169
174
|
response, self._hiredis_PushNotificationType
|
|
170
175
|
):
|
|
171
176
|
response = self.handle_push_response(response)
|
|
172
|
-
if
|
|
173
|
-
return self.read_response(
|
|
174
|
-
disable_decoding=disable_decoding, push_request=push_request
|
|
175
|
-
)
|
|
176
|
-
else:
|
|
177
|
+
if push_request:
|
|
177
178
|
return response
|
|
179
|
+
return self.read_response(
|
|
180
|
+
disable_decoding=disable_decoding,
|
|
181
|
+
push_request=push_request,
|
|
182
|
+
)
|
|
183
|
+
|
|
178
184
|
elif (
|
|
179
185
|
isinstance(response, list)
|
|
180
186
|
and response
|
|
@@ -18,6 +18,8 @@ class _RESP3Parser(_RESPBase, PushNotificationsParser):
|
|
|
18
18
|
def __init__(self, socket_read_size):
|
|
19
19
|
super().__init__(socket_read_size)
|
|
20
20
|
self.pubsub_push_handler_func = self.handle_pubsub_push_response
|
|
21
|
+
self.node_moving_push_handler_func = None
|
|
22
|
+
self.maintenance_push_handler_func = None
|
|
21
23
|
self.invalidation_push_handler_func = None
|
|
22
24
|
|
|
23
25
|
def handle_pubsub_push_response(self, response):
|
|
@@ -117,17 +119,21 @@ class _RESP3Parser(_RESPBase, PushNotificationsParser):
|
|
|
117
119
|
for _ in range(int(response))
|
|
118
120
|
]
|
|
119
121
|
response = self.handle_push_response(response)
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
)
|
|
124
|
-
else:
|
|
122
|
+
|
|
123
|
+
# if this is a push request return the push response
|
|
124
|
+
if push_request:
|
|
125
125
|
return response
|
|
126
|
+
|
|
127
|
+
return self._read_response(
|
|
128
|
+
disable_decoding=disable_decoding,
|
|
129
|
+
push_request=push_request,
|
|
130
|
+
)
|
|
126
131
|
else:
|
|
127
132
|
raise InvalidResponse(f"Protocol Error: {raw!r}")
|
|
128
133
|
|
|
129
134
|
if isinstance(response, bytes) and disable_decoding is False:
|
|
130
135
|
response = self.encoder.decode(response)
|
|
136
|
+
|
|
131
137
|
return response
|
|
132
138
|
|
|
133
139
|
|
|
@@ -81,10 +81,11 @@ from redis.utils import (
|
|
|
81
81
|
)
|
|
82
82
|
|
|
83
83
|
if TYPE_CHECKING and SSL_AVAILABLE:
|
|
84
|
-
from ssl import TLSVersion, VerifyMode
|
|
84
|
+
from ssl import TLSVersion, VerifyFlags, VerifyMode
|
|
85
85
|
else:
|
|
86
86
|
TLSVersion = None
|
|
87
87
|
VerifyMode = None
|
|
88
|
+
VerifyFlags = None
|
|
88
89
|
|
|
89
90
|
PubSubHandler = Callable[[Dict[str, str]], Awaitable[None]]
|
|
90
91
|
_KeyT = TypeVar("_KeyT", bound=KeyT)
|
|
@@ -238,6 +239,8 @@ class Redis(
|
|
|
238
239
|
ssl_keyfile: Optional[str] = None,
|
|
239
240
|
ssl_certfile: Optional[str] = None,
|
|
240
241
|
ssl_cert_reqs: Union[str, VerifyMode] = "required",
|
|
242
|
+
ssl_include_verify_flags: Optional[List[VerifyFlags]] = None,
|
|
243
|
+
ssl_exclude_verify_flags: Optional[List[VerifyFlags]] = None,
|
|
241
244
|
ssl_ca_certs: Optional[str] = None,
|
|
242
245
|
ssl_ca_data: Optional[str] = None,
|
|
243
246
|
ssl_check_hostname: bool = True,
|
|
@@ -347,6 +350,8 @@ class Redis(
|
|
|
347
350
|
"ssl_keyfile": ssl_keyfile,
|
|
348
351
|
"ssl_certfile": ssl_certfile,
|
|
349
352
|
"ssl_cert_reqs": ssl_cert_reqs,
|
|
353
|
+
"ssl_include_verify_flags": ssl_include_verify_flags,
|
|
354
|
+
"ssl_exclude_verify_flags": ssl_exclude_verify_flags,
|
|
350
355
|
"ssl_ca_certs": ssl_ca_certs,
|
|
351
356
|
"ssl_ca_data": ssl_ca_data,
|
|
352
357
|
"ssl_check_hostname": ssl_check_hostname,
|
|
@@ -387,6 +392,12 @@ class Redis(
|
|
|
387
392
|
# on a set of redis commands
|
|
388
393
|
self._single_conn_lock = asyncio.Lock()
|
|
389
394
|
|
|
395
|
+
# When used as an async context manager, we need to increment and decrement
|
|
396
|
+
# a usage counter so that we can close the connection pool when no one is
|
|
397
|
+
# using the client.
|
|
398
|
+
self._usage_counter = 0
|
|
399
|
+
self._usage_lock = asyncio.Lock()
|
|
400
|
+
|
|
390
401
|
def __repr__(self):
|
|
391
402
|
return (
|
|
392
403
|
f"<{self.__class__.__module__}.{self.__class__.__name__}"
|
|
@@ -594,10 +605,47 @@ class Redis(
|
|
|
594
605
|
)
|
|
595
606
|
|
|
596
607
|
async def __aenter__(self: _RedisT) -> _RedisT:
|
|
597
|
-
|
|
608
|
+
"""
|
|
609
|
+
Async context manager entry. Increments a usage counter so that the
|
|
610
|
+
connection pool is only closed (via aclose()) when no context is using
|
|
611
|
+
the client.
|
|
612
|
+
"""
|
|
613
|
+
await self._increment_usage()
|
|
614
|
+
try:
|
|
615
|
+
# Initialize the client (i.e. establish connection, etc.)
|
|
616
|
+
return await self.initialize()
|
|
617
|
+
except Exception:
|
|
618
|
+
# If initialization fails, decrement the counter to keep it in sync
|
|
619
|
+
await self._decrement_usage()
|
|
620
|
+
raise
|
|
621
|
+
|
|
622
|
+
async def _increment_usage(self) -> int:
|
|
623
|
+
"""
|
|
624
|
+
Helper coroutine to increment the usage counter while holding the lock.
|
|
625
|
+
Returns the new value of the usage counter.
|
|
626
|
+
"""
|
|
627
|
+
async with self._usage_lock:
|
|
628
|
+
self._usage_counter += 1
|
|
629
|
+
return self._usage_counter
|
|
630
|
+
|
|
631
|
+
async def _decrement_usage(self) -> int:
|
|
632
|
+
"""
|
|
633
|
+
Helper coroutine to decrement the usage counter while holding the lock.
|
|
634
|
+
Returns the new value of the usage counter.
|
|
635
|
+
"""
|
|
636
|
+
async with self._usage_lock:
|
|
637
|
+
self._usage_counter -= 1
|
|
638
|
+
return self._usage_counter
|
|
598
639
|
|
|
599
640
|
async def __aexit__(self, exc_type, exc_value, traceback):
|
|
600
|
-
|
|
641
|
+
"""
|
|
642
|
+
Async context manager exit. Decrements a usage counter. If this is the
|
|
643
|
+
last exit (counter becomes zero), the client closes its connection pool.
|
|
644
|
+
"""
|
|
645
|
+
current_usage = await asyncio.shield(self._decrement_usage())
|
|
646
|
+
if current_usage == 0:
|
|
647
|
+
# This was the last active context, so disconnect the pool.
|
|
648
|
+
await asyncio.shield(self.aclose())
|
|
601
649
|
|
|
602
650
|
_DEL_MESSAGE = "Unclosed Redis client"
|
|
603
651
|
|
|
@@ -86,10 +86,11 @@ from redis.utils import (
|
|
|
86
86
|
)
|
|
87
87
|
|
|
88
88
|
if SSL_AVAILABLE:
|
|
89
|
-
from ssl import TLSVersion, VerifyMode
|
|
89
|
+
from ssl import TLSVersion, VerifyFlags, VerifyMode
|
|
90
90
|
else:
|
|
91
91
|
TLSVersion = None
|
|
92
92
|
VerifyMode = None
|
|
93
|
+
VerifyFlags = None
|
|
93
94
|
|
|
94
95
|
TargetNodesT = TypeVar(
|
|
95
96
|
"TargetNodesT", str, "ClusterNode", List["ClusterNode"], Dict[Any, "ClusterNode"]
|
|
@@ -299,6 +300,8 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
|
|
|
299
300
|
ssl_ca_certs: Optional[str] = None,
|
|
300
301
|
ssl_ca_data: Optional[str] = None,
|
|
301
302
|
ssl_cert_reqs: Union[str, VerifyMode] = "required",
|
|
303
|
+
ssl_include_verify_flags: Optional[List[VerifyFlags]] = None,
|
|
304
|
+
ssl_exclude_verify_flags: Optional[List[VerifyFlags]] = None,
|
|
302
305
|
ssl_certfile: Optional[str] = None,
|
|
303
306
|
ssl_check_hostname: bool = True,
|
|
304
307
|
ssl_keyfile: Optional[str] = None,
|
|
@@ -358,6 +361,8 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
|
|
|
358
361
|
"ssl_ca_certs": ssl_ca_certs,
|
|
359
362
|
"ssl_ca_data": ssl_ca_data,
|
|
360
363
|
"ssl_cert_reqs": ssl_cert_reqs,
|
|
364
|
+
"ssl_include_verify_flags": ssl_include_verify_flags,
|
|
365
|
+
"ssl_exclude_verify_flags": ssl_exclude_verify_flags,
|
|
361
366
|
"ssl_certfile": ssl_certfile,
|
|
362
367
|
"ssl_check_hostname": ssl_check_hostname,
|
|
363
368
|
"ssl_keyfile": ssl_keyfile,
|
|
@@ -431,6 +436,12 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
|
|
|
431
436
|
self._initialize = True
|
|
432
437
|
self._lock: Optional[asyncio.Lock] = None
|
|
433
438
|
|
|
439
|
+
# When used as an async context manager, we need to increment and decrement
|
|
440
|
+
# a usage counter so that we can close the connection pool when no one is
|
|
441
|
+
# using the client.
|
|
442
|
+
self._usage_counter = 0
|
|
443
|
+
self._usage_lock = asyncio.Lock()
|
|
444
|
+
|
|
434
445
|
async def initialize(self) -> "RedisCluster":
|
|
435
446
|
"""Get all nodes from startup nodes & creates connections if not initialized."""
|
|
436
447
|
if self._initialize:
|
|
@@ -467,10 +478,47 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
|
|
|
467
478
|
await self.aclose()
|
|
468
479
|
|
|
469
480
|
async def __aenter__(self) -> "RedisCluster":
|
|
470
|
-
|
|
481
|
+
"""
|
|
482
|
+
Async context manager entry. Increments a usage counter so that the
|
|
483
|
+
connection pool is only closed (via aclose()) when no context is using
|
|
484
|
+
the client.
|
|
485
|
+
"""
|
|
486
|
+
await self._increment_usage()
|
|
487
|
+
try:
|
|
488
|
+
# Initialize the client (i.e. establish connection, etc.)
|
|
489
|
+
return await self.initialize()
|
|
490
|
+
except Exception:
|
|
491
|
+
# If initialization fails, decrement the counter to keep it in sync
|
|
492
|
+
await self._decrement_usage()
|
|
493
|
+
raise
|
|
471
494
|
|
|
472
|
-
async def
|
|
473
|
-
|
|
495
|
+
async def _increment_usage(self) -> int:
|
|
496
|
+
"""
|
|
497
|
+
Helper coroutine to increment the usage counter while holding the lock.
|
|
498
|
+
Returns the new value of the usage counter.
|
|
499
|
+
"""
|
|
500
|
+
async with self._usage_lock:
|
|
501
|
+
self._usage_counter += 1
|
|
502
|
+
return self._usage_counter
|
|
503
|
+
|
|
504
|
+
async def _decrement_usage(self) -> int:
|
|
505
|
+
"""
|
|
506
|
+
Helper coroutine to decrement the usage counter while holding the lock.
|
|
507
|
+
Returns the new value of the usage counter.
|
|
508
|
+
"""
|
|
509
|
+
async with self._usage_lock:
|
|
510
|
+
self._usage_counter -= 1
|
|
511
|
+
return self._usage_counter
|
|
512
|
+
|
|
513
|
+
async def __aexit__(self, exc_type, exc_value, traceback):
|
|
514
|
+
"""
|
|
515
|
+
Async context manager exit. Decrements a usage counter. If this is the
|
|
516
|
+
last exit (counter becomes zero), the client closes its connection pool.
|
|
517
|
+
"""
|
|
518
|
+
current_usage = await asyncio.shield(self._decrement_usage())
|
|
519
|
+
if current_usage == 0:
|
|
520
|
+
# This was the last active context, so disconnect the pool.
|
|
521
|
+
await asyncio.shield(self.aclose())
|
|
474
522
|
|
|
475
523
|
def __await__(self) -> Generator[Any, None, "RedisCluster"]:
|
|
476
524
|
return self.initialize().__await__()
|
|
@@ -30,11 +30,12 @@ from ..utils import SSL_AVAILABLE
|
|
|
30
30
|
|
|
31
31
|
if SSL_AVAILABLE:
|
|
32
32
|
import ssl
|
|
33
|
-
from ssl import SSLContext, TLSVersion
|
|
33
|
+
from ssl import SSLContext, TLSVersion, VerifyFlags
|
|
34
34
|
else:
|
|
35
35
|
ssl = None
|
|
36
36
|
TLSVersion = None
|
|
37
37
|
SSLContext = None
|
|
38
|
+
VerifyFlags = None
|
|
38
39
|
|
|
39
40
|
from ..auth.token import TokenInterface
|
|
40
41
|
from ..event import AsyncAfterConnectionReleasedEvent, EventDispatcher
|
|
@@ -793,6 +794,8 @@ class SSLConnection(Connection):
|
|
|
793
794
|
ssl_keyfile: Optional[str] = None,
|
|
794
795
|
ssl_certfile: Optional[str] = None,
|
|
795
796
|
ssl_cert_reqs: Union[str, ssl.VerifyMode] = "required",
|
|
797
|
+
ssl_include_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
|
|
798
|
+
ssl_exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
|
|
796
799
|
ssl_ca_certs: Optional[str] = None,
|
|
797
800
|
ssl_ca_data: Optional[str] = None,
|
|
798
801
|
ssl_check_hostname: bool = True,
|
|
@@ -807,6 +810,8 @@ class SSLConnection(Connection):
|
|
|
807
810
|
keyfile=ssl_keyfile,
|
|
808
811
|
certfile=ssl_certfile,
|
|
809
812
|
cert_reqs=ssl_cert_reqs,
|
|
813
|
+
include_verify_flags=ssl_include_verify_flags,
|
|
814
|
+
exclude_verify_flags=ssl_exclude_verify_flags,
|
|
810
815
|
ca_certs=ssl_ca_certs,
|
|
811
816
|
ca_data=ssl_ca_data,
|
|
812
817
|
check_hostname=ssl_check_hostname,
|
|
@@ -832,6 +837,14 @@ class SSLConnection(Connection):
|
|
|
832
837
|
def cert_reqs(self):
|
|
833
838
|
return self.ssl_context.cert_reqs
|
|
834
839
|
|
|
840
|
+
@property
|
|
841
|
+
def include_verify_flags(self):
|
|
842
|
+
return self.ssl_context.include_verify_flags
|
|
843
|
+
|
|
844
|
+
@property
|
|
845
|
+
def exclude_verify_flags(self):
|
|
846
|
+
return self.ssl_context.exclude_verify_flags
|
|
847
|
+
|
|
835
848
|
@property
|
|
836
849
|
def ca_certs(self):
|
|
837
850
|
return self.ssl_context.ca_certs
|
|
@@ -854,6 +867,8 @@ class RedisSSLContext:
|
|
|
854
867
|
"keyfile",
|
|
855
868
|
"certfile",
|
|
856
869
|
"cert_reqs",
|
|
870
|
+
"include_verify_flags",
|
|
871
|
+
"exclude_verify_flags",
|
|
857
872
|
"ca_certs",
|
|
858
873
|
"ca_data",
|
|
859
874
|
"context",
|
|
@@ -867,6 +882,8 @@ class RedisSSLContext:
|
|
|
867
882
|
keyfile: Optional[str] = None,
|
|
868
883
|
certfile: Optional[str] = None,
|
|
869
884
|
cert_reqs: Optional[Union[str, ssl.VerifyMode]] = None,
|
|
885
|
+
include_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
|
|
886
|
+
exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
|
|
870
887
|
ca_certs: Optional[str] = None,
|
|
871
888
|
ca_data: Optional[str] = None,
|
|
872
889
|
check_hostname: bool = False,
|
|
@@ -892,6 +909,8 @@ class RedisSSLContext:
|
|
|
892
909
|
)
|
|
893
910
|
cert_reqs = CERT_REQS[cert_reqs]
|
|
894
911
|
self.cert_reqs = cert_reqs
|
|
912
|
+
self.include_verify_flags = include_verify_flags
|
|
913
|
+
self.exclude_verify_flags = exclude_verify_flags
|
|
895
914
|
self.ca_certs = ca_certs
|
|
896
915
|
self.ca_data = ca_data
|
|
897
916
|
self.check_hostname = (
|
|
@@ -906,6 +925,12 @@ class RedisSSLContext:
|
|
|
906
925
|
context = ssl.create_default_context()
|
|
907
926
|
context.check_hostname = self.check_hostname
|
|
908
927
|
context.verify_mode = self.cert_reqs
|
|
928
|
+
if self.include_verify_flags:
|
|
929
|
+
for flag in self.include_verify_flags:
|
|
930
|
+
context.verify_flags |= flag
|
|
931
|
+
if self.exclude_verify_flags:
|
|
932
|
+
for flag in self.exclude_verify_flags:
|
|
933
|
+
context.verify_flags &= ~flag
|
|
909
934
|
if self.certfile and self.keyfile:
|
|
910
935
|
context.load_cert_chain(certfile=self.certfile, keyfile=self.keyfile)
|
|
911
936
|
if self.ca_certs or self.ca_data:
|
|
@@ -953,6 +978,20 @@ def to_bool(value) -> Optional[bool]:
|
|
|
953
978
|
return bool(value)
|
|
954
979
|
|
|
955
980
|
|
|
981
|
+
def parse_ssl_verify_flags(value):
|
|
982
|
+
# flags are passed in as a string representation of a list,
|
|
983
|
+
# e.g. VERIFY_X509_STRICT, VERIFY_X509_PARTIAL_CHAIN
|
|
984
|
+
verify_flags_str = value.replace("[", "").replace("]", "")
|
|
985
|
+
|
|
986
|
+
verify_flags = []
|
|
987
|
+
for flag in verify_flags_str.split(","):
|
|
988
|
+
flag = flag.strip()
|
|
989
|
+
if not hasattr(VerifyFlags, flag):
|
|
990
|
+
raise ValueError(f"Invalid ssl verify flag: {flag}")
|
|
991
|
+
verify_flags.append(getattr(VerifyFlags, flag))
|
|
992
|
+
return verify_flags
|
|
993
|
+
|
|
994
|
+
|
|
956
995
|
URL_QUERY_ARGUMENT_PARSERS: Mapping[str, Callable[..., object]] = MappingProxyType(
|
|
957
996
|
{
|
|
958
997
|
"db": int,
|
|
@@ -963,6 +1002,8 @@ URL_QUERY_ARGUMENT_PARSERS: Mapping[str, Callable[..., object]] = MappingProxyTy
|
|
|
963
1002
|
"max_connections": int,
|
|
964
1003
|
"health_check_interval": int,
|
|
965
1004
|
"ssl_check_hostname": to_bool,
|
|
1005
|
+
"ssl_include_verify_flags": parse_ssl_verify_flags,
|
|
1006
|
+
"ssl_exclude_verify_flags": parse_ssl_verify_flags,
|
|
966
1007
|
"timeout": float,
|
|
967
1008
|
}
|
|
968
1009
|
)
|
|
@@ -1021,6 +1062,7 @@ def parse_url(url: str) -> ConnectKwargs:
|
|
|
1021
1062
|
|
|
1022
1063
|
if parsed.scheme == "rediss":
|
|
1023
1064
|
kwargs["connection_class"] = SSLConnection
|
|
1065
|
+
|
|
1024
1066
|
else:
|
|
1025
1067
|
valid_schemes = "redis://, rediss://, unix://"
|
|
1026
1068
|
raise ValueError(
|