redis 6.3.0__py3-none-any.whl → 7.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. redis/__init__.py +1 -2
  2. redis/_parsers/base.py +193 -8
  3. redis/_parsers/helpers.py +64 -6
  4. redis/_parsers/hiredis.py +16 -10
  5. redis/_parsers/resp3.py +11 -5
  6. redis/asyncio/client.py +65 -8
  7. redis/asyncio/cluster.py +57 -14
  8. redis/asyncio/connection.py +62 -2
  9. redis/asyncio/http/__init__.py +0 -0
  10. redis/asyncio/http/http_client.py +265 -0
  11. redis/asyncio/multidb/__init__.py +0 -0
  12. redis/asyncio/multidb/client.py +530 -0
  13. redis/asyncio/multidb/command_executor.py +339 -0
  14. redis/asyncio/multidb/config.py +210 -0
  15. redis/asyncio/multidb/database.py +69 -0
  16. redis/asyncio/multidb/event.py +84 -0
  17. redis/asyncio/multidb/failover.py +125 -0
  18. redis/asyncio/multidb/failure_detector.py +38 -0
  19. redis/asyncio/multidb/healthcheck.py +285 -0
  20. redis/background.py +204 -0
  21. redis/cache.py +1 -0
  22. redis/client.py +99 -22
  23. redis/cluster.py +14 -3
  24. redis/commands/core.py +348 -313
  25. redis/commands/helpers.py +0 -20
  26. redis/commands/json/_util.py +4 -2
  27. redis/commands/json/commands.py +2 -2
  28. redis/commands/search/__init__.py +2 -2
  29. redis/commands/search/aggregation.py +28 -30
  30. redis/commands/search/commands.py +13 -13
  31. redis/commands/search/field.py +2 -2
  32. redis/commands/search/query.py +23 -23
  33. redis/commands/vectorset/__init__.py +1 -1
  34. redis/commands/vectorset/commands.py +50 -25
  35. redis/commands/vectorset/utils.py +40 -4
  36. redis/connection.py +1258 -90
  37. redis/data_structure.py +81 -0
  38. redis/event.py +88 -14
  39. redis/exceptions.py +8 -0
  40. redis/http/__init__.py +0 -0
  41. redis/http/http_client.py +425 -0
  42. redis/maint_notifications.py +810 -0
  43. redis/multidb/__init__.py +0 -0
  44. redis/multidb/circuit.py +144 -0
  45. redis/multidb/client.py +526 -0
  46. redis/multidb/command_executor.py +350 -0
  47. redis/multidb/config.py +207 -0
  48. redis/multidb/database.py +130 -0
  49. redis/multidb/event.py +89 -0
  50. redis/multidb/exception.py +17 -0
  51. redis/multidb/failover.py +125 -0
  52. redis/multidb/failure_detector.py +104 -0
  53. redis/multidb/healthcheck.py +282 -0
  54. redis/retry.py +14 -1
  55. redis/utils.py +34 -0
  56. {redis-6.3.0.dist-info → redis-7.0.0.dist-info}/METADATA +7 -4
  57. redis-7.0.0.dist-info/RECORD +105 -0
  58. redis-6.3.0.dist-info/RECORD +0 -78
  59. {redis-6.3.0.dist-info → redis-7.0.0.dist-info}/WHEEL +0 -0
  60. {redis-6.3.0.dist-info → redis-7.0.0.dist-info}/licenses/LICENSE +0 -0
redis/__init__.py CHANGED
@@ -46,8 +46,7 @@ def int_or_str(value):
46
46
  return value
47
47
 
48
48
 
49
- # This version is used when building the package for publishing
50
- __version__ = "6.3.0"
49
+ __version__ = "7.0.0"
51
50
  VERSION = tuple(map(int_or_str, __version__.split(".")))
52
51
 
53
52
 
redis/_parsers/base.py CHANGED
@@ -1,7 +1,17 @@
1
+ import logging
1
2
  import sys
2
3
  from abc import ABC
3
4
  from asyncio import IncompleteReadError, StreamReader, TimeoutError
4
- from typing import Callable, List, Optional, Protocol, Union
5
+ from typing import Awaitable, Callable, List, Optional, Protocol, Union
6
+
7
+ from redis.maint_notifications import (
8
+ MaintenanceNotification,
9
+ NodeFailedOverNotification,
10
+ NodeFailingOverNotification,
11
+ NodeMigratedNotification,
12
+ NodeMigratingNotification,
13
+ NodeMovingNotification,
14
+ )
5
15
 
