redis 6.4.0__tar.gz → 7.0.0b2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (165) hide show
  1. {redis-6.4.0 → redis-7.0.0b2}/PKG-INFO +1 -1
  2. {redis-6.4.0 → redis-7.0.0b2}/redis/__init__.py +1 -1
  3. {redis-6.4.0 → redis-7.0.0b2}/redis/_parsers/base.py +187 -8
  4. {redis-6.4.0 → redis-7.0.0b2}/redis/_parsers/hiredis.py +16 -10
  5. {redis-6.4.0 → redis-7.0.0b2}/redis/_parsers/resp3.py +11 -5
  6. {redis-6.4.0 → redis-7.0.0b2}/redis/asyncio/client.py +51 -3
  7. {redis-6.4.0 → redis-7.0.0b2}/redis/asyncio/cluster.py +52 -4
  8. {redis-6.4.0 → redis-7.0.0b2}/redis/asyncio/connection.py +43 -1
  9. {redis-6.4.0 → redis-7.0.0b2}/redis/cache.py +1 -0
  10. {redis-6.4.0 → redis-7.0.0b2}/redis/client.py +72 -13
  11. {redis-6.4.0 → redis-7.0.0b2}/redis/cluster.py +5 -2
  12. {redis-6.4.0 → redis-7.0.0b2}/redis/commands/core.py +285 -285
  13. {redis-6.4.0 → redis-7.0.0b2}/redis/commands/helpers.py +0 -20
  14. {redis-6.4.0 → redis-7.0.0b2}/redis/commands/search/query.py +12 -12
  15. {redis-6.4.0 → redis-7.0.0b2}/redis/commands/vectorset/commands.py +43 -25
  16. {redis-6.4.0 → redis-7.0.0b2}/redis/commands/vectorset/utils.py +40 -4
  17. {redis-6.4.0 → redis-7.0.0b2}/redis/connection.py +884 -60
  18. redis-7.0.0b2/redis/maint_notifications.py +799 -0
  19. redis-7.0.0b2/tests/test_asyncio/test_ssl.py +143 -0
  20. redis-7.0.0b2/tests/test_asyncio/test_usage_counter.py +16 -0
  21. {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/test_vsets.py +113 -1
  22. {redis-6.4.0 → redis-7.0.0b2}/tests/test_cluster.py +16 -0
  23. {redis-6.4.0 → redis-7.0.0b2}/tests/test_connection_pool.py +71 -3
  24. redis-7.0.0b2/tests/test_maint_notifications.py +893 -0
  25. redis-7.0.0b2/tests/test_maint_notifications_handling.py +2226 -0
  26. redis-7.0.0b2/tests/test_scenario/__init__.py +0 -0
  27. redis-7.0.0b2/tests/test_scenario/conftest.py +125 -0
  28. redis-7.0.0b2/tests/test_scenario/fault_injector_client.py +150 -0
  29. redis-7.0.0b2/tests/test_scenario/maint_notifications_helpers.py +326 -0
  30. redis-7.0.0b2/tests/test_scenario/test_maint_notifications.py +1111 -0
  31. {redis-6.4.0 → redis-7.0.0b2}/tests/test_ssl.py +97 -0
  32. {redis-6.4.0 → redis-7.0.0b2}/tests/test_vsets.py +111 -1
  33. redis-6.4.0/tests/test_asyncio/test_ssl.py +0 -56
  34. {redis-6.4.0 → redis-7.0.0b2}/.gitignore +0 -0
  35. {redis-6.4.0 → redis-7.0.0b2}/LICENSE +0 -0
  36. {redis-6.4.0 → redis-7.0.0b2}/README.md +0 -0
  37. {redis-6.4.0 → redis-7.0.0b2}/dev_requirements.txt +0 -0
  38. {redis-6.4.0 → redis-7.0.0b2}/pyproject.toml +0 -0
  39. {redis-6.4.0 → redis-7.0.0b2}/redis/_parsers/__init__.py +0 -0
  40. {redis-6.4.0 → redis-7.0.0b2}/redis/_parsers/commands.py +0 -0
  41. {redis-6.4.0 → redis-7.0.0b2}/redis/_parsers/encoders.py +0 -0
  42. {redis-6.4.0 → redis-7.0.0b2}/redis/_parsers/helpers.py +0 -0
  43. {redis-6.4.0 → redis-7.0.0b2}/redis/_parsers/resp2.py +0 -0
  44. {redis-6.4.0 → redis-7.0.0b2}/redis/_parsers/socket.py +0 -0
  45. {redis-6.4.0 → redis-7.0.0b2}/redis/asyncio/__init__.py +0 -0
  46. {redis-6.4.0 → redis-7.0.0b2}/redis/asyncio/lock.py +0 -0
  47. {redis-6.4.0 → redis-7.0.0b2}/redis/asyncio/retry.py +0 -0
  48. {redis-6.4.0 → redis-7.0.0b2}/redis/asyncio/sentinel.py +0 -0
  49. {redis-6.4.0 → redis-7.0.0b2}/redis/asyncio/utils.py +0 -0
  50. {redis-6.4.0 → redis-7.0.0b2}/redis/auth/__init__.py +0 -0
  51. {redis-6.4.0 → redis-7.0.0b2}/redis/auth/err.py +0 -0
  52. {redis-6.4.0 → redis-7.0.0b2}/redis/auth/idp.py +0 -0
  53. {redis-6.4.0 → redis-7.0.0b2}/redis/auth/token.py +0 -0
  54. {redis-6.4.0 → redis-7.0.0b2}/redis/auth/token_manager.py +0 -0
  55. {redis-6.4.0 → redis-7.0.0b2}/redis/backoff.py +0 -0
  56. {redis-6.4.0 → redis-7.0.0b2}/redis/commands/__init__.py +0 -0
  57. {redis-6.4.0 → redis-7.0.0b2}/redis/commands/bf/__init__.py +0 -0
  58. {redis-6.4.0 → redis-7.0.0b2}/redis/commands/bf/commands.py +0 -0
  59. {redis-6.4.0 → redis-7.0.0b2}/redis/commands/bf/info.py +0 -0
  60. {redis-6.4.0 → redis-7.0.0b2}/redis/commands/cluster.py +0 -0
  61. {redis-6.4.0 → redis-7.0.0b2}/redis/commands/json/__init__.py +0 -0
  62. {redis-6.4.0 → redis-7.0.0b2}/redis/commands/json/_util.py +0 -0
  63. {redis-6.4.0 → redis-7.0.0b2}/redis/commands/json/commands.py +0 -0
  64. {redis-6.4.0 → redis-7.0.0b2}/redis/commands/json/decoders.py +0 -0
  65. {redis-6.4.0 → redis-7.0.0b2}/redis/commands/json/path.py +0 -0
  66. {redis-6.4.0 → redis-7.0.0b2}/redis/commands/redismodules.py +0 -0
  67. {redis-6.4.0 → redis-7.0.0b2}/redis/commands/search/__init__.py +0 -0
  68. {redis-6.4.0 → redis-7.0.0b2}/redis/commands/search/_util.py +0 -0
  69. {redis-6.4.0 → redis-7.0.0b2}/redis/commands/search/aggregation.py +0 -0
  70. {redis-6.4.0 → redis-7.0.0b2}/redis/commands/search/commands.py +0 -0
  71. {redis-6.4.0 → redis-7.0.0b2}/redis/commands/search/dialect.py +0 -0
  72. {redis-6.4.0 → redis-7.0.0b2}/redis/commands/search/document.py +0 -0
  73. {redis-6.4.0 → redis-7.0.0b2}/redis/commands/search/field.py +0 -0
  74. {redis-6.4.0 → redis-7.0.0b2}/redis/commands/search/index_definition.py +0 -0
  75. {redis-6.4.0 → redis-7.0.0b2}/redis/commands/search/profile_information.py +0 -0
  76. {redis-6.4.0 → redis-7.0.0b2}/redis/commands/search/querystring.py +0 -0
  77. {redis-6.4.0 → redis-7.0.0b2}/redis/commands/search/reducers.py +0 -0
  78. {redis-6.4.0 → redis-7.0.0b2}/redis/commands/search/result.py +0 -0
  79. {redis-6.4.0 → redis-7.0.0b2}/redis/commands/search/suggestion.py +0 -0
  80. {redis-6.4.0 → redis-7.0.0b2}/redis/commands/sentinel.py +0 -0
  81. {redis-6.4.0 → redis-7.0.0b2}/redis/commands/timeseries/__init__.py +0 -0
  82. {redis-6.4.0 → redis-7.0.0b2}/redis/commands/timeseries/commands.py +0 -0
  83. {redis-6.4.0 → redis-7.0.0b2}/redis/commands/timeseries/info.py +0 -0
  84. {redis-6.4.0 → redis-7.0.0b2}/redis/commands/timeseries/utils.py +0 -0
  85. {redis-6.4.0 → redis-7.0.0b2}/redis/commands/vectorset/__init__.py +1 -1
  86. {redis-6.4.0 → redis-7.0.0b2}/redis/crc.py +0 -0
  87. {redis-6.4.0 → redis-7.0.0b2}/redis/credentials.py +0 -0
  88. {redis-6.4.0 → redis-7.0.0b2}/redis/event.py +0 -0
  89. {redis-6.4.0 → redis-7.0.0b2}/redis/exceptions.py +0 -0
  90. {redis-6.4.0 → redis-7.0.0b2}/redis/lock.py +0 -0
  91. {redis-6.4.0 → redis-7.0.0b2}/redis/ocsp.py +0 -0
  92. {redis-6.4.0 → redis-7.0.0b2}/redis/py.typed +0 -0
  93. {redis-6.4.0 → redis-7.0.0b2}/redis/retry.py +0 -0
  94. {redis-6.4.0 → redis-7.0.0b2}/redis/sentinel.py +0 -0
  95. {redis-6.4.0 → redis-7.0.0b2}/redis/typing.py +0 -0
  96. {redis-6.4.0 → redis-7.0.0b2}/redis/utils.py +0 -0
  97. {redis-6.4.0 → redis-7.0.0b2}/tests/__init__.py +0 -0
  98. {redis-6.4.0 → redis-7.0.0b2}/tests/conftest.py +0 -0
  99. {redis-6.4.0 → redis-7.0.0b2}/tests/entraid_utils.py +0 -0
  100. {redis-6.4.0 → redis-7.0.0b2}/tests/mocks.py +0 -0
  101. {redis-6.4.0 → redis-7.0.0b2}/tests/ssl_utils.py +0 -0
  102. {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/__init__.py +0 -0
  103. {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/compat.py +0 -0
  104. {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/conftest.py +0 -0
  105. {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/mocks.py +0 -0
  106. {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/test_bloom.py +0 -0
  107. {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/test_cluster.py +0 -0
  108. {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/test_cluster_transaction.py +0 -0
  109. {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/test_commands.py +0 -0
  110. {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/test_connect.py +0 -0
  111. {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/test_connection.py +0 -0
  112. {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/test_connection_pool.py +0 -0
  113. {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/test_credentials.py +0 -0
  114. {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/test_cwe_404.py +0 -0
  115. {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/test_encoding.py +0 -0
  116. {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/test_hash.py +0 -0
  117. {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/test_json.py +0 -0
  118. {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/test_lock.py +0 -0
  119. {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/test_monitor.py +0 -0
  120. {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/test_pipeline.py +0 -0
  121. {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/test_pubsub.py +0 -0
  122. {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/test_retry.py +0 -0
  123. {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/test_scripting.py +0 -0
  124. {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/test_search.py +0 -0
  125. {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/test_sentinel.py +0 -0
  126. {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/test_sentinel_managed_connection.py +0 -0
  127. {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/test_timeseries.py +0 -0
  128. {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/test_utils.py +0 -0
  129. {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/testdata/jsontestdata.py +0 -0
  130. {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/testdata/titles.csv +0 -0
  131. {redis-6.4.0 → redis-7.0.0b2}/tests/test_asyncio/testdata/will_play_text.csv.bz2 +0 -0
  132. {redis-6.4.0 → redis-7.0.0b2}/tests/test_auth/__init__.py +0 -0
  133. {redis-6.4.0 → redis-7.0.0b2}/tests/test_auth/test_token.py +0 -0
  134. {redis-6.4.0 → redis-7.0.0b2}/tests/test_auth/test_token_manager.py +0 -0
  135. {redis-6.4.0 → redis-7.0.0b2}/tests/test_backoff.py +0 -0
  136. {redis-6.4.0 → redis-7.0.0b2}/tests/test_bloom.py +0 -0
  137. {redis-6.4.0 → redis-7.0.0b2}/tests/test_cache.py +0 -0
  138. {redis-6.4.0 → redis-7.0.0b2}/tests/test_cluster_transaction.py +0 -0
  139. {redis-6.4.0 → redis-7.0.0b2}/tests/test_command_parser.py +0 -0
  140. {redis-6.4.0 → redis-7.0.0b2}/tests/test_commands.py +0 -0
  141. {redis-6.4.0 → redis-7.0.0b2}/tests/test_connect.py +0 -0
  142. {redis-6.4.0 → redis-7.0.0b2}/tests/test_connection.py +0 -0
  143. {redis-6.4.0 → redis-7.0.0b2}/tests/test_credentials.py +0 -0
  144. {redis-6.4.0 → redis-7.0.0b2}/tests/test_encoding.py +0 -0
  145. {redis-6.4.0 → redis-7.0.0b2}/tests/test_function.py +0 -0
  146. {redis-6.4.0 → redis-7.0.0b2}/tests/test_hash.py +0 -0
  147. {redis-6.4.0 → redis-7.0.0b2}/tests/test_helpers.py +0 -0
  148. {redis-6.4.0 → redis-7.0.0b2}/tests/test_json.py +0 -0
  149. {redis-6.4.0 → redis-7.0.0b2}/tests/test_lock.py +0 -0
  150. {redis-6.4.0 → redis-7.0.0b2}/tests/test_max_connections_error.py +0 -0
  151. {redis-6.4.0 → redis-7.0.0b2}/tests/test_monitor.py +0 -0
  152. {redis-6.4.0 → redis-7.0.0b2}/tests/test_multiprocessing.py +0 -0
  153. {redis-6.4.0 → redis-7.0.0b2}/tests/test_parsers/test_helpers.py +0 -0
  154. {redis-6.4.0 → redis-7.0.0b2}/tests/test_pipeline.py +0 -0
  155. {redis-6.4.0 → redis-7.0.0b2}/tests/test_pubsub.py +0 -0
  156. {redis-6.4.0 → redis-7.0.0b2}/tests/test_retry.py +0 -0
  157. {redis-6.4.0 → redis-7.0.0b2}/tests/test_scripting.py +0 -0
  158. {redis-6.4.0 → redis-7.0.0b2}/tests/test_search.py +0 -0
  159. {redis-6.4.0 → redis-7.0.0b2}/tests/test_sentinel.py +0 -0
  160. {redis-6.4.0 → redis-7.0.0b2}/tests/test_sentinel_managed_connection.py +0 -0
  161. {redis-6.4.0 → redis-7.0.0b2}/tests/test_timeseries.py +0 -0
  162. {redis-6.4.0 → redis-7.0.0b2}/tests/test_utils.py +0 -0
  163. {redis-6.4.0 → redis-7.0.0b2}/tests/testdata/jsontestdata.py +0 -0
  164. {redis-6.4.0 → redis-7.0.0b2}/tests/testdata/titles.csv +0 -0
  165. {redis-6.4.0 → redis-7.0.0b2}/tests/testdata/will_play_text.csv.bz2 +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: redis
3
- Version: 6.4.0
3
+ Version: 7.0.0b2
4
4
  Summary: Python client for Redis database and key-value store
5
5
  Project-URL: Changes, https://github.com/redis/redis-py/releases
6
6
  Project-URL: Code, https://github.com/redis/redis-py
@@ -46,7 +46,7 @@ def int_or_str(value):
46
46
  return value
47
47
 
48
48
 
49
- __version__ = "6.4.0"
49
+ __version__ = "7.0.0b2"
50
50
  VERSION = tuple(map(int_or_str, __version__.split(".")))
51
51
 
52
52
 
@@ -1,7 +1,17 @@
1
+ import logging
1
2
  import sys
2
3
  from abc import ABC
3
4
  from asyncio import IncompleteReadError, StreamReader, TimeoutError
4
- from typing import Callable, List, Optional, Protocol, Union
5
+ from typing import Awaitable, Callable, List, Optional, Protocol, Union
6
+
7
+ from redis.maint_notifications import (
8
+ MaintenanceNotification,
9
+ NodeFailedOverNotification,
10
+ NodeFailingOverNotification,
11
+ NodeMigratedNotification,
12
+ NodeMigratingNotification,
13
+ NodeMovingNotification,
14
+ )
5
15
 
6
16
  if sys.version_info.major >= 3 and sys.version_info.minor >= 11:
7
17
  from asyncio import timeout as async_timeout
@@ -50,6 +60,8 @@ NO_AUTH_SET_ERROR = {
50
60
  "Client sent AUTH, but no password is set": AuthenticationError,
51
61
  }
52
62
 
63
+ logger = logging.getLogger(__name__)
64
+
53
65
 
54
66
  class BaseParser(ABC):
55
67
  EXCEPTION_CLASSES = {
@@ -158,7 +170,77 @@ class AsyncBaseParser(BaseParser):
158
170
  raise NotImplementedError()
159
171
 
160
172
 
161
- _INVALIDATION_MESSAGE = [b"invalidate", "invalidate"]
173
+ class MaintenanceNotificationsParser:
174
+ """Protocol defining maintenance push notification parsing functionality"""
175
+
176
+ @staticmethod
177
+ def parse_maintenance_start_msg(response, notification_type):
178
+ # Expected message format is: <notification_type> <seq_number> <time>
179
+ id = response[1]
180
+ ttl = response[2]
181
+ return notification_type(id, ttl)
182
+
183
+ @staticmethod
184
+ def parse_maintenance_completed_msg(response, notification_type):
185
+ # Expected message format is: <notification_type> <seq_number>
186
+ id = response[1]
187
+ return notification_type(id)
188
+
189
+ @staticmethod
190
+ def parse_moving_msg(response):
191
+ # Expected message format is: MOVING <seq_number> <time> <endpoint>
192
+ id = response[1]
193
+ ttl = response[2]
194
+ if response[3] is None:
195
+ host, port = None, None
196
+ else:
197
+ value = response[3]
198
+ if isinstance(value, bytes):
199
+ value = value.decode()
200
+ host, port = value.split(":")
201
+ port = int(port) if port is not None else None
202
+
203
+ return NodeMovingNotification(id, host, port, ttl)
204
+
205
+
206
+ _INVALIDATION_MESSAGE = "invalidate"
207
+ _MOVING_MESSAGE = "MOVING"
208
+ _MIGRATING_MESSAGE = "MIGRATING"
209
+ _MIGRATED_MESSAGE = "MIGRATED"
210
+ _FAILING_OVER_MESSAGE = "FAILING_OVER"
211
+ _FAILED_OVER_MESSAGE = "FAILED_OVER"
212
+
213
+ _MAINTENANCE_MESSAGES = (
214
+ _MIGRATING_MESSAGE,
215
+ _MIGRATED_MESSAGE,
216
+ _FAILING_OVER_MESSAGE,
217
+ _FAILED_OVER_MESSAGE,
218
+ )
219
+
220
+ MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING: dict[
221
+ str, tuple[type[MaintenanceNotification], Callable]
222
+ ] = {
223
+ _MIGRATING_MESSAGE: (
224
+ NodeMigratingNotification,
225
+ MaintenanceNotificationsParser.parse_maintenance_start_msg,
226
+ ),
227
+ _MIGRATED_MESSAGE: (
228
+ NodeMigratedNotification,
229
+ MaintenanceNotificationsParser.parse_maintenance_completed_msg,
230
+ ),
231
+ _FAILING_OVER_MESSAGE: (
232
+ NodeFailingOverNotification,
233
+ MaintenanceNotificationsParser.parse_maintenance_start_msg,
234
+ ),
235
+ _FAILED_OVER_MESSAGE: (
236
+ NodeFailedOverNotification,
237
+ MaintenanceNotificationsParser.parse_maintenance_completed_msg,
238
+ ),
239
+ _MOVING_MESSAGE: (
240
+ NodeMovingNotification,
241
+ MaintenanceNotificationsParser.parse_moving_msg,
242
+ ),
243
+ }
162
244
 
163
245
 
164
246
  class PushNotificationsParser(Protocol):
@@ -166,16 +248,57 @@ class PushNotificationsParser(Protocol):
166
248
 
167
249
  pubsub_push_handler_func: Callable
168
250
  invalidation_push_handler_func: Optional[Callable] = None
251
+ node_moving_push_handler_func: Optional[Callable] = None
252
+ maintenance_push_handler_func: Optional[Callable] = None
169
253
 
170
254
  def handle_pubsub_push_response(self, response):
171
255
  """Handle pubsub push responses"""
172
256
  raise NotImplementedError()
173
257
 
174
258
  def handle_push_response(self, response, **kwargs):
175
- if response[0] not in _INVALIDATION_MESSAGE:
259
+ msg_type = response[0]
260
+ if isinstance(msg_type, bytes):
261
+ msg_type = msg_type.decode()
262
+
263
+ if msg_type not in (
264
+ _INVALIDATION_MESSAGE,
265
+ *_MAINTENANCE_MESSAGES,
266
+ _MOVING_MESSAGE,
267
+ ):
176
268
  return self.pubsub_push_handler_func(response)
177
- if self.invalidation_push_handler_func:
178
- return self.invalidation_push_handler_func(response)
269
+
270
+ try:
271
+ if (
272
+ msg_type == _INVALIDATION_MESSAGE
273
+ and self.invalidation_push_handler_func
274
+ ):
275
+ return self.invalidation_push_handler_func(response)
276
+
277
+ if msg_type == _MOVING_MESSAGE and self.node_moving_push_handler_func:
278
+ parser_function = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[
279
+ msg_type
280
+ ][1]
281
+
282
+ notification = parser_function(response)
283
+ return self.node_moving_push_handler_func(notification)
284
+
285
+ if msg_type in _MAINTENANCE_MESSAGES and self.maintenance_push_handler_func:
286
+ parser_function = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[
287
+ msg_type
288
+ ][1]
289
+ notification_type = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[
290
+ msg_type
291
+ ][0]
292
+ notification = parser_function(response, notification_type)
293
+
294
+ if notification is not None:
295
+ return self.maintenance_push_handler_func(notification)
296
+ except Exception as e:
297
+ logger.error(
298
+ "Error handling {} message ({}): {}".format(msg_type, response, e)
299
+ )
300
+
301
+ return None
179
302
 
180
303
  def set_pubsub_push_handler(self, pubsub_push_handler_func):
181
304
  self.pubsub_push_handler_func = pubsub_push_handler_func
@@ -183,12 +306,20 @@ class PushNotificationsParser(Protocol):
183
306
  def set_invalidation_push_handler(self, invalidation_push_handler_func):
184
307
  self.invalidation_push_handler_func = invalidation_push_handler_func
185
308
 
309
+ def set_node_moving_push_handler(self, node_moving_push_handler_func):
310
+ self.node_moving_push_handler_func = node_moving_push_handler_func
311
+
312
+ def set_maintenance_push_handler(self, maintenance_push_handler_func):
313
+ self.maintenance_push_handler_func = maintenance_push_handler_func
314
+
186
315
 
187
316
  class AsyncPushNotificationsParser(Protocol):
188
317
  """Protocol defining async RESP3-specific parsing functionality"""
189
318
 
190
319
  pubsub_push_handler_func: Callable
191
320
  invalidation_push_handler_func: Optional[Callable] = None
321
+ node_moving_push_handler_func: Optional[Callable[..., Awaitable[None]]] = None
322
+ maintenance_push_handler_func: Optional[Callable[..., Awaitable[None]]] = None
192
323
 
193
324
  async def handle_pubsub_push_response(self, response):
194
325
  """Handle pubsub push responses asynchronously"""
@@ -196,10 +327,52 @@ class AsyncPushNotificationsParser(Protocol):
196
327
 
197
328
  async def handle_push_response(self, response, **kwargs):
198
329
  """Handle push responses asynchronously"""
199
- if response[0] not in _INVALIDATION_MESSAGE:
330
+
331
+ msg_type = response[0]
332
+ if isinstance(msg_type, bytes):
333
+ msg_type = msg_type.decode()
334
+
335
+ if msg_type not in (
336
+ _INVALIDATION_MESSAGE,
337
+ *_MAINTENANCE_MESSAGES,
338
+ _MOVING_MESSAGE,
339
+ ):
200
340
  return await self.pubsub_push_handler_func(response)
201
- if self.invalidation_push_handler_func:
202
- return await self.invalidation_push_handler_func(response)
341
+
342
+ try:
343
+ if (
344
+ msg_type == _INVALIDATION_MESSAGE
345
+ and self.invalidation_push_handler_func
346
+ ):
347
+ return await self.invalidation_push_handler_func(response)
348
+
349
+ if isinstance(msg_type, bytes):
350
+ msg_type = msg_type.decode()
351
+
352
+ if msg_type == _MOVING_MESSAGE and self.node_moving_push_handler_func:
353
+ parser_function = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[
354
+ msg_type
355
+ ][1]
356
+ notification = parser_function(response)
357
+ return await self.node_moving_push_handler_func(notification)
358
+
359
+ if msg_type in _MAINTENANCE_MESSAGES and self.maintenance_push_handler_func:
360
+ parser_function = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[
361
+ msg_type
362
+ ][1]
363
+ notification_type = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[
364
+ msg_type
365
+ ][0]
366
+ notification = parser_function(response, notification_type)
367
+
368
+ if notification is not None:
369
+ return await self.maintenance_push_handler_func(notification)
370
+ except Exception as e:
371
+ logger.error(
372
+ "Error handling {} message ({}): {}".format(msg_type, response, e)
373
+ )
374
+
375
+ return None
203
376
 
204
377
  def set_pubsub_push_handler(self, pubsub_push_handler_func):
205
378
  """Set the pubsub push handler function"""
@@ -209,6 +382,12 @@ class AsyncPushNotificationsParser(Protocol):
209
382
  """Set the invalidation push handler function"""
210
383
  self.invalidation_push_handler_func = invalidation_push_handler_func
211
384
 
385
+ def set_node_moving_push_handler(self, node_moving_push_handler_func):
386
+ self.node_moving_push_handler_func = node_moving_push_handler_func
387
+
388
+ def set_maintenance_push_handler(self, maintenance_push_handler_func):
389
+ self.maintenance_push_handler_func = maintenance_push_handler_func
390
+
212
391
 
213
392
  class _AsyncRESPBase(AsyncBaseParser):
214
393
  """Base class for async resp parsing"""
@@ -47,6 +47,8 @@ class _HiredisParser(BaseParser, PushNotificationsParser):
47
47
  self.socket_read_size = socket_read_size
48
48
  self._buffer = bytearray(socket_read_size)
49
49
  self.pubsub_push_handler_func = self.handle_pubsub_push_response
50
+ self.node_moving_push_handler_func = None
51
+ self.maintenance_push_handler_func = None
50
52
  self.invalidation_push_handler_func = None
51
53
  self._hiredis_PushNotificationType = None
52
54
 
@@ -141,12 +143,15 @@ class _HiredisParser(BaseParser, PushNotificationsParser):
141
143
  response, self._hiredis_PushNotificationType
142
144
  ):
143
145
  response = self.handle_push_response(response)
144
- if not push_request:
145
- return self.read_response(
146
- disable_decoding=disable_decoding, push_request=push_request
147
- )
148
- else:
146
+
147
+ # if this is a push request return the push response
148
+ if push_request:
149
149
  return response
150
+
151
+ return self.read_response(
152
+ disable_decoding=disable_decoding,
153
+ push_request=push_request,
154
+ )
150
155
  return response
151
156
 
152
157
  if disable_decoding:
@@ -169,12 +174,13 @@ class _HiredisParser(BaseParser, PushNotificationsParser):
169
174
  response, self._hiredis_PushNotificationType
170
175
  ):
171
176
  response = self.handle_push_response(response)
172
- if not push_request:
173
- return self.read_response(
174
- disable_decoding=disable_decoding, push_request=push_request
175
- )
176
- else:
177
+ if push_request:
177
178
  return response
179
+ return self.read_response(
180
+ disable_decoding=disable_decoding,
181
+ push_request=push_request,
182
+ )
183
+
178
184
  elif (
179
185
  isinstance(response, list)
180
186
  and response
@@ -18,6 +18,8 @@ class _RESP3Parser(_RESPBase, PushNotificationsParser):
18
18
  def __init__(self, socket_read_size):
19
19
  super().__init__(socket_read_size)
20
20
  self.pubsub_push_handler_func = self.handle_pubsub_push_response
21
+ self.node_moving_push_handler_func = None
22
+ self.maintenance_push_handler_func = None
21
23
  self.invalidation_push_handler_func = None
22
24
 
23
25
  def handle_pubsub_push_response(self, response):
@@ -117,17 +119,21 @@ class _RESP3Parser(_RESPBase, PushNotificationsParser):
117
119
  for _ in range(int(response))
118
120
  ]
119
121
  response = self.handle_push_response(response)
120
- if not push_request:
121
- return self._read_response(
122
- disable_decoding=disable_decoding, push_request=push_request
123
- )
124
- else:
122
+
123
+ # if this is a push request return the push response
124
+ if push_request:
125
125
  return response
126
+
127
+ return self._read_response(
128
+ disable_decoding=disable_decoding,
129
+ push_request=push_request,
130
+ )
126
131
  else:
127
132
  raise InvalidResponse(f"Protocol Error: {raw!r}")
128
133
 
129
134
  if isinstance(response, bytes) and disable_decoding is False:
130
135
  response = self.encoder.decode(response)
136
+
131
137
  return response
132
138
 
133
139
 
@@ -81,10 +81,11 @@ from redis.utils import (
81
81
  )
82
82
 
83
83
  if TYPE_CHECKING and SSL_AVAILABLE:
84
- from ssl import TLSVersion, VerifyMode
84
+ from ssl import TLSVersion, VerifyFlags, VerifyMode
85
85
  else:
86
86
  TLSVersion = None
87
87
  VerifyMode = None
88
+ VerifyFlags = None
88
89
 
89
90
  PubSubHandler = Callable[[Dict[str, str]], Awaitable[None]]
90
91
  _KeyT = TypeVar("_KeyT", bound=KeyT)
@@ -238,6 +239,8 @@ class Redis(
238
239
  ssl_keyfile: Optional[str] = None,
239
240
  ssl_certfile: Optional[str] = None,
240
241
  ssl_cert_reqs: Union[str, VerifyMode] = "required",
242
+ ssl_include_verify_flags: Optional[List[VerifyFlags]] = None,
243
+ ssl_exclude_verify_flags: Optional[List[VerifyFlags]] = None,
241
244
  ssl_ca_certs: Optional[str] = None,
242
245
  ssl_ca_data: Optional[str] = None,
243
246
  ssl_check_hostname: bool = True,
@@ -347,6 +350,8 @@ class Redis(
347
350
  "ssl_keyfile": ssl_keyfile,
348
351
  "ssl_certfile": ssl_certfile,
349
352
  "ssl_cert_reqs": ssl_cert_reqs,
353
+ "ssl_include_verify_flags": ssl_include_verify_flags,
354
+ "ssl_exclude_verify_flags": ssl_exclude_verify_flags,
350
355
  "ssl_ca_certs": ssl_ca_certs,
351
356
  "ssl_ca_data": ssl_ca_data,
352
357
  "ssl_check_hostname": ssl_check_hostname,
@@ -387,6 +392,12 @@ class Redis(
387
392
  # on a set of redis commands
388
393
  self._single_conn_lock = asyncio.Lock()
389
394
 
395
+ # When used as an async context manager, we need to increment and decrement
396
+ # a usage counter so that we can close the connection pool when no one is
397
+ # using the client.
398
+ self._usage_counter = 0
399
+ self._usage_lock = asyncio.Lock()
400
+
390
401
  def __repr__(self):
391
402
  return (
392
403
  f"<{self.__class__.__module__}.{self.__class__.__name__}"
@@ -594,10 +605,47 @@ class Redis(
594
605
  )
595
606
 
596
607
  async def __aenter__(self: _RedisT) -> _RedisT:
597
- return await self.initialize()
608
+ """
609
+ Async context manager entry. Increments a usage counter so that the
610
+ connection pool is only closed (via aclose()) when no context is using
611
+ the client.
612
+ """
613
+ await self._increment_usage()
614
+ try:
615
+ # Initialize the client (i.e. establish connection, etc.)
616
+ return await self.initialize()
617
+ except Exception:
618
+ # If initialization fails, decrement the counter to keep it in sync
619
+ await self._decrement_usage()
620
+ raise
621
+
622
+ async def _increment_usage(self) -> int:
623
+ """
624
+ Helper coroutine to increment the usage counter while holding the lock.
625
+ Returns the new value of the usage counter.
626
+ """
627
+ async with self._usage_lock:
628
+ self._usage_counter += 1
629
+ return self._usage_counter
630
+
631
+ async def _decrement_usage(self) -> int:
632
+ """
633
+ Helper coroutine to decrement the usage counter while holding the lock.
634
+ Returns the new value of the usage counter.
635
+ """
636
+ async with self._usage_lock:
637
+ self._usage_counter -= 1
638
+ return self._usage_counter
598
639
 
599
640
  async def __aexit__(self, exc_type, exc_value, traceback):
600
- await self.aclose()
641
+ """
642
+ Async context manager exit. Decrements a usage counter. If this is the
643
+ last exit (counter becomes zero), the client closes its connection pool.
644
+ """
645
+ current_usage = await asyncio.shield(self._decrement_usage())
646
+ if current_usage == 0:
647
+ # This was the last active context, so disconnect the pool.
648
+ await asyncio.shield(self.aclose())
601
649
 
602
650
  _DEL_MESSAGE = "Unclosed Redis client"
603
651
 
@@ -86,10 +86,11 @@ from redis.utils import (
86
86
  )
87
87
 
88
88
  if SSL_AVAILABLE:
89
- from ssl import TLSVersion, VerifyMode
89
+ from ssl import TLSVersion, VerifyFlags, VerifyMode
90
90
  else:
91
91
  TLSVersion = None
92
92
  VerifyMode = None
93
+ VerifyFlags = None
93
94
 
94
95
  TargetNodesT = TypeVar(
95
96
  "TargetNodesT", str, "ClusterNode", List["ClusterNode"], Dict[Any, "ClusterNode"]
@@ -299,6 +300,8 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
299
300
  ssl_ca_certs: Optional[str] = None,
300
301
  ssl_ca_data: Optional[str] = None,
301
302
  ssl_cert_reqs: Union[str, VerifyMode] = "required",
303
+ ssl_include_verify_flags: Optional[List[VerifyFlags]] = None,
304
+ ssl_exclude_verify_flags: Optional[List[VerifyFlags]] = None,
302
305
  ssl_certfile: Optional[str] = None,
303
306
  ssl_check_hostname: bool = True,
304
307
  ssl_keyfile: Optional[str] = None,
@@ -358,6 +361,8 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
358
361
  "ssl_ca_certs": ssl_ca_certs,
359
362
  "ssl_ca_data": ssl_ca_data,
360
363
  "ssl_cert_reqs": ssl_cert_reqs,
364
+ "ssl_include_verify_flags": ssl_include_verify_flags,
365
+ "ssl_exclude_verify_flags": ssl_exclude_verify_flags,
361
366
  "ssl_certfile": ssl_certfile,
362
367
  "ssl_check_hostname": ssl_check_hostname,
363
368
  "ssl_keyfile": ssl_keyfile,
@@ -431,6 +436,12 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
431
436
  self._initialize = True
432
437
  self._lock: Optional[asyncio.Lock] = None
433
438
 
439
+ # When used as an async context manager, we need to increment and decrement
440
+ # a usage counter so that we can close the connection pool when no one is
441
+ # using the client.
442
+ self._usage_counter = 0
443
+ self._usage_lock = asyncio.Lock()
444
+
434
445
  async def initialize(self) -> "RedisCluster":
435
446
  """Get all nodes from startup nodes & creates connections if not initialized."""
436
447
  if self._initialize:
@@ -467,10 +478,47 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
467
478
  await self.aclose()
468
479
 
469
480
  async def __aenter__(self) -> "RedisCluster":
470
- return await self.initialize()
481
+ """
482
+ Async context manager entry. Increments a usage counter so that the
483
+ connection pool is only closed (via aclose()) when no context is using
484
+ the client.
485
+ """
486
+ await self._increment_usage()
487
+ try:
488
+ # Initialize the client (i.e. establish connection, etc.)
489
+ return await self.initialize()
490
+ except Exception:
491
+ # If initialization fails, decrement the counter to keep it in sync
492
+ await self._decrement_usage()
493
+ raise
471
494
 
472
- async def __aexit__(self, exc_type: None, exc_value: None, traceback: None) -> None:
473
- await self.aclose()
495
+ async def _increment_usage(self) -> int:
496
+ """
497
+ Helper coroutine to increment the usage counter while holding the lock.
498
+ Returns the new value of the usage counter.
499
+ """
500
+ async with self._usage_lock:
501
+ self._usage_counter += 1
502
+ return self._usage_counter
503
+
504
+ async def _decrement_usage(self) -> int:
505
+ """
506
+ Helper coroutine to decrement the usage counter while holding the lock.
507
+ Returns the new value of the usage counter.
508
+ """
509
+ async with self._usage_lock:
510
+ self._usage_counter -= 1
511
+ return self._usage_counter
512
+
513
+ async def __aexit__(self, exc_type, exc_value, traceback):
514
+ """
515
+ Async context manager exit. Decrements a usage counter. If this is the
516
+ last exit (counter becomes zero), the client closes its connection pool.
517
+ """
518
+ current_usage = await asyncio.shield(self._decrement_usage())
519
+ if current_usage == 0:
520
+ # This was the last active context, so disconnect the pool.
521
+ await asyncio.shield(self.aclose())
474
522
 
475
523
  def __await__(self) -> Generator[Any, None, "RedisCluster"]:
476
524
  return self.initialize().__await__()
@@ -30,11 +30,12 @@ from ..utils import SSL_AVAILABLE
30
30
 
31
31
  if SSL_AVAILABLE:
32
32
  import ssl
33
- from ssl import SSLContext, TLSVersion
33
+ from ssl import SSLContext, TLSVersion, VerifyFlags
34
34
  else:
35
35
  ssl = None
36
36
  TLSVersion = None
37
37
  SSLContext = None
38
+ VerifyFlags = None
38
39
 
39
40
  from ..auth.token import TokenInterface
40
41
  from ..event import AsyncAfterConnectionReleasedEvent, EventDispatcher
@@ -793,6 +794,8 @@ class SSLConnection(Connection):
793
794
  ssl_keyfile: Optional[str] = None,
794
795
  ssl_certfile: Optional[str] = None,
795
796
  ssl_cert_reqs: Union[str, ssl.VerifyMode] = "required",
797
+ ssl_include_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
798
+ ssl_exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
796
799
  ssl_ca_certs: Optional[str] = None,
797
800
  ssl_ca_data: Optional[str] = None,
798
801
  ssl_check_hostname: bool = True,
@@ -807,6 +810,8 @@ class SSLConnection(Connection):
807
810
  keyfile=ssl_keyfile,
808
811
  certfile=ssl_certfile,
809
812
  cert_reqs=ssl_cert_reqs,
813
+ include_verify_flags=ssl_include_verify_flags,
814
+ exclude_verify_flags=ssl_exclude_verify_flags,
810
815
  ca_certs=ssl_ca_certs,
811
816
  ca_data=ssl_ca_data,
812
817
  check_hostname=ssl_check_hostname,
@@ -832,6 +837,14 @@ class SSLConnection(Connection):
832
837
  def cert_reqs(self):
833
838
  return self.ssl_context.cert_reqs
834
839
 
840
+ @property
841
+ def include_verify_flags(self):
842
+ return self.ssl_context.include_verify_flags
843
+
844
+ @property
845
+ def exclude_verify_flags(self):
846
+ return self.ssl_context.exclude_verify_flags
847
+
835
848
  @property
836
849
  def ca_certs(self):
837
850
  return self.ssl_context.ca_certs
@@ -854,6 +867,8 @@ class RedisSSLContext:
854
867
  "keyfile",
855
868
  "certfile",
856
869
  "cert_reqs",
870
+ "include_verify_flags",
871
+ "exclude_verify_flags",
857
872
  "ca_certs",
858
873
  "ca_data",
859
874
  "context",
@@ -867,6 +882,8 @@ class RedisSSLContext:
867
882
  keyfile: Optional[str] = None,
868
883
  certfile: Optional[str] = None,
869
884
  cert_reqs: Optional[Union[str, ssl.VerifyMode]] = None,
885
+ include_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
886
+ exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
870
887
  ca_certs: Optional[str] = None,
871
888
  ca_data: Optional[str] = None,
872
889
  check_hostname: bool = False,
@@ -892,6 +909,8 @@ class RedisSSLContext:
892
909
  )
893
910
  cert_reqs = CERT_REQS[cert_reqs]
894
911
  self.cert_reqs = cert_reqs
912
+ self.include_verify_flags = include_verify_flags
913
+ self.exclude_verify_flags = exclude_verify_flags
895
914
  self.ca_certs = ca_certs
896
915
  self.ca_data = ca_data
897
916
  self.check_hostname = (
@@ -906,6 +925,12 @@ class RedisSSLContext:
906
925
  context = ssl.create_default_context()
907
926
  context.check_hostname = self.check_hostname
908
927
  context.verify_mode = self.cert_reqs
928
+ if self.include_verify_flags:
929
+ for flag in self.include_verify_flags:
930
+ context.verify_flags |= flag
931
+ if self.exclude_verify_flags:
932
+ for flag in self.exclude_verify_flags:
933
+ context.verify_flags &= ~flag
909
934
  if self.certfile and self.keyfile:
910
935
  context.load_cert_chain(certfile=self.certfile, keyfile=self.keyfile)
911
936
  if self.ca_certs or self.ca_data:
@@ -953,6 +978,20 @@ def to_bool(value) -> Optional[bool]:
953
978
  return bool(value)
954
979
 
955
980
 
981
+ def parse_ssl_verify_flags(value):
982
+ # flags are passed in as a string representation of a list,
983
+ # e.g. VERIFY_X509_STRICT, VERIFY_X509_PARTIAL_CHAIN
984
+ verify_flags_str = value.replace("[", "").replace("]", "")
985
+
986
+ verify_flags = []
987
+ for flag in verify_flags_str.split(","):
988
+ flag = flag.strip()
989
+ if not hasattr(VerifyFlags, flag):
990
+ raise ValueError(f"Invalid ssl verify flag: {flag}")
991
+ verify_flags.append(getattr(VerifyFlags, flag))
992
+ return verify_flags
993
+
994
+
956
995
  URL_QUERY_ARGUMENT_PARSERS: Mapping[str, Callable[..., object]] = MappingProxyType(
957
996
  {
958
997
  "db": int,
@@ -963,6 +1002,8 @@ URL_QUERY_ARGUMENT_PARSERS: Mapping[str, Callable[..., object]] = MappingProxyTy
963
1002
  "max_connections": int,
964
1003
  "health_check_interval": int,
965
1004
  "ssl_check_hostname": to_bool,
1005
+ "ssl_include_verify_flags": parse_ssl_verify_flags,
1006
+ "ssl_exclude_verify_flags": parse_ssl_verify_flags,
966
1007
  "timeout": float,
967
1008
  }
968
1009
  )
@@ -1021,6 +1062,7 @@ def parse_url(url: str) -> ConnectKwargs:
1021
1062
 
1022
1063
  if parsed.scheme == "rediss":
1023
1064
  kwargs["connection_class"] = SSLConnection
1065
+
1024
1066
  else:
1025
1067
  valid_schemes = "redis://, rediss://, unix://"
1026
1068
  raise ValueError(
@@ -50,6 +50,7 @@ class EvictionPolicyInterface(ABC):
50
50
  pass
51
51
 
52
52
  @cache.setter
53
+ @abstractmethod
53
54
  def cache(self, value):
54
55
  pass
55
56