redis 7.0.0b2__py3-none-any.whl → 7.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. redis/__init__.py +1 -1
  2. redis/_parsers/base.py +6 -0
  3. redis/_parsers/helpers.py +64 -6
  4. redis/asyncio/client.py +14 -5
  5. redis/asyncio/cluster.py +5 -1
  6. redis/asyncio/connection.py +19 -1
  7. redis/asyncio/http/__init__.py +0 -0
  8. redis/asyncio/http/http_client.py +265 -0
  9. redis/asyncio/multidb/__init__.py +0 -0
  10. redis/asyncio/multidb/client.py +530 -0
  11. redis/asyncio/multidb/command_executor.py +339 -0
  12. redis/asyncio/multidb/config.py +210 -0
  13. redis/asyncio/multidb/database.py +69 -0
  14. redis/asyncio/multidb/event.py +84 -0
  15. redis/asyncio/multidb/failover.py +125 -0
  16. redis/asyncio/multidb/failure_detector.py +38 -0
  17. redis/asyncio/multidb/healthcheck.py +285 -0
  18. redis/background.py +204 -0
  19. redis/client.py +49 -27
  20. redis/cluster.py +9 -1
  21. redis/commands/core.py +64 -29
  22. redis/commands/json/commands.py +2 -2
  23. redis/commands/search/__init__.py +2 -2
  24. redis/commands/search/aggregation.py +24 -26
  25. redis/commands/search/commands.py +10 -10
  26. redis/commands/search/field.py +2 -2
  27. redis/commands/search/query.py +12 -12
  28. redis/connection.py +1613 -1263
  29. redis/data_structure.py +81 -0
  30. redis/event.py +84 -10
  31. redis/exceptions.py +8 -0
  32. redis/http/__init__.py +0 -0
  33. redis/http/http_client.py +425 -0
  34. redis/maint_notifications.py +18 -7
  35. redis/multidb/__init__.py +0 -0
  36. redis/multidb/circuit.py +144 -0
  37. redis/multidb/client.py +526 -0
  38. redis/multidb/command_executor.py +350 -0
  39. redis/multidb/config.py +207 -0
  40. redis/multidb/database.py +130 -0
  41. redis/multidb/event.py +89 -0
  42. redis/multidb/exception.py +17 -0
  43. redis/multidb/failover.py +125 -0
  44. redis/multidb/failure_detector.py +104 -0
  45. redis/multidb/healthcheck.py +282 -0
  46. redis/retry.py +14 -1
  47. redis/utils.py +34 -0
  48. {redis-7.0.0b2.dist-info → redis-7.0.1.dist-info}/METADATA +17 -4
  49. {redis-7.0.0b2.dist-info → redis-7.0.1.dist-info}/RECORD +51 -25
  50. {redis-7.0.0b2.dist-info → redis-7.0.1.dist-info}/WHEEL +0 -0
  51. {redis-7.0.0b2.dist-info → redis-7.0.1.dist-info}/licenses/LICENSE +0 -0
redis/client.py CHANGED
@@ -58,7 +58,6 @@ from redis.exceptions import (
58
58
  from redis.lock import Lock
59
59
  from redis.maint_notifications import (
60
60
  MaintNotificationsConfig,
61
- MaintNotificationsPoolHandler,
62
61
  )
63
62
  from redis.retry import Retry
64
63
  from redis.utils import (
@@ -278,6 +277,17 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands):
278
277
  single_connection_client:
279
278
  if `True`, connection pool is not used. In that case `Redis`
280
279
  instance use is not thread safe.
280
+ decode_responses:
281
+ if `True`, the response will be decoded to utf-8.
282
+ Argument is ignored when connection_pool is provided.
283
+ maint_notifications_config:
284
+ configuration the pool to support maintenance notifications - see
285
+ `redis.maint_notifications.MaintNotificationsConfig` for details.
286
+ Only supported with RESP3
287
+ If not provided and protocol is RESP3, the maintenance notifications
288
+ will be enabled by default (logic is included in the connection pool
289
+ initialization).
290
+ Argument is ignored when connection_pool is provided.
281
291
  """
282
292
  if event_dispatcher is None:
283
293
  self._event_dispatcher = EventDispatcher()
@@ -354,6 +364,22 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands):
354
364
  "cache_config": cache_config,
355
365
  }
356
366
  )
367
+ maint_notifications_enabled = (
368
+ maint_notifications_config and maint_notifications_config.enabled
369
+ )
370
+ if maint_notifications_enabled and protocol not in [
371
+ 3,
372
+ "3",
373
+ ]:
374
+ raise RedisError(
375
+ "Maintenance notifications handlers on connection are only supported with RESP version 3"
376
+ )
377
+ if maint_notifications_config:
378
+ kwargs.update(
379
+ {
380
+ "maint_notifications_config": maint_notifications_config,
381
+ }
382
+ )
357
383
  connection_pool = ConnectionPool(**kwargs)
358
384
  self._event_dispatcher.dispatch(
359
385
  AfterPooledConnectionsInstantiationEvent(
@@ -377,23 +403,6 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands):
377
403
  ]:
378
404
  raise RedisError("Client caching is only supported with RESP version 3")
379
405
 
380
- if maint_notifications_config and self.connection_pool.get_protocol() not in [
381
- 3,
382
- "3",
383
- ]:
384
- raise RedisError(
385
- "Push handlers on connection are only supported with RESP version 3"
386
- )
387
- if maint_notifications_config and maint_notifications_config.enabled:
388
- self.maint_notifications_pool_handler = MaintNotificationsPoolHandler(
389
- self.connection_pool, maint_notifications_config
390
- )
391
- self.connection_pool.set_maint_notifications_pool_handler(
392
- self.maint_notifications_pool_handler
393
- )
394
- else:
395
- self.maint_notifications_pool_handler = None
396
-
397
406
  self.single_connection_lock = threading.RLock()
398
407
  self.connection = None
399
408
  self._single_connection_client = single_connection_client
@@ -591,15 +600,9 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands):
591
600
  return Monitor(self.connection_pool)
592
601
 
593
602
  def client(self):
594
- maint_notifications_config = (
595
- None
596
- if self.maint_notifications_pool_handler is None
597
- else self.maint_notifications_pool_handler.config
598
- )
599
603
  return self.__class__(
600
604
  connection_pool=self.connection_pool,
601
605
  single_connection_client=True,
602
- maint_notifications_config=maint_notifications_config,
603
606
  )
604
607
 
605
608
  def __enter__(self):
@@ -1186,7 +1189,10 @@ class PubSub:
1186
1189
 
1187
1190
  def ping(self, message: Union[str, None] = None) -> bool:
1188
1191
  """
1189
- Ping the Redis server
1192
+ Ping the Redis server to test connectivity.
1193
+
1194
+ Sends a PING command to the Redis server and returns True if the server
1195
+ responds with "PONG".
1190
1196
  """
1191
1197
  args = ["PING", message] if message is not None else ["PING"]
1192
1198
  return self.execute_command(*args)
@@ -1271,6 +1277,8 @@ class PubSub:
1271
1277
  sleep_time: float = 0.0,
1272
1278
  daemon: bool = False,
1273
1279
  exception_handler: Optional[Callable] = None,
1280
+ pubsub=None,
1281
+ sharded_pubsub: bool = False,
1274
1282
  ) -> "PubSubWorkerThread":
1275
1283
  for channel, handler in self.channels.items():
1276
1284
  if handler is None:
@@ -1284,8 +1292,13 @@ class PubSub:
1284
1292
  f"Shard Channel: '{s_channel}' has no handler registered"
1285
1293
  )
1286
1294
 
1295
+ pubsub = self if pubsub is None else pubsub
1287
1296
  thread = PubSubWorkerThread(
1288
- self, sleep_time, daemon=daemon, exception_handler=exception_handler
1297
+ pubsub,
1298
+ sleep_time,
1299
+ daemon=daemon,
1300
+ exception_handler=exception_handler,
1301
+ sharded_pubsub=sharded_pubsub,
1289
1302
  )
1290
1303
  thread.start()
1291
1304
  return thread
@@ -1300,12 +1313,14 @@ class PubSubWorkerThread(threading.Thread):
1300
1313
  exception_handler: Union[
1301
1314
  Callable[[Exception, "PubSub", "PubSubWorkerThread"], None], None
1302
1315
  ] = None,
1316
+ sharded_pubsub: bool = False,
1303
1317
  ):
1304
1318
  super().__init__()
1305
1319
  self.daemon = daemon
1306
1320
  self.pubsub = pubsub
1307
1321
  self.sleep_time = sleep_time
1308
1322
  self.exception_handler = exception_handler
1323
+ self.sharded_pubsub = sharded_pubsub
1309
1324
  self._running = threading.Event()
1310
1325
 
1311
1326
  def run(self) -> None:
@@ -1316,7 +1331,14 @@ class PubSubWorkerThread(threading.Thread):
1316
1331
  sleep_time = self.sleep_time
1317
1332
  while self._running.is_set():
1318
1333
  try:
1319
- pubsub.get_message(ignore_subscribe_messages=True, timeout=sleep_time)
1334
+ if not self.sharded_pubsub:
1335
+ pubsub.get_message(
1336
+ ignore_subscribe_messages=True, timeout=sleep_time
1337
+ )
1338
+ else:
1339
+ pubsub.get_sharded_message(
1340
+ ignore_subscribe_messages=True, timeout=sleep_time
1341
+ )
1320
1342
  except BaseException as e:
1321
1343
  if self.exception_handler is None:
1322
1344
  raise
redis/cluster.py CHANGED
@@ -50,6 +50,7 @@ from redis.exceptions import (
50
50
  WatchError,
51
51
  )
52
52
  from redis.lock import Lock
53
+ from redis.maint_notifications import MaintNotificationsConfig
53
54
  from redis.retry import Retry
54
55
  from redis.utils import (
55
56
  deprecated_args,
@@ -695,6 +696,7 @@ class RedisCluster(AbstractRedisCluster, RedisClusterCommands):
695
696
  self._event_dispatcher = EventDispatcher()
696
697
  else:
697
698
  self._event_dispatcher = event_dispatcher
699
+ self.startup_nodes = startup_nodes
698
700
  self.nodes_manager = NodesManager(
699
701
  startup_nodes=startup_nodes,
700
702
  from_url=from_url,
@@ -1662,6 +1664,11 @@ class NodesManager:
1662
1664
  backoff=NoBackoff(), retries=0, supported_errors=(ConnectionError,)
1663
1665
  )
1664
1666
 
1667
+ protocol = kwargs.get("protocol", None)
1668
+ if protocol in [3, "3"]:
1669
+ kwargs.update(
1670
+ {"maint_notifications_config": MaintNotificationsConfig(enabled=False)}
1671
+ )
1665
1672
  if self.from_url:
1666
1673
  # Create a redis node with a costumed connection pool
1667
1674
  kwargs.update({"host": host})
@@ -3164,7 +3171,8 @@ class TransactionStrategy(AbstractStrategy):
3164
3171
  self._nodes_manager.initialize()
3165
3172
  self.reinitialize_counter = 0
3166
3173
  else:
3167
- self._nodes_manager.update_moved_exception(error)
3174
+ if isinstance(error, AskError):
3175
+ self._nodes_manager.update_moved_exception(error)
3168
3176
 
3169
3177
  self._executing = False
3170
3178
 
redis/commands/core.py CHANGED
@@ -830,7 +830,7 @@ class ManagementCommands(CommandsProtocol):
830
830
 
831
831
  return self.execute_command("COMMAND LIST", *pieces)
832
832
 
833
- def command_getkeysandflags(self, *args: List[str]) -> List[Union[str, List[str]]]:
833
+ def command_getkeysandflags(self, *args: str) -> List[Union[str, List[str]]]:
834
834
  """
835
835
  Returns array of keys from a full Redis command and their usage flags.
836
836
 
@@ -848,7 +848,7 @@ class ManagementCommands(CommandsProtocol):
848
848
  )
849
849
 
850
850
  def config_get(
851
- self, pattern: PatternT = "*", *args: List[PatternT], **kwargs
851
+ self, pattern: PatternT = "*", *args: PatternT, **kwargs
852
852
  ) -> ResponseT:
853
853
  """
854
854
  Return a dictionary of configuration based on the ``pattern``
@@ -861,7 +861,7 @@ class ManagementCommands(CommandsProtocol):
861
861
  self,
862
862
  name: KeyT,
863
863
  value: EncodableT,
864
- *args: List[Union[KeyT, EncodableT]],
864
+ *args: Union[KeyT, EncodableT],
865
865
  **kwargs,
866
866
  ) -> ResponseT:
867
867
  """Set config item ``name`` with ``value``
@@ -987,9 +987,7 @@ class ManagementCommands(CommandsProtocol):
987
987
  """
988
988
  return self.execute_command("SELECT", index, **kwargs)
989
989
 
990
- def info(
991
- self, section: Optional[str] = None, *args: List[str], **kwargs
992
- ) -> ResponseT:
990
+ def info(self, section: Optional[str] = None, *args: str, **kwargs) -> ResponseT:
993
991
  """
994
992
  Returns a dictionary containing information about the Redis server
995
993
 
@@ -1210,11 +1208,18 @@ class ManagementCommands(CommandsProtocol):
1210
1208
  """
1211
1209
  return self.execute_command("LATENCY RESET", *events)
1212
1210
 
1213
- def ping(self, **kwargs) -> ResponseT:
1211
+ def ping(self, **kwargs) -> Union[Awaitable[bool], bool]:
1214
1212
  """
1215
- Ping the Redis server
1213
+ Ping the Redis server to test connectivity.
1214
+
1215
+ Sends a PING command to the Redis server and returns True if the server
1216
+ responds with "PONG".
1216
1217
 
1217
- For more information, see https://redis.io/commands/ping
1218
+ This command is useful for:
1219
+ - Testing whether a connection is still alive
1220
+ - Verifying the server's ability to serve data
1221
+
1222
+ For more information on the underlying ping command see https://redis.io/commands/ping
1218
1223
  """
1219
1224
  return self.execute_command("PING", **kwargs)
1220
1225
 
@@ -2599,7 +2604,7 @@ class ListCommands(CommandsProtocol):
2599
2604
  self,
2600
2605
  timeout: float,
2601
2606
  numkeys: int,
2602
- *args: List[str],
2607
+ *args: str,
2603
2608
  direction: str,
2604
2609
  count: Optional[int] = 1,
2605
2610
  ) -> Optional[list]:
@@ -2612,14 +2617,14 @@ class ListCommands(CommandsProtocol):
2612
2617
 
2613
2618
  For more information, see https://redis.io/commands/blmpop
2614
2619
  """
2615
- args = [timeout, numkeys, *args, direction, "COUNT", count]
2620
+ cmd_args = [timeout, numkeys, *args, direction, "COUNT", count]
2616
2621
 
2617
- return self.execute_command("BLMPOP", *args)
2622
+ return self.execute_command("BLMPOP", *cmd_args)
2618
2623
 
2619
2624
  def lmpop(
2620
2625
  self,
2621
2626
  num_keys: int,
2622
- *args: List[str],
2627
+ *args: str,
2623
2628
  direction: str,
2624
2629
  count: Optional[int] = 1,
2625
2630
  ) -> Union[Awaitable[list], list]:
@@ -2629,11 +2634,11 @@ class ListCommands(CommandsProtocol):
2629
2634
 
2630
2635
  For more information, see https://redis.io/commands/lmpop
2631
2636
  """
2632
- args = [num_keys] + list(args) + [direction]
2637
+ cmd_args = [num_keys] + list(args) + [direction]
2633
2638
  if count != 1:
2634
- args.extend(["COUNT", count])
2639
+ cmd_args.extend(["COUNT", count])
2635
2640
 
2636
- return self.execute_command("LMPOP", *args)
2641
+ return self.execute_command("LMPOP", *cmd_args)
2637
2642
 
2638
2643
  def lindex(
2639
2644
  self, name: str, index: int
@@ -4771,6 +4776,7 @@ class SortedSetCommands(CommandsProtocol):
4771
4776
  name: KeyT,
4772
4777
  value: EncodableT,
4773
4778
  withscore: bool = False,
4779
+ score_cast_func: Union[type, Callable] = float,
4774
4780
  ) -> ResponseT:
4775
4781
  """
4776
4782
  Returns a 0-based value indicating the rank of ``value`` in sorted set
@@ -4778,11 +4784,17 @@ class SortedSetCommands(CommandsProtocol):
4778
4784
  The optional WITHSCORE argument supplements the command's
4779
4785
  reply with the score of the element returned.
4780
4786
 
4787
+ ``score_cast_func`` a callable used to cast the score return value
4788
+
4781
4789
  For more information, see https://redis.io/commands/zrank
4782
4790
  """
4791
+ pieces = ["ZRANK", name, value]
4783
4792
  if withscore:
4784
- return self.execute_command("ZRANK", name, value, "WITHSCORE", keys=[name])
4785
- return self.execute_command("ZRANK", name, value, keys=[name])
4793
+ pieces.append("WITHSCORE")
4794
+
4795
+ options = {"withscore": withscore, "score_cast_func": score_cast_func}
4796
+
4797
+ return self.execute_command(*pieces, **options)
4786
4798
 
4787
4799
  def zrem(self, name: KeyT, *values: FieldT) -> ResponseT:
4788
4800
  """
@@ -4830,6 +4842,7 @@ class SortedSetCommands(CommandsProtocol):
4830
4842
  name: KeyT,
4831
4843
  value: EncodableT,
4832
4844
  withscore: bool = False,
4845
+ score_cast_func: Union[type, Callable] = float,
4833
4846
  ) -> ResponseT:
4834
4847
  """
4835
4848
  Returns a 0-based value indicating the descending rank of
@@ -4837,13 +4850,17 @@ class SortedSetCommands(CommandsProtocol):
4837
4850
  The optional ``withscore`` argument supplements the command's
4838
4851
  reply with the score of the element returned.
4839
4852
 
4853
+ ``score_cast_func`` a callable used to cast the score return value
4854
+
4840
4855
  For more information, see https://redis.io/commands/zrevrank
4841
4856
  """
4857
+ pieces = ["ZREVRANK", name, value]
4842
4858
  if withscore:
4843
- return self.execute_command(
4844
- "ZREVRANK", name, value, "WITHSCORE", keys=[name]
4845
- )
4846
- return self.execute_command("ZREVRANK", name, value, keys=[name])
4859
+ pieces.append("WITHSCORE")
4860
+
4861
+ options = {"withscore": withscore, "score_cast_func": score_cast_func}
4862
+
4863
+ return self.execute_command(*pieces, **options)
4847
4864
 
4848
4865
  def zscore(self, name: KeyT, value: EncodableT) -> ResponseT:
4849
4866
  """
@@ -4858,6 +4875,7 @@ class SortedSetCommands(CommandsProtocol):
4858
4875
  keys: Union[Sequence[KeyT], Mapping[AnyKeyT, float]],
4859
4876
  aggregate: Optional[str] = None,
4860
4877
  withscores: bool = False,
4878
+ score_cast_func: Union[type, Callable] = float,
4861
4879
  ) -> ResponseT:
4862
4880
  """
4863
4881
  Return the union of multiple sorted sets specified by ``keys``.
@@ -4865,9 +4883,18 @@ class SortedSetCommands(CommandsProtocol):
4865
4883
  Scores will be aggregated based on the ``aggregate``, or SUM if
4866
4884
  none is provided.
4867
4885
 
4886
+ ``score_cast_func`` a callable used to cast the score return value
4887
+
4868
4888
  For more information, see https://redis.io/commands/zunion
4869
4889
  """
4870
- return self._zaggregate("ZUNION", None, keys, aggregate, withscores=withscores)
4890
+ return self._zaggregate(
4891
+ "ZUNION",
4892
+ None,
4893
+ keys,
4894
+ aggregate,
4895
+ withscores=withscores,
4896
+ score_cast_func=score_cast_func,
4897
+ )
4871
4898
 
4872
4899
  def zunionstore(
4873
4900
  self,
@@ -5856,12 +5883,16 @@ class ScriptCommands(CommandsProtocol):
5856
5883
  """
5857
5884
 
5858
5885
  def _eval(
5859
- self, command: str, script: str, numkeys: int, *keys_and_args: str
5886
+ self,
5887
+ command: str,
5888
+ script: str,
5889
+ numkeys: int,
5890
+ *keys_and_args: Union[KeyT, EncodableT],
5860
5891
  ) -> Union[Awaitable[str], str]:
5861
5892
  return self.execute_command(command, script, numkeys, *keys_and_args)
5862
5893
 
5863
5894
  def eval(
5864
- self, script: str, numkeys: int, *keys_and_args: str
5895
+ self, script: str, numkeys: int, *keys_and_args: Union[KeyT, EncodableT]
5865
5896
  ) -> Union[Awaitable[str], str]:
5866
5897
  """
5867
5898
  Execute the Lua ``script``, specifying the ``numkeys`` the script
@@ -5876,7 +5907,7 @@ class ScriptCommands(CommandsProtocol):
5876
5907
  return self._eval("EVAL", script, numkeys, *keys_and_args)
5877
5908
 
5878
5909
  def eval_ro(
5879
- self, script: str, numkeys: int, *keys_and_args: str
5910
+ self, script: str, numkeys: int, *keys_and_args: Union[KeyT, EncodableT]
5880
5911
  ) -> Union[Awaitable[str], str]:
5881
5912
  """
5882
5913
  The read-only variant of the EVAL command
@@ -5890,12 +5921,16 @@ class ScriptCommands(CommandsProtocol):
5890
5921
  return self._eval("EVAL_RO", script, numkeys, *keys_and_args)
5891
5922
 
5892
5923
  def _evalsha(
5893
- self, command: str, sha: str, numkeys: int, *keys_and_args: list
5924
+ self,
5925
+ command: str,
5926
+ sha: str,
5927
+ numkeys: int,
5928
+ *keys_and_args: Union[KeyT, EncodableT],
5894
5929
  ) -> Union[Awaitable[str], str]:
5895
5930
  return self.execute_command(command, sha, numkeys, *keys_and_args)
5896
5931
 
5897
5932
  def evalsha(
5898
- self, sha: str, numkeys: int, *keys_and_args: str
5933
+ self, sha: str, numkeys: int, *keys_and_args: Union[KeyT, EncodableT]
5899
5934
  ) -> Union[Awaitable[str], str]:
5900
5935
  """
5901
5936
  Use the ``sha`` to execute a Lua script already registered via EVAL
@@ -5911,7 +5946,7 @@ class ScriptCommands(CommandsProtocol):
5911
5946
  return self._evalsha("EVALSHA", sha, numkeys, *keys_and_args)
5912
5947
 
5913
5948
  def evalsha_ro(
5914
- self, sha: str, numkeys: int, *keys_and_args: str
5949
+ self, sha: str, numkeys: int, *keys_and_args: Union[KeyT, EncodableT]
5915
5950
  ) -> Union[Awaitable[str], str]:
5916
5951
  """
5917
5952
  The read-only variant of the EVALSHA command
@@ -14,7 +14,7 @@ class JSONCommands:
14
14
  """json commands."""
15
15
 
16
16
  def arrappend(
17
- self, name: str, path: Optional[str] = Path.root_path(), *args: List[JsonType]
17
+ self, name: str, path: Optional[str] = Path.root_path(), *args: JsonType
18
18
  ) -> List[Optional[int]]:
19
19
  """Append the objects ``args`` to the array under the
20
20
  ``path` in key ``name``.
@@ -52,7 +52,7 @@ class JSONCommands:
52
52
  return self.execute_command("JSON.ARRINDEX", *pieces, keys=[name])
53
53
 
54
54
  def arrinsert(
55
- self, name: str, path: str, index: int, *args: List[JsonType]
55
+ self, name: str, path: str, index: int, *args: JsonType
56
56
  ) -> List[Optional[int]]:
57
57
  """Insert the objects ``args`` to the array at index ``index``
58
58
  under the ``path` in key ``name``.
@@ -1,4 +1,4 @@
1
- import redis
1
+ from redis.client import Pipeline as RedisPipeline
2
2
 
3
3
  from ...asyncio.client import Pipeline as AsyncioPipeline
4
4
  from .commands import (
@@ -181,7 +181,7 @@ class AsyncSearch(Search, AsyncSearchCommands):
181
181
  return p
182
182
 
183
183
 
184
- class Pipeline(SearchCommands, redis.client.Pipeline):
184
+ class Pipeline(SearchCommands, RedisPipeline):
185
185
  """Pipeline for the module."""
186
186
 
187
187
 
@@ -1,4 +1,4 @@
1
- from typing import List, Union
1
+ from typing import List, Optional, Tuple, Union
2
2
 
3
3
  from redis.commands.search.dialect import DEFAULT_DIALECT
4
4
 
@@ -27,9 +27,9 @@ class Reducer:
27
27
  NAME = None
28
28
 
29
29
  def __init__(self, *args: str) -> None:
30
- self._args = args
31
- self._field = None
32
- self._alias = None
30
+ self._args: Tuple[str, ...] = args
31
+ self._field: Optional[str] = None
32
+ self._alias: Optional[str] = None
33
33
 
34
34
  def alias(self, alias: str) -> "Reducer":
35
35
  """
@@ -49,13 +49,14 @@ class Reducer:
49
49
  if alias is FIELDNAME:
50
50
  if not self._field:
51
51
  raise ValueError("Cannot use FIELDNAME alias with no field")
52
- # Chop off initial '@'
53
- alias = self._field[1:]
52
+ else:
53
+ # Chop off initial '@'
54
+ alias = self._field[1:]
54
55
  self._alias = alias
55
56
  return self
56
57
 
57
58
  @property
58
- def args(self) -> List[str]:
59
+ def args(self) -> Tuple[str, ...]:
59
60
  return self._args
60
61
 
61
62
 
@@ -64,7 +65,7 @@ class SortDirection:
64
65
  This special class is used to indicate sort direction.
65
66
  """
66
67
 
67
- DIRSTRING = None
68
+ DIRSTRING: Optional[str] = None
68
69
 
69
70
  def __init__(self, field: str) -> None:
70
71
  self.field = field
@@ -104,17 +105,17 @@ class AggregateRequest:
104
105
  All member methods (except `build_args()`)
105
106
  return the object itself, making them useful for chaining.
106
107
  """
107
- self._query = query
108
- self._aggregateplan = []
109
- self._loadfields = []
110
- self._loadall = False
111
- self._max = 0
112
- self._with_schema = False
113
- self._verbatim = False
114
- self._cursor = []
115
- self._dialect = DEFAULT_DIALECT
116
- self._add_scores = False
117
- self._scorer = "TFIDF"
108
+ self._query: str = query
109
+ self._aggregateplan: List[str] = []
110
+ self._loadfields: List[str] = []
111
+ self._loadall: bool = False
112
+ self._max: int = 0
113
+ self._with_schema: bool = False
114
+ self._verbatim: bool = False
115
+ self._cursor: List[str] = []
116
+ self._dialect: int = DEFAULT_DIALECT
117
+ self._add_scores: bool = False
118
+ self._scorer: str = "TFIDF"
118
119
 
119
120
  def load(self, *fields: str) -> "AggregateRequest":
120
121
  """
@@ -133,7 +134,7 @@ class AggregateRequest:
133
134
  return self
134
135
 
135
136
  def group_by(
136
- self, fields: List[str], *reducers: Union[Reducer, List[Reducer]]
137
+ self, fields: Union[str, List[str]], *reducers: Reducer
137
138
  ) -> "AggregateRequest":
138
139
  """
139
140
  Specify by which fields to group the aggregation.
@@ -147,7 +148,6 @@ class AggregateRequest:
147
148
  `aggregation` module.
148
149
  """
149
150
  fields = [fields] if isinstance(fields, str) else fields
150
- reducers = [reducers] if isinstance(reducers, Reducer) else reducers
151
151
 
152
152
  ret = ["GROUPBY", str(len(fields)), *fields]
153
153
  for reducer in reducers:
@@ -251,12 +251,10 @@ class AggregateRequest:
251
251
  .sort_by(Desc("@paid"), max=10)
252
252
  ```
253
253
  """
254
- if isinstance(fields, (str, SortDirection)):
255
- fields = [fields]
256
254
 
257
255
  fields_args = []
258
256
  for f in fields:
259
- if isinstance(f, SortDirection):
257
+ if isinstance(f, (Asc, Desc)):
260
258
  fields_args += [f.field, f.DIRSTRING]
261
259
  else:
262
260
  fields_args += [f]
@@ -356,7 +354,7 @@ class AggregateRequest:
356
354
  ret.extend(self._loadfields)
357
355
 
358
356
  if self._dialect:
359
- ret.extend(["DIALECT", self._dialect])
357
+ ret.extend(["DIALECT", str(self._dialect)])
360
358
 
361
359
  ret.extend(self._aggregateplan)
362
360
 
@@ -393,7 +391,7 @@ class AggregateResult:
393
391
  self.cursor = cursor
394
392
  self.schema = schema
395
393
 
396
- def __repr__(self) -> (str, str):
394
+ def __repr__(self) -> str:
397
395
  cid = self.cursor.cid if self.cursor else -1
398
396
  return (
399
397
  f"<{self.__class__.__name__} at 0x{id(self):x} "
@@ -221,7 +221,7 @@ class SearchCommands:
221
221
 
222
222
  return self.execute_command(*args)
223
223
 
224
- def alter_schema_add(self, fields: List[str]):
224
+ def alter_schema_add(self, fields: Union[Field, List[Field]]):
225
225
  """
226
226
  Alter the existing search index by adding new fields. The index
227
227
  must already exist.
@@ -336,11 +336,11 @@ class SearchCommands:
336
336
  doc_id: str,
337
337
  nosave: bool = False,
338
338
  score: float = 1.0,
339
- payload: bool = None,
339
+ payload: Optional[bool] = None,
340
340
  replace: bool = False,
341
341
  partial: bool = False,
342
342
  language: Optional[str] = None,
343
- no_create: str = False,
343
+ no_create: bool = False,
344
344
  **fields: List[str],
345
345
  ):
346
346
  """
@@ -464,7 +464,7 @@ class SearchCommands:
464
464
  return self._parse_results(INFO_CMD, res)
465
465
 
466
466
  def get_params_args(
467
- self, query_params: Union[Dict[str, Union[str, int, float, bytes]], None]
467
+ self, query_params: Optional[Dict[str, Union[str, int, float, bytes]]]
468
468
  ):
469
469
  if query_params is None:
470
470
  return []
@@ -478,7 +478,7 @@ class SearchCommands:
478
478
  return args
479
479
 
480
480
  def _mk_query_args(
481
- self, query, query_params: Union[Dict[str, Union[str, int, float, bytes]], None]
481
+ self, query, query_params: Optional[Dict[str, Union[str, int, float, bytes]]]
482
482
  ):
483
483
  args = [self.index_name]
484
484
 
@@ -528,7 +528,7 @@ class SearchCommands:
528
528
  def explain(
529
529
  self,
530
530
  query: Union[str, Query],
531
- query_params: Dict[str, Union[str, int, float]] = None,
531
+ query_params: Optional[Dict[str, Union[str, int, float, bytes]]] = None,
532
532
  ):
533
533
  """Returns the execution plan for a complex query.
534
534
 
@@ -543,7 +543,7 @@ class SearchCommands:
543
543
  def aggregate(
544
544
  self,
545
545
  query: Union[AggregateRequest, Cursor],
546
- query_params: Dict[str, Union[str, int, float]] = None,
546
+ query_params: Optional[Dict[str, Union[str, int, float, bytes]]] = None,
547
547
  ):
548
548
  """
549
549
  Issue an aggregation query.
@@ -598,7 +598,7 @@ class SearchCommands:
598
598
  self,
599
599
  query: Union[Query, AggregateRequest],
600
600
  limited: bool = False,
601
- query_params: Optional[Dict[str, Union[str, int, float]]] = None,
601
+ query_params: Optional[Dict[str, Union[str, int, float, bytes]]] = None,
602
602
  ):
603
603
  """
604
604
  Performs a search or aggregate command and collects performance
@@ -936,7 +936,7 @@ class AsyncSearchCommands(SearchCommands):
936
936
  async def search(
937
937
  self,
938
938
  query: Union[str, Query],
939
- query_params: Dict[str, Union[str, int, float]] = None,
939
+ query_params: Optional[Dict[str, Union[str, int, float, bytes]]] = None,
940
940
  ):
941
941
  """
942
942
  Search the index for a given query, and return a result of documents
@@ -968,7 +968,7 @@ class AsyncSearchCommands(SearchCommands):
968
968
  async def aggregate(
969
969
  self,
970
970
  query: Union[AggregateResult, Cursor],
971
- query_params: Dict[str, Union[str, int, float]] = None,
971
+ query_params: Optional[Dict[str, Union[str, int, float, bytes]]] = None,
972
972
  ):
973
973
  """
974
974
  Issue an aggregation query.