6
16
  if sys.version_info.major >= 3 and sys.version_info.minor >= 11:
7
17
  from asyncio import timeout as async_timeout
@@ -17,6 +27,7 @@ from ..exceptions import (
17
27
  ClusterDownError,
18
28
  ConnectionError,
19
29
  ExecAbortError,
30
+ ExternalAuthProviderError,
20
31
  MasterDownError,
21
32
  ModuleError,
22
33
  MovedError,
@@ -50,6 +61,12 @@ NO_AUTH_SET_ERROR = {
50
61
  "Client sent AUTH, but no password is set": AuthenticationError,
51
62
  }
52
63
 
64
+ EXTERNAL_AUTH_PROVIDER_ERROR = {
65
+ "problem with LDAP service": ExternalAuthProviderError,
66
+ }
67
+
68
+ logger = logging.getLogger(__name__)
69
+
53
70
 
54
71
  class BaseParser(ABC):
55
72
  EXCEPTION_CLASSES = {
@@ -69,6 +86,7 @@ class BaseParser(ABC):
69
86
  NO_SUCH_MODULE_ERROR: ModuleError,
70
87
  MODULE_UNLOAD_NOT_POSSIBLE_ERROR: ModuleError,
71
88
  **NO_AUTH_SET_ERROR,
89
+ **EXTERNAL_AUTH_PROVIDER_ERROR,
72
90
  },
73
91
  "OOM": OutOfMemoryError,
74
92
  "WRONGPASS": AuthenticationError,
@@ -158,7 +176,77 @@ class AsyncBaseParser(BaseParser):
158
176
  raise NotImplementedError()
159
177
 
160
178
 
161
- _INVALIDATION_MESSAGE = [b"invalidate", "invalidate"]
179
+ class MaintenanceNotificationsParser:
180
+ """Protocol defining maintenance push notification parsing functionality"""
181
+
182
+ @staticmethod
183
+ def parse_maintenance_start_msg(response, notification_type):
184
+ # Expected message format is: <notification_type> <seq_number> <time>
185
+ id = response[1]
186
+ ttl = response[2]
187
+ return notification_type(id, ttl)
188
+
189
+ @staticmethod
190
+ def parse_maintenance_completed_msg(response, notification_type):
191
+ # Expected message format is: <notification_type> <seq_number>
192
+ id = response[1]
193
+ return notification_type(id)
194
+
195
+ @staticmethod
196
+ def parse_moving_msg(response):
197
+ # Expected message format is: MOVING <seq_number> <time> <endpoint>
198
+ id = response[1]
199
+ ttl = response[2]
200
+ if response[3] is None:
201
+ host, port = None, None
202
+ else:
203
+ value = response[3]
204
+ if isinstance(value, bytes):
205
+ value = value.decode()
206
+ host, port = value.split(":")
207
+ port = int(port) if port is not None else None
208
+
209
+ return NodeMovingNotification(id, host, port, ttl)
210
+
211
+
212
+ _INVALIDATION_MESSAGE = "invalidate"
213
+ _MOVING_MESSAGE = "MOVING"
214
+ _MIGRATING_MESSAGE = "MIGRATING"
215
+ _MIGRATED_MESSAGE = "MIGRATED"
216
+ _FAILING_OVER_MESSAGE = "FAILING_OVER"
217
+ _FAILED_OVER_MESSAGE = "FAILED_OVER"
218
+
219
+ _MAINTENANCE_MESSAGES = (
220
+ _MIGRATING_MESSAGE,
221
+ _MIGRATED_MESSAGE,
222
+ _FAILING_OVER_MESSAGE,
223
+ _FAILED_OVER_MESSAGE,
224
+ )
225
+
226
+ MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING: dict[
227
+ str, tuple[type[MaintenanceNotification], Callable]
228
+ ] = {
229
+ _MIGRATING_MESSAGE: (
230
+ NodeMigratingNotification,
231
+ MaintenanceNotificationsParser.parse_maintenance_start_msg,
232
+ ),
233
+ _MIGRATED_MESSAGE: (
234
+ NodeMigratedNotification,
235
+ MaintenanceNotificationsParser.parse_maintenance_completed_msg,
236
+ ),
237
+ _FAILING_OVER_MESSAGE: (
238
+ NodeFailingOverNotification,
239
+ MaintenanceNotificationsParser.parse_maintenance_start_msg,
240
+ ),
241
+ _FAILED_OVER_MESSAGE: (
242
+ NodeFailedOverNotification,
243
+ MaintenanceNotificationsParser.parse_maintenance_completed_msg,
244
+ ),
245
+ _MOVING_MESSAGE: (
246
+ NodeMovingNotification,
247
+ MaintenanceNotificationsParser.parse_moving_msg,
248
+ ),
249
+ }
162
250
 
163
251
 
164
252
  class PushNotificationsParser(Protocol):
@@ -166,16 +254,57 @@ class PushNotificationsParser(Protocol):
166
254
 
167
255
  pubsub_push_handler_func: Callable
168
256
  invalidation_push_handler_func: Optional[Callable] = None
257
+ node_moving_push_handler_func: Optional[Callable] = None
258
+ maintenance_push_handler_func: Optional[Callable] = None
169
259
 
170
260
  def handle_pubsub_push_response(self, response):
171
261
  """Handle pubsub push responses"""
172
262
  raise NotImplementedError()
173
263
 
174
264
  def handle_push_response(self, response, **kwargs):
175
- if response[0] not in _INVALIDATION_MESSAGE:
265
+ msg_type = response[0]
266
+ if isinstance(msg_type, bytes):
267
+ msg_type = msg_type.decode()
268
+
269
+ if msg_type not in (
270
+ _INVALIDATION_MESSAGE,
271
+ *_MAINTENANCE_MESSAGES,
272
+ _MOVING_MESSAGE,
273
+ ):
176
274
  return self.pubsub_push_handler_func(response)
177
- if self.invalidation_push_handler_func:
178
- return self.invalidation_push_handler_func(response)
275
+
276
+ try:
277
+ if (
278
+ msg_type == _INVALIDATION_MESSAGE
279
+ and self.invalidation_push_handler_func
280
+ ):
281
+ return self.invalidation_push_handler_func(response)
282
+
283
+ if msg_type == _MOVING_MESSAGE and self.node_moving_push_handler_func:
284
+ parser_function = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[
285
+ msg_type
286
+ ][1]
287
+
288
+ notification = parser_function(response)
289
+ return self.node_moving_push_handler_func(notification)
290
+
291
+ if msg_type in _MAINTENANCE_MESSAGES and self.maintenance_push_handler_func:
292
+ parser_function = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[
293
+ msg_type
294
+ ][1]
295
+ notification_type = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[
296
+ msg_type
297
+ ][0]
298
+ notification = parser_function(response, notification_type)
299
+
300
+ if notification is not None:
301
+ return self.maintenance_push_handler_func(notification)
302
+ except Exception as e:
303
+ logger.error(
304
+ "Error handling {} message ({}): {}".format(msg_type, response, e)
305
+ )
306
+
307
+ return None
179
308
 
180
309
  def set_pubsub_push_handler(self, pubsub_push_handler_func):
181
310
  self.pubsub_push_handler_func = pubsub_push_handler_func
@@ -183,12 +312,20 @@ class PushNotificationsParser(Protocol):
183
312
  def set_invalidation_push_handler(self, invalidation_push_handler_func):
184
313
  self.invalidation_push_handler_func = invalidation_push_handler_func
185
314
 
315
+ def set_node_moving_push_handler(self, node_moving_push_handler_func):
316
+ self.node_moving_push_handler_func = node_moving_push_handler_func
317
+
318
+ def set_maintenance_push_handler(self, maintenance_push_handler_func):
319
+ self.maintenance_push_handler_func = maintenance_push_handler_func
320
+
186
321
 
187
322
  class AsyncPushNotificationsParser(Protocol):
188
323
  """Protocol defining async RESP3-specific parsing functionality"""
189
324
 
190
325
  pubsub_push_handler_func: Callable
191
326
  invalidation_push_handler_func: Optional[Callable] = None
327
+ node_moving_push_handler_func: Optional[Callable[..., Awaitable[None]]] = None
328
+ maintenance_push_handler_func: Optional[Callable[..., Awaitable[None]]] = None
192
329
 
193
330
  async def handle_pubsub_push_response(self, response):
194
331
  """Handle pubsub push responses asynchronously"""
@@ -196,10 +333,52 @@ class AsyncPushNotificationsParser(Protocol):
196
333
 
197
334
  async def handle_push_response(self, response, **kwargs):
198
335
  """Handle push responses asynchronously"""
199
- if response[0] not in _INVALIDATION_MESSAGE:
336
+
337
+ msg_type = response[0]
338
+ if isinstance(msg_type, bytes):
339
+ msg_type = msg_type.decode()
340
+
341
+ if msg_type not in (
342
+ _INVALIDATION_MESSAGE,
343
+ *_MAINTENANCE_MESSAGES,
344
+ _MOVING_MESSAGE,
345
+ ):
200
346
  return await self.pubsub_push_handler_func(response)
201
- if self.invalidation_push_handler_func:
202
- return await self.invalidation_push_handler_func(response)
347
+
348
+ try:
349
+ if (
350
+ msg_type == _INVALIDATION_MESSAGE
351
+ and self.invalidation_push_handler_func
352
+ ):
353
+ return await self.invalidation_push_handler_func(response)
354
+
355
+ if isinstance(msg_type, bytes):
356
+ msg_type = msg_type.decode()
357
+
358
+ if msg_type == _MOVING_MESSAGE and self.node_moving_push_handler_func:
359
+ parser_function = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[
360
+ msg_type
361
+ ][1]
362
+ notification = parser_function(response)
363
+ return await self.node_moving_push_handler_func(notification)
364
+
365
+ if msg_type in _MAINTENANCE_MESSAGES and self.maintenance_push_handler_func:
366
+ parser_function = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[
367
+ msg_type
368
+ ][1]
369
+ notification_type = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[
370
+ msg_type
371
+ ][0]
372
+ notification = parser_function(response, notification_type)
373
+
374
+ if notification is not None:
375
+ return await self.maintenance_push_handler_func(notification)
376
+ except Exception as e:
377
+ logger.error(
378
+ "Error handling {} message ({}): {}".format(msg_type, response, e)
379
+ )
380
+
381
+ return None
203
382
 
204
383
  def set_pubsub_push_handler(self, pubsub_push_handler_func):
205
384
  """Set the pubsub push handler function"""
@@ -209,6 +388,12 @@ class AsyncPushNotificationsParser(Protocol):
209
388
  """Set the invalidation push handler function"""
210
389
  self.invalidation_push_handler_func = invalidation_push_handler_func
211
390
 
391
+ def set_node_moving_push_handler(self, node_moving_push_handler_func):
392
+ self.node_moving_push_handler_func = node_moving_push_handler_func
393
+
394
+ def set_maintenance_push_handler(self, maintenance_push_handler_func):
395
+ self.maintenance_push_handler_func = maintenance_push_handler_func
396
+
212
397
 
213
398
  class _AsyncRESPBase(AsyncBaseParser):
214
399
  """Base class for async resp parsing"""
redis/_parsers/helpers.py CHANGED
@@ -224,6 +224,39 @@ def zset_score_pairs(response, **options):
224
224
  return list(zip(it, map(score_cast_func, it)))
225
225
 
226
226
 
227
+ def zset_score_for_rank(response, **options):
228
+ """
229
+ If ``withscores`` is specified in the options, return the response as
230
+ a [value, score] pair
231
+ """
232
+ if not response or not options.get("withscore"):
233
+ return response
234
+ score_cast_func = options.get("score_cast_func", float)
235
+ return [response[0], score_cast_func(response[1])]
236
+
237
+
238
+ def zset_score_pairs_resp3(response, **options):
239
+ """
240
+ If ``withscores`` is specified in the options, return the response as
241
+ a list of [value, score] pairs
242
+ """
243
+ if not response or not options.get("withscores"):
244
+ return response
245
+ score_cast_func = options.get("score_cast_func", float)
246
+ return [[name, score_cast_func(val)] for name, val in response]
247
+
248
+
249
+ def zset_score_for_rank_resp3(response, **options):
250
+ """
251
+ If ``withscores`` is specified in the options, return the response as
252
+ a [value, score] pair
253
+ """
254
+ if not response or not options.get("withscore"):
255
+ return response
256
+ score_cast_func = options.get("score_cast_func", float)
257
+ return [response[0], score_cast_func(response[1])]
258
+
259
+
227
260
  def sort_return_tuples(response, **options):
228
261
  """
229
262
  If ``groups`` is specified, return the response as a list of
@@ -349,8 +382,22 @@ def parse_zadd(response, **options):
349
382
  def parse_client_list(response, **options):
350
383
  clients = []
351
384
  for c in str_if_bytes(response).splitlines():
352
- # Values might contain '='
353
- clients.append(dict(pair.split("=", 1) for pair in c.split(" ")))
385
+ client_dict = {}
386
+ tokens = c.split(" ")
387
+ last_key = None
388
+ for token in tokens:
389
+ if "=" in token:
390
+ # Values might contain '='
391
+ key, value = token.split("=", 1)
392
+ client_dict[key] = value
393
+ last_key = key
394
+ else:
395
+ # Values may include spaces. For instance, when running Redis via a Unix socket — such as
396
+ # "/tmp/redis sock/redis.sock" — the addr or laddr field will include a space.
397
+ client_dict[last_key] += " " + token
398
+
399
+ if client_dict:
400
+ clients.append(client_dict)
354
401
  return clients
355
402
 
356
403
 
@@ -797,10 +844,14 @@ _RedisCallbacksRESP2 = {
797
844
  "SDIFF SINTER SMEMBERS SUNION", lambda r: r and set(r) or set()
798
845
  ),
799
846
  **string_keys_to_dict(
800
- "ZDIFF ZINTER ZPOPMAX ZPOPMIN ZRANGE ZRANGEBYSCORE ZRANK ZREVRANGE "
801
- "ZREVRANGEBYSCORE ZREVRANK ZUNION",
847
+ "ZDIFF ZINTER ZPOPMAX ZPOPMIN ZRANGE ZRANGEBYSCORE ZREVRANGE "
848
+ "ZREVRANGEBYSCORE ZUNION",
802
849
  zset_score_pairs,
803
850
  ),
851
+ **string_keys_to_dict(
852
+ "ZREVRANK ZRANK",
853
+ zset_score_for_rank,
854
+ ),
804
855
  **string_keys_to_dict("ZINCRBY ZSCORE", float_or_none),
805
856
  **string_keys_to_dict("BGREWRITEAOF BGSAVE", lambda r: True),
806
857
  **string_keys_to_dict("BLPOP BRPOP", lambda r: r and tuple(r) or None),
@@ -844,10 +895,17 @@ _RedisCallbacksRESP3 = {
844
895
  "SDIFF SINTER SMEMBERS SUNION", lambda r: r and set(r) or set()
845
896
  ),
846
897
  **string_keys_to_dict(
847
- "ZRANGE ZINTER ZPOPMAX ZPOPMIN ZRANGEBYSCORE ZREVRANGE ZREVRANGEBYSCORE "
848
- "ZUNION HGETALL XREADGROUP",
898
+ "ZRANGE ZINTER ZPOPMAX ZPOPMIN HGETALL XREADGROUP",
849
899
  lambda r, **kwargs: r,
850
900
  ),
901
+ **string_keys_to_dict(
902
+ "ZRANGE ZRANGEBYSCORE ZREVRANGE ZREVRANGEBYSCORE ZUNION",
903
+ zset_score_pairs_resp3,
904
+ ),
905
+ **string_keys_to_dict(
906
+ "ZREVRANK ZRANK",
907
+ zset_score_for_rank_resp3,
908
+ ),
851
909
  **string_keys_to_dict("XREAD XREADGROUP", parse_xread_resp3),
852
910
  "ACL LOG": lambda r: (
853
911
  [
redis/_parsers/hiredis.py CHANGED
@@ -47,6 +47,8 @@ class _HiredisParser(BaseParser, PushNotificationsParser):
47
47
  self.socket_read_size = socket_read_size
48
48
  self._buffer = bytearray(socket_read_size)
49
49
  self.pubsub_push_handler_func = self.handle_pubsub_push_response
50
+ self.node_moving_push_handler_func = None
51
+ self.maintenance_push_handler_func = None
50
52
  self.invalidation_push_handler_func = None
51
53
  self._hiredis_PushNotificationType = None
52
54
 
@@ -141,12 +143,15 @@ class _HiredisParser(BaseParser, PushNotificationsParser):
141
143
  response, self._hiredis_PushNotificationType
142
144
  ):
143
145
  response = self.handle_push_response(response)
144
- if not push_request:
145
- return self.read_response(
146
- disable_decoding=disable_decoding, push_request=push_request
147
- )
148
- else:
146
+
147
+ # if this is a push request return the push response
148
+ if push_request:
149
149
  return response
150
+
151
+ return self.read_response(
152
+ disable_decoding=disable_decoding,
153
+ push_request=push_request,
154
+ )
150
155
  return response
151
156
 
152
157
  if disable_decoding:
@@ -169,12 +174,13 @@ class _HiredisParser(BaseParser, PushNotificationsParser):
169
174
  response, self._hiredis_PushNotificationType
170
175
  ):
171
176
  response = self.handle_push_response(response)
172
- if not push_request:
173
- return self.read_response(
174
- disable_decoding=disable_decoding, push_request=push_request
175
- )
176
- else:
177
+ if push_request:
177
178
  return response
179
+ return self.read_response(
180
+ disable_decoding=disable_decoding,
181
+ push_request=push_request,
182
+ )
183
+
178
184
  elif (
179
185
  isinstance(response, list)
180
186
  and response
redis/_parsers/resp3.py CHANGED
@@ -18,6 +18,8 @@ class _RESP3Parser(_RESPBase, PushNotificationsParser):
18
18
  def __init__(self, socket_read_size):
19
19
  super().__init__(socket_read_size)
20
20
  self.pubsub_push_handler_func = self.handle_pubsub_push_response
21
+ self.node_moving_push_handler_func = None
22
+ self.maintenance_push_handler_func = None
21
23
  self.invalidation_push_handler_func = None
22
24
 
23
25
  def handle_pubsub_push_response(self, response):
@@ -117,17 +119,21 @@ class _RESP3Parser(_RESPBase, PushNotificationsParser):
117
119
  for _ in range(int(response))
118
120
  ]
119
121
  response = self.handle_push_response(response)
120
- if not push_request:
121
- return self._read_response(
122
- disable_decoding=disable_decoding, push_request=push_request
123
- )
124
- else:
122
+
123
+ # if this is a push request return the push response
124
+ if push_request:
125
125
  return response
126
+
127
+ return self._read_response(
128
+ disable_decoding=disable_decoding,
129
+ push_request=push_request,
130
+ )
126
131
  else:
127
132
  raise InvalidResponse(f"Protocol Error: {raw!r}")
128
133
 
129
134
  if isinstance(response, bytes) and disable_decoding is False:
130
135
  response = self.encoder.decode(response)
136
+
131
137
  return response
132
138
 
133
139
 
redis/asyncio/client.py CHANGED
@@ -81,10 +81,11 @@ from redis.utils import (
81
81
  )
82
82
 
83
83
  if TYPE_CHECKING and SSL_AVAILABLE:
84
- from ssl import TLSVersion, VerifyMode
84
+ from ssl import TLSVersion, VerifyFlags, VerifyMode
85
85
  else:
86
86
  TLSVersion = None
87
87
  VerifyMode = None
88
+ VerifyFlags = None
88
89
 
89
90
  PubSubHandler = Callable[[Dict[str, str]], Awaitable[None]]
90
91
  _KeyT = TypeVar("_KeyT", bound=KeyT)
@@ -238,6 +239,8 @@ class Redis(
238
239
  ssl_keyfile: Optional[str] = None,
239
240
  ssl_certfile: Optional[str] = None,
240
241
  ssl_cert_reqs: Union[str, VerifyMode] = "required",
242
+ ssl_include_verify_flags: Optional[List[VerifyFlags]] = None,
243
+ ssl_exclude_verify_flags: Optional[List[VerifyFlags]] = None,
241
244
  ssl_ca_certs: Optional[str] = None,
242
245
  ssl_ca_data: Optional[str] = None,
243
246
  ssl_check_hostname: bool = True,
@@ -347,6 +350,8 @@ class Redis(
347
350
  "ssl_keyfile": ssl_keyfile,
348
351
  "ssl_certfile": ssl_certfile,
349
352
  "ssl_cert_reqs": ssl_cert_reqs,
353
+ "ssl_include_verify_flags": ssl_include_verify_flags,
354
+ "ssl_exclude_verify_flags": ssl_exclude_verify_flags,
350
355
  "ssl_ca_certs": ssl_ca_certs,
351
356
  "ssl_ca_data": ssl_ca_data,
352
357
  "ssl_check_hostname": ssl_check_hostname,
@@ -387,6 +392,12 @@ class Redis(
387
392
  # on a set of redis commands
388
393
  self._single_conn_lock = asyncio.Lock()
389
394
 
395
+ # When used as an async context manager, we need to increment and decrement
396
+ # a usage counter so that we can close the connection pool when no one is
397
+ # using the client.
398
+ self._usage_counter = 0
399
+ self._usage_lock = asyncio.Lock()
400
+
390
401
  def __repr__(self):
391
402
  return (
392
403
  f"<{self.__class__.__module__}.{self.__class__.__name__}"
@@ -594,10 +605,47 @@ class Redis(
594
605
  )
595
606
 
596
607
  async def __aenter__(self: _RedisT) -> _RedisT:
597
- return await self.initialize()
608
+ """
609
+ Async context manager entry. Increments a usage counter so that the
610
+ connection pool is only closed (via aclose()) when no context is using
611
+ the client.
612
+ """
613
+ await self._increment_usage()
614
+ try:
615
+ # Initialize the client (i.e. establish connection, etc.)
616
+ return await self.initialize()
617
+ except Exception:
618
+ # If initialization fails, decrement the counter to keep it in sync
619
+ await self._decrement_usage()
620
+ raise
621
+
622
+ async def _increment_usage(self) -> int:
623
+ """
624
+ Helper coroutine to increment the usage counter while holding the lock.
625
+ Returns the new value of the usage counter.
626
+ """
627
+ async with self._usage_lock:
628
+ self._usage_counter += 1
629
+ return self._usage_counter
630
+
631
+ async def _decrement_usage(self) -> int:
632
+ """
633
+ Helper coroutine to decrement the usage counter while holding the lock.
634
+ Returns the new value of the usage counter.
635
+ """
636
+ async with self._usage_lock:
637
+ self._usage_counter -= 1
638
+ return self._usage_counter
598
639
 
599
640
  async def __aexit__(self, exc_type, exc_value, traceback):
600
- await self.aclose()
641
+ """
642
+ Async context manager exit. Decrements a usage counter. If this is the
643
+ last exit (counter becomes zero), the client closes its connection pool.
644
+ """
645
+ current_usage = await asyncio.shield(self._decrement_usage())
646
+ if current_usage == 0:
647
+ # This was the last active context, so disconnect the pool.
648
+ await asyncio.shield(self.aclose())
601
649
 
602
650
  _DEL_MESSAGE = "Unclosed Redis client"
603
651
 
@@ -1113,9 +1161,12 @@ class PubSub:
1113
1161
  return await self.handle_message(response, ignore_subscribe_messages)
1114
1162
  return None
1115
1163
 
1116
- def ping(self, message=None) -> Awaitable:
1164
+ def ping(self, message=None) -> Awaitable[bool]:
1117
1165
  """
1118
- Ping the Redis server
1166
+ Ping the Redis server to test connectivity.
1167
+
1168
+ Sends a PING command to the Redis server and returns True if the server
1169
+ responds with "PONG".
1119
1170
  """
1120
1171
  args = ["PING", message] if message is not None else ["PING"]
1121
1172
  return self.execute_command(*args)
@@ -1191,6 +1242,7 @@ class PubSub:
1191
1242
  *,
1192
1243
  exception_handler: Optional["PSWorkerThreadExcHandlerT"] = None,
1193
1244
  poll_timeout: float = 1.0,
1245
+ pubsub=None,
1194
1246
  ) -> None:
1195
1247
  """Process pub/sub messages using registered callbacks.
1196
1248
 
@@ -1215,9 +1267,14 @@ class PubSub:
1215
1267
  await self.connect()
1216
1268
  while True:
1217
1269
  try:
1218
- await self.get_message(
1219
- ignore_subscribe_messages=True, timeout=poll_timeout
1220
- )
1270
+ if pubsub is None:
1271
+ await self.get_message(
1272
+ ignore_subscribe_messages=True, timeout=poll_timeout
1273
+ )
1274
+ else:
1275
+ await pubsub.get_message(
1276
+ ignore_subscribe_messages=True, timeout=poll_timeout
1277
+ )
1221
1278
  except asyncio.CancelledError:
1222
1279
  raise
1223
1280
  except BaseException as e: