redis 6.1.1__py3-none-any.whl → 6.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
redis/asyncio/cluster.py CHANGED
@@ -2,16 +2,23 @@ import asyncio
2
2
  import collections
3
3
  import random
4
4
  import socket
5
+ import threading
6
+ import time
5
7
  import warnings
8
+ from abc import ABC, abstractmethod
9
+ from copy import copy
10
+ from itertools import chain
6
11
  from typing import (
7
12
  Any,
8
13
  Callable,
14
+ Coroutine,
9
15
  Deque,
10
16
  Dict,
11
17
  Generator,
12
18
  List,
13
19
  Mapping,
14
20
  Optional,
21
+ Set,
15
22
  Tuple,
16
23
  Type,
17
24
  TypeVar,
@@ -53,7 +60,10 @@ from redis.exceptions import (
53
60
  ClusterDownError,
54
61
  ClusterError,
55
62
  ConnectionError,
63
+ CrossSlotTransactionError,
56
64
  DataError,
65
+ ExecAbortError,
66
+ InvalidPipelineStack,
57
67
  MaxConnectionsError,
58
68
  MovedError,
59
69
  RedisClusterException,
@@ -62,6 +72,7 @@ from redis.exceptions import (
62
72
  SlotNotCoveredError,
63
73
  TimeoutError,
64
74
  TryAgainError,
75
+ WatchError,
65
76
  )
66
77
  from redis.typing import AnyKeyT, EncodableT, KeyT
67
78
  from redis.utils import (
@@ -134,6 +145,14 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
134
145
  | Enable read from replicas in READONLY mode and defines the load balancing
135
146
  strategy that will be used for cluster node selection.
136
147
  The data read from replicas is eventually consistent with the data in primary nodes.
148
+ :param dynamic_startup_nodes:
149
+ | Set the RedisCluster's startup nodes to all the discovered nodes.
150
+ If true (default value), the cluster's discovered nodes will be used to
151
+ determine the cluster nodes-slots mapping in the next topology refresh.
152
+ It will remove the initial passed startup nodes if their endpoints aren't
153
+ listed in the CLUSTER SLOTS output.
154
+ If you use dynamic DNS endpoints for startup nodes but CLUSTER SLOTS lists
155
+ specific IP addresses, it is best to set it to false.
137
156
  :param reinitialize_steps:
138
157
  | Specifies the number of MOVED errors that need to occur before reinitializing
139
158
  the whole cluster topology. If a MOVED error occurs and the cluster does not
@@ -250,6 +269,7 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
250
269
  require_full_coverage: bool = True,
251
270
  read_from_replicas: bool = False,
252
271
  load_balancing_strategy: Optional[LoadBalancingStrategy] = None,
272
+ dynamic_startup_nodes: bool = True,
253
273
  reinitialize_steps: int = 5,
254
274
  cluster_error_retry_attempts: int = 3,
255
275
  max_connections: int = 2**31,
@@ -388,6 +408,7 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
388
408
  startup_nodes,
389
409
  require_full_coverage,
390
410
  kwargs,
411
+ dynamic_startup_nodes=dynamic_startup_nodes,
391
412
  address_remap=address_remap,
392
413
  event_dispatcher=self._event_dispatcher,
393
414
  )
@@ -793,7 +814,13 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
793
814
  moved = False
794
815
 
795
816
  return await target_node.execute_command(*args, **kwargs)
796
- except (BusyLoadingError, MaxConnectionsError):
817
+ except BusyLoadingError:
818
+ raise
819
+ except MaxConnectionsError:
820
+ # MaxConnectionsError indicates client-side resource exhaustion
821
+ # (too many connections in the pool), not a node failure.
822
+ # Don't treat this as a node failure - just re-raise the error
823
+ # without reinitializing the cluster.
797
824
  raise
798
825
  except (ConnectionError, TimeoutError):
799
826
  # Connection retries are being handled in the node's
@@ -860,10 +887,7 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
860
887
  if shard_hint:
861
888
  raise RedisClusterException("shard_hint is deprecated in cluster mode")
862
889
 
863
- if transaction:
864
- raise RedisClusterException("transaction is deprecated in cluster mode")
865
-
866
- return ClusterPipeline(self)
890
+ return ClusterPipeline(self, transaction)
867
891
 
868
892
  def lock(
869
893
  self,
@@ -946,6 +970,30 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
946
970
  raise_on_release_error=raise_on_release_error,
947
971
  )
948
972
 
973
+ async def transaction(
974
+ self, func: Coroutine[None, "ClusterPipeline", Any], *watches, **kwargs
975
+ ):
976
+ """
977
+ Convenience method for executing the callable `func` as a transaction
978
+ while watching all keys specified in `watches`. The 'func' callable
979
+ should expect a single argument which is a Pipeline object.
980
+ """
981
+ shard_hint = kwargs.pop("shard_hint", None)
982
+ value_from_callable = kwargs.pop("value_from_callable", False)
983
+ watch_delay = kwargs.pop("watch_delay", None)
984
+ async with self.pipeline(True, shard_hint) as pipe:
985
+ while True:
986
+ try:
987
+ if watches:
988
+ await pipe.watch(*watches)
989
+ func_value = await func(pipe)
990
+ exec_value = await pipe.execute()
991
+ return func_value if value_from_callable else exec_value
992
+ except WatchError:
993
+ if watch_delay is not None and watch_delay > 0:
994
+ time.sleep(watch_delay)
995
+ continue
996
+
949
997
 
950
998
  class ClusterNode:
951
999
  """
@@ -1067,6 +1115,12 @@ class ClusterNode:
1067
1115
 
1068
1116
  raise MaxConnectionsError()
1069
1117
 
1118
+ def release(self, connection: Connection) -> None:
1119
+ """
1120
+ Release connection back to free queue.
1121
+ """
1122
+ self._free.append(connection)
1123
+
1070
1124
  async def parse_response(
1071
1125
  self, connection: Connection, command: str, **kwargs: Any
1072
1126
  ) -> Any:
@@ -1162,6 +1216,7 @@ class ClusterNode:
1162
1216
 
1163
1217
  class NodesManager:
1164
1218
  __slots__ = (
1219
+ "_dynamic_startup_nodes",
1165
1220
  "_moved_exception",
1166
1221
  "_event_dispatcher",
1167
1222
  "connection_kwargs",
@@ -1179,6 +1234,7 @@ class NodesManager:
1179
1234
  startup_nodes: List["ClusterNode"],
1180
1235
  require_full_coverage: bool,
1181
1236
  connection_kwargs: Dict[str, Any],
1237
+ dynamic_startup_nodes: bool = True,
1182
1238
  address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None,
1183
1239
  event_dispatcher: Optional[EventDispatcher] = None,
1184
1240
  ) -> None:
@@ -1191,6 +1247,8 @@ class NodesManager:
1191
1247
  self.nodes_cache: Dict[str, "ClusterNode"] = {}
1192
1248
  self.slots_cache: Dict[int, List["ClusterNode"]] = {}
1193
1249
  self.read_load_balancer = LoadBalancer()
1250
+
1251
+ self._dynamic_startup_nodes: bool = dynamic_startup_nodes
1194
1252
  self._moved_exception: MovedError = None
1195
1253
  if event_dispatcher is None:
1196
1254
  self._event_dispatcher = EventDispatcher()
@@ -1233,6 +1291,9 @@ class NodesManager:
1233
1291
  task = asyncio.create_task(old[name].disconnect()) # noqa
1234
1292
  old[name] = node
1235
1293
 
1294
+ def update_moved_exception(self, exception):
1295
+ self._moved_exception = exception
1296
+
1236
1297
  def _update_moved_slots(self) -> None:
1237
1298
  e = self._moved_exception
1238
1299
  redirected_node = self.get_node(host=e.host, port=e.port)
@@ -1433,8 +1494,10 @@ class NodesManager:
1433
1494
  # Set the tmp variables to the real variables
1434
1495
  self.slots_cache = tmp_slots
1435
1496
  self.set_nodes(self.nodes_cache, tmp_nodes_cache, remove_old=True)
1436
- # Populate the startup nodes with all discovered nodes
1437
- self.set_nodes(self.startup_nodes, self.nodes_cache, remove_old=True)
1497
+
1498
+ if self._dynamic_startup_nodes:
1499
+ # Populate the startup nodes with all discovered nodes
1500
+ self.set_nodes(self.startup_nodes, self.nodes_cache, remove_old=True)
1438
1501
 
1439
1502
  # Set the default node
1440
1503
  self.default_node = self.get_nodes_by_server_type(PRIMARY)[0]
@@ -1498,41 +1561,47 @@ class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterComm
1498
1561
  | Existing :class:`~.RedisCluster` client
1499
1562
  """
1500
1563
 
1501
- __slots__ = ("_command_stack", "_client")
1564
+ __slots__ = ("cluster_client", "_transaction", "_execution_strategy")
1502
1565
 
1503
- def __init__(self, client: RedisCluster) -> None:
1504
- self._client = client
1505
-
1506
- self._command_stack: List["PipelineCommand"] = []
1566
+ def __init__(
1567
+ self, client: RedisCluster, transaction: Optional[bool] = None
1568
+ ) -> None:
1569
+ self.cluster_client = client
1570
+ self._transaction = transaction
1571
+ self._execution_strategy: ExecutionStrategy = (
1572
+ PipelineStrategy(self)
1573
+ if not self._transaction
1574
+ else TransactionStrategy(self)
1575
+ )
1507
1576
 
1508
1577
  async def initialize(self) -> "ClusterPipeline":
1509
- if self._client._initialize:
1510
- await self._client.initialize()
1511
- self._command_stack = []
1578
+ await self._execution_strategy.initialize()
1512
1579
  return self
1513
1580
 
1514
1581
  async def __aenter__(self) -> "ClusterPipeline":
1515
1582
  return await self.initialize()
1516
1583
 
1517
1584
  async def __aexit__(self, exc_type: None, exc_value: None, traceback: None) -> None:
1518
- self._command_stack = []
1585
+ await self.reset()
1519
1586
 
1520
1587
  def __await__(self) -> Generator[Any, None, "ClusterPipeline"]:
1521
1588
  return self.initialize().__await__()
1522
1589
 
1523
1590
  def __enter__(self) -> "ClusterPipeline":
1524
- self._command_stack = []
1591
+ # TODO: Remove this method before 7.0.0
1592
+ self._execution_strategy._command_queue = []
1525
1593
  return self
1526
1594
 
1527
1595
  def __exit__(self, exc_type: None, exc_value: None, traceback: None) -> None:
1528
- self._command_stack = []
1596
+ # TODO: Remove this method before 7.0.0
1597
+ self._execution_strategy._command_queue = []
1529
1598
 
1530
1599
  def __bool__(self) -> bool:
1531
1600
  "Pipeline instances should always evaluate to True on Python 3+"
1532
1601
  return True
1533
1602
 
1534
1603
  def __len__(self) -> int:
1535
- return len(self._command_stack)
1604
+ return len(self._execution_strategy)
1536
1605
 
1537
1606
  def execute_command(
1538
1607
  self, *args: Union[KeyT, EncodableT], **kwargs: Any
@@ -1548,10 +1617,7 @@ class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterComm
1548
1617
  or List[:class:`~.ClusterNode`] or Dict[Any, :class:`~.ClusterNode`]
1549
1618
  - Rest of the kwargs are passed to the Redis connection
1550
1619
  """
1551
- self._command_stack.append(
1552
- PipelineCommand(len(self._command_stack), *args, **kwargs)
1553
- )
1554
- return self
1620
+ return self._execution_strategy.execute_command(*args, **kwargs)
1555
1621
 
1556
1622
  async def execute(
1557
1623
  self, raise_on_error: bool = True, allow_redirections: bool = True
@@ -1571,34 +1637,294 @@ class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterComm
1571
1637
  :raises RedisClusterException: if target_nodes is not provided & the command
1572
1638
  can't be mapped to a slot
1573
1639
  """
1574
- if not self._command_stack:
1640
+ try:
1641
+ return await self._execution_strategy.execute(
1642
+ raise_on_error, allow_redirections
1643
+ )
1644
+ finally:
1645
+ await self.reset()
1646
+
1647
+ def _split_command_across_slots(
1648
+ self, command: str, *keys: KeyT
1649
+ ) -> "ClusterPipeline":
1650
+ for slot_keys in self.cluster_client._partition_keys_by_slot(keys).values():
1651
+ self.execute_command(command, *slot_keys)
1652
+
1653
+ return self
1654
+
1655
+ async def reset(self):
1656
+ """
1657
+ Reset back to empty pipeline.
1658
+ """
1659
+ await self._execution_strategy.reset()
1660
+
1661
+ def multi(self):
1662
+ """
1663
+ Start a transactional block of the pipeline after WATCH commands
1664
+ are issued. End the transactional block with `execute`.
1665
+ """
1666
+ self._execution_strategy.multi()
1667
+
1668
+ async def discard(self):
1669
+ """ """
1670
+ await self._execution_strategy.discard()
1671
+
1672
+ async def watch(self, *names):
1673
+ """Watches the values at keys ``names``"""
1674
+ await self._execution_strategy.watch(*names)
1675
+
1676
+ async def unwatch(self):
1677
+ """Unwatches all previously specified keys"""
1678
+ await self._execution_strategy.unwatch()
1679
+
1680
+ async def unlink(self, *names):
1681
+ await self._execution_strategy.unlink(*names)
1682
+
1683
+ def mset_nonatomic(
1684
+ self, mapping: Mapping[AnyKeyT, EncodableT]
1685
+ ) -> "ClusterPipeline":
1686
+ return self._execution_strategy.mset_nonatomic(mapping)
1687
+
1688
+
1689
+ for command in PIPELINE_BLOCKED_COMMANDS:
1690
+ command = command.replace(" ", "_").lower()
1691
+ if command == "mset_nonatomic":
1692
+ continue
1693
+
1694
+ setattr(ClusterPipeline, command, block_pipeline_command(command))
1695
+
1696
+
1697
+ class PipelineCommand:
1698
+ def __init__(self, position: int, *args: Any, **kwargs: Any) -> None:
1699
+ self.args = args
1700
+ self.kwargs = kwargs
1701
+ self.position = position
1702
+ self.result: Union[Any, Exception] = None
1703
+
1704
+ def __repr__(self) -> str:
1705
+ return f"[{self.position}] {self.args} ({self.kwargs})"
1706
+
1707
+
1708
+ class ExecutionStrategy(ABC):
1709
+ @abstractmethod
1710
+ async def initialize(self) -> "ClusterPipeline":
1711
+ """
1712
+ Initialize the execution strategy.
1713
+
1714
+ See ClusterPipeline.initialize()
1715
+ """
1716
+ pass
1717
+
1718
+ @abstractmethod
1719
+ def execute_command(
1720
+ self, *args: Union[KeyT, EncodableT], **kwargs: Any
1721
+ ) -> "ClusterPipeline":
1722
+ """
1723
+ Append a raw command to the pipeline.
1724
+
1725
+ See ClusterPipeline.execute_command()
1726
+ """
1727
+ pass
1728
+
1729
+ @abstractmethod
1730
+ async def execute(
1731
+ self, raise_on_error: bool = True, allow_redirections: bool = True
1732
+ ) -> List[Any]:
1733
+ """
1734
+ Execute the pipeline.
1735
+
1736
+ It will retry the commands as specified by retries specified in :attr:`retry`
1737
+ & then raise an exception.
1738
+
1739
+ See ClusterPipeline.execute()
1740
+ """
1741
+ pass
1742
+
1743
+ @abstractmethod
1744
+ def mset_nonatomic(
1745
+ self, mapping: Mapping[AnyKeyT, EncodableT]
1746
+ ) -> "ClusterPipeline":
1747
+ """
1748
+ Executes multiple MSET commands according to the provided slot/pairs mapping.
1749
+
1750
+ See ClusterPipeline.mset_nonatomic()
1751
+ """
1752
+ pass
1753
+
1754
+ @abstractmethod
1755
+ async def reset(self):
1756
+ """
1757
+ Resets current execution strategy.
1758
+
1759
+ See: ClusterPipeline.reset()
1760
+ """
1761
+ pass
1762
+
1763
+ @abstractmethod
1764
+ def multi(self):
1765
+ """
1766
+ Starts transactional context.
1767
+
1768
+ See: ClusterPipeline.multi()
1769
+ """
1770
+ pass
1771
+
1772
+ @abstractmethod
1773
+ async def watch(self, *names):
1774
+ """
1775
+ Watch given keys.
1776
+
1777
+ See: ClusterPipeline.watch()
1778
+ """
1779
+ pass
1780
+
1781
+ @abstractmethod
1782
+ async def unwatch(self):
1783
+ """
1784
+ Unwatches all previously specified keys
1785
+
1786
+ See: ClusterPipeline.unwatch()
1787
+ """
1788
+ pass
1789
+
1790
+ @abstractmethod
1791
+ async def discard(self):
1792
+ pass
1793
+
1794
+ @abstractmethod
1795
+ async def unlink(self, *names):
1796
+ """
1797
+ "Unlink a key specified by ``names``"
1798
+
1799
+ See: ClusterPipeline.unlink()
1800
+ """
1801
+ pass
1802
+
1803
+ @abstractmethod
1804
+ def __len__(self) -> int:
1805
+ pass
1806
+
1807
+
1808
+ class AbstractStrategy(ExecutionStrategy):
1809
+ def __init__(self, pipe: ClusterPipeline) -> None:
1810
+ self._pipe: ClusterPipeline = pipe
1811
+ self._command_queue: List["PipelineCommand"] = []
1812
+
1813
+ async def initialize(self) -> "ClusterPipeline":
1814
+ if self._pipe.cluster_client._initialize:
1815
+ await self._pipe.cluster_client.initialize()
1816
+ self._command_queue = []
1817
+ return self._pipe
1818
+
1819
+ def execute_command(
1820
+ self, *args: Union[KeyT, EncodableT], **kwargs: Any
1821
+ ) -> "ClusterPipeline":
1822
+ self._command_queue.append(
1823
+ PipelineCommand(len(self._command_queue), *args, **kwargs)
1824
+ )
1825
+ return self._pipe
1826
+
1827
+ def _annotate_exception(self, exception, number, command):
1828
+ """
1829
+ Provides extra context to the exception prior to it being handled
1830
+ """
1831
+ cmd = " ".join(map(safe_str, command))
1832
+ msg = (
1833
+ f"Command # {number} ({truncate_text(cmd)}) of pipeline "
1834
+ f"caused error: {exception.args[0]}"
1835
+ )
1836
+ exception.args = (msg,) + exception.args[1:]
1837
+
1838
+ @abstractmethod
1839
+ def mset_nonatomic(
1840
+ self, mapping: Mapping[AnyKeyT, EncodableT]
1841
+ ) -> "ClusterPipeline":
1842
+ pass
1843
+
1844
+ @abstractmethod
1845
+ async def execute(
1846
+ self, raise_on_error: bool = True, allow_redirections: bool = True
1847
+ ) -> List[Any]:
1848
+ pass
1849
+
1850
+ @abstractmethod
1851
+ async def reset(self):
1852
+ pass
1853
+
1854
+ @abstractmethod
1855
+ def multi(self):
1856
+ pass
1857
+
1858
+ @abstractmethod
1859
+ async def watch(self, *names):
1860
+ pass
1861
+
1862
+ @abstractmethod
1863
+ async def unwatch(self):
1864
+ pass
1865
+
1866
+ @abstractmethod
1867
+ async def discard(self):
1868
+ pass
1869
+
1870
+ @abstractmethod
1871
+ async def unlink(self, *names):
1872
+ pass
1873
+
1874
+ def __len__(self) -> int:
1875
+ return len(self._command_queue)
1876
+
1877
+
1878
+ class PipelineStrategy(AbstractStrategy):
1879
+ def __init__(self, pipe: ClusterPipeline) -> None:
1880
+ super().__init__(pipe)
1881
+
1882
+ def mset_nonatomic(
1883
+ self, mapping: Mapping[AnyKeyT, EncodableT]
1884
+ ) -> "ClusterPipeline":
1885
+ encoder = self._pipe.cluster_client.encoder
1886
+
1887
+ slots_pairs = {}
1888
+ for pair in mapping.items():
1889
+ slot = key_slot(encoder.encode(pair[0]))
1890
+ slots_pairs.setdefault(slot, []).extend(pair)
1891
+
1892
+ for pairs in slots_pairs.values():
1893
+ self.execute_command("MSET", *pairs)
1894
+
1895
+ return self._pipe
1896
+
1897
+ async def execute(
1898
+ self, raise_on_error: bool = True, allow_redirections: bool = True
1899
+ ) -> List[Any]:
1900
+ if not self._command_queue:
1575
1901
  return []
1576
1902
 
1577
1903
  try:
1578
- retry_attempts = self._client.retry.get_retries()
1904
+ retry_attempts = self._pipe.cluster_client.retry.get_retries()
1579
1905
  while True:
1580
1906
  try:
1581
- if self._client._initialize:
1582
- await self._client.initialize()
1907
+ if self._pipe.cluster_client._initialize:
1908
+ await self._pipe.cluster_client.initialize()
1583
1909
  return await self._execute(
1584
- self._client,
1585
- self._command_stack,
1910
+ self._pipe.cluster_client,
1911
+ self._command_queue,
1586
1912
  raise_on_error=raise_on_error,
1587
1913
  allow_redirections=allow_redirections,
1588
1914
  )
1589
1915
 
1590
- except self.__class__.ERRORS_ALLOW_RETRY as e:
1916
+ except RedisCluster.ERRORS_ALLOW_RETRY as e:
1591
1917
  if retry_attempts > 0:
1592
1918
  # Try again with the new cluster setup. All other errors
1593
1919
  # should be raised.
1594
1920
  retry_attempts -= 1
1595
- await self._client.aclose()
1921
+ await self._pipe.cluster_client.aclose()
1596
1922
  await asyncio.sleep(0.25)
1597
1923
  else:
1598
1924
  # All other errors should be raised.
1599
1925
  raise e
1600
1926
  finally:
1601
- self._command_stack = []
1927
+ await self.reset()
1602
1928
 
1603
1929
  async def _execute(
1604
1930
  self,
@@ -1678,50 +2004,402 @@ class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterComm
1678
2004
  for cmd in default_node[1]:
1679
2005
  # Check if it has a command that failed with a relevant
1680
2006
  # exception
1681
- if type(cmd.result) in self.__class__.ERRORS_ALLOW_RETRY:
2007
+ if type(cmd.result) in RedisCluster.ERRORS_ALLOW_RETRY:
1682
2008
  client.replace_default_node()
1683
2009
  break
1684
2010
 
1685
2011
  return [cmd.result for cmd in stack]
1686
2012
 
1687
- def _split_command_across_slots(
1688
- self, command: str, *keys: KeyT
1689
- ) -> "ClusterPipeline":
1690
- for slot_keys in self._client._partition_keys_by_slot(keys).values():
1691
- self.execute_command(command, *slot_keys)
2013
+ async def reset(self):
2014
+ """
2015
+ Reset back to empty pipeline.
2016
+ """
2017
+ self._command_queue = []
1692
2018
 
1693
- return self
2019
+ def multi(self):
2020
+ raise RedisClusterException(
2021
+ "method multi() is not supported outside of transactional context"
2022
+ )
2023
+
2024
+ async def watch(self, *names):
2025
+ raise RedisClusterException(
2026
+ "method watch() is not supported outside of transactional context"
2027
+ )
2028
+
2029
+ async def unwatch(self):
2030
+ raise RedisClusterException(
2031
+ "method unwatch() is not supported outside of transactional context"
2032
+ )
2033
+
2034
+ async def discard(self):
2035
+ raise RedisClusterException(
2036
+ "method discard() is not supported outside of transactional context"
2037
+ )
2038
+
2039
+ async def unlink(self, *names):
2040
+ if len(names) != 1:
2041
+ raise RedisClusterException(
2042
+ "unlinking multiple keys is not implemented in pipeline command"
2043
+ )
2044
+
2045
+ return self.execute_command("UNLINK", names[0])
2046
+
2047
+
2048
+ class TransactionStrategy(AbstractStrategy):
2049
+ NO_SLOTS_COMMANDS = {"UNWATCH"}
2050
+ IMMEDIATE_EXECUTE_COMMANDS = {"WATCH", "UNWATCH"}
2051
+ UNWATCH_COMMANDS = {"DISCARD", "EXEC", "UNWATCH"}
2052
+ SLOT_REDIRECT_ERRORS = (AskError, MovedError)
2053
+ CONNECTION_ERRORS = (
2054
+ ConnectionError,
2055
+ OSError,
2056
+ ClusterDownError,
2057
+ SlotNotCoveredError,
2058
+ )
2059
+
2060
+ def __init__(self, pipe: ClusterPipeline) -> None:
2061
+ super().__init__(pipe)
2062
+ self._explicit_transaction = False
2063
+ self._watching = False
2064
+ self._pipeline_slots: Set[int] = set()
2065
+ self._transaction_node: Optional[ClusterNode] = None
2066
+ self._transaction_connection: Optional[Connection] = None
2067
+ self._executing = False
2068
+ self._retry = copy(self._pipe.cluster_client.retry)
2069
+ self._retry.update_supported_errors(
2070
+ RedisCluster.ERRORS_ALLOW_RETRY + self.SLOT_REDIRECT_ERRORS
2071
+ )
2072
+
2073
+ def _get_client_and_connection_for_transaction(
2074
+ self,
2075
+ ) -> Tuple[ClusterNode, Connection]:
2076
+ """
2077
+ Find a connection for a pipeline transaction.
2078
+
2079
+ For running an atomic transaction, watch keys ensure that contents have not been
2080
+ altered as long as the watch commands for those keys were sent over the same
2081
+ connection. So once we start watching a key, we fetch a connection to the
2082
+ node that owns that slot and reuse it.
2083
+ """
2084
+ if not self._pipeline_slots:
2085
+ raise RedisClusterException(
2086
+ "At least a command with a key is needed to identify a node"
2087
+ )
2088
+
2089
+ node: ClusterNode = self._pipe.cluster_client.nodes_manager.get_node_from_slot(
2090
+ list(self._pipeline_slots)[0], False
2091
+ )
2092
+ self._transaction_node = node
2093
+
2094
+ if not self._transaction_connection:
2095
+ connection: Connection = self._transaction_node.acquire_connection()
2096
+ self._transaction_connection = connection
2097
+
2098
+ return self._transaction_node, self._transaction_connection
2099
+
2100
+ def execute_command(self, *args: Union[KeyT, EncodableT], **kwargs: Any) -> "Any":
2101
+ # Given the limitation of ClusterPipeline sync API, we have to run it in thread.
2102
+ response = None
2103
+ error = None
2104
+
2105
+ def runner():
2106
+ nonlocal response
2107
+ nonlocal error
2108
+ try:
2109
+ response = asyncio.run(self._execute_command(*args, **kwargs))
2110
+ except Exception as e:
2111
+ error = e
2112
+
2113
+ thread = threading.Thread(target=runner)
2114
+ thread.start()
2115
+ thread.join()
2116
+
2117
+ if error:
2118
+ raise error
2119
+
2120
+ return response
2121
+
2122
+ async def _execute_command(
2123
+ self, *args: Union[KeyT, EncodableT], **kwargs: Any
2124
+ ) -> Any:
2125
+ if self._pipe.cluster_client._initialize:
2126
+ await self._pipe.cluster_client.initialize()
2127
+
2128
+ slot_number: Optional[int] = None
2129
+ if args[0] not in self.NO_SLOTS_COMMANDS:
2130
+ slot_number = await self._pipe.cluster_client._determine_slot(*args)
2131
+
2132
+ if (
2133
+ self._watching or args[0] in self.IMMEDIATE_EXECUTE_COMMANDS
2134
+ ) and not self._explicit_transaction:
2135
+ if args[0] == "WATCH":
2136
+ self._validate_watch()
2137
+
2138
+ if slot_number is not None:
2139
+ if self._pipeline_slots and slot_number not in self._pipeline_slots:
2140
+ raise CrossSlotTransactionError(
2141
+ "Cannot watch or send commands on different slots"
2142
+ )
2143
+
2144
+ self._pipeline_slots.add(slot_number)
2145
+ elif args[0] not in self.NO_SLOTS_COMMANDS:
2146
+ raise RedisClusterException(
2147
+ f"Cannot identify slot number for command: {args[0]},"
2148
+ "it cannot be triggered in a transaction"
2149
+ )
2150
+
2151
+ return self._immediate_execute_command(*args, **kwargs)
2152
+ else:
2153
+ if slot_number is not None:
2154
+ self._pipeline_slots.add(slot_number)
2155
+
2156
+ return super().execute_command(*args, **kwargs)
2157
+
2158
+ def _validate_watch(self):
2159
+ if self._explicit_transaction:
2160
+ raise RedisError("Cannot issue a WATCH after a MULTI")
2161
+
2162
+ self._watching = True
2163
+
2164
+ async def _immediate_execute_command(self, *args, **options):
2165
+ return await self._retry.call_with_retry(
2166
+ lambda: self._get_connection_and_send_command(*args, **options),
2167
+ self._reinitialize_on_error,
2168
+ )
2169
+
2170
+ async def _get_connection_and_send_command(self, *args, **options):
2171
+ redis_node, connection = self._get_client_and_connection_for_transaction()
2172
+ return await self._send_command_parse_response(
2173
+ connection, redis_node, args[0], *args, **options
2174
+ )
2175
+
2176
+ async def _send_command_parse_response(
2177
+ self,
2178
+ connection: Connection,
2179
+ redis_node: ClusterNode,
2180
+ command_name,
2181
+ *args,
2182
+ **options,
2183
+ ):
2184
+ """
2185
+ Send a command and parse the response
2186
+ """
2187
+
2188
+ await connection.send_command(*args)
2189
+ output = await redis_node.parse_response(connection, command_name, **options)
2190
+
2191
+ if command_name in self.UNWATCH_COMMANDS:
2192
+ self._watching = False
2193
+ return output
2194
+
2195
+ async def _reinitialize_on_error(self, error):
2196
+ if self._watching:
2197
+ if type(error) in self.SLOT_REDIRECT_ERRORS and self._executing:
2198
+ raise WatchError("Slot rebalancing occurred while watching keys")
2199
+
2200
+ if (
2201
+ type(error) in self.SLOT_REDIRECT_ERRORS
2202
+ or type(error) in self.CONNECTION_ERRORS
2203
+ ):
2204
+ if self._transaction_connection:
2205
+ self._transaction_connection = None
2206
+
2207
+ self._pipe.cluster_client.reinitialize_counter += 1
2208
+ if (
2209
+ self._pipe.cluster_client.reinitialize_steps
2210
+ and self._pipe.cluster_client.reinitialize_counter
2211
+ % self._pipe.cluster_client.reinitialize_steps
2212
+ == 0
2213
+ ):
2214
+ await self._pipe.cluster_client.nodes_manager.initialize()
2215
+ self.reinitialize_counter = 0
2216
+ else:
2217
+ self._pipe.cluster_client.nodes_manager.update_moved_exception(error)
2218
+
2219
+ self._executing = False
2220
+
2221
+ def _raise_first_error(self, responses, stack):
2222
+ """
2223
+ Raise the first exception on the stack
2224
+ """
2225
+ for r, cmd in zip(responses, stack):
2226
+ if isinstance(r, Exception):
2227
+ self._annotate_exception(r, cmd.position + 1, cmd.args)
2228
+ raise r
1694
2229
 
1695
2230
  def mset_nonatomic(
1696
2231
  self, mapping: Mapping[AnyKeyT, EncodableT]
1697
2232
  ) -> "ClusterPipeline":
1698
- encoder = self._client.encoder
2233
+ raise NotImplementedError("Method is not supported in transactional context.")
1699
2234
 
1700
- slots_pairs = {}
1701
- for pair in mapping.items():
1702
- slot = key_slot(encoder.encode(pair[0]))
1703
- slots_pairs.setdefault(slot, []).extend(pair)
2235
+ async def execute(
2236
+ self, raise_on_error: bool = True, allow_redirections: bool = True
2237
+ ) -> List[Any]:
2238
+ stack = self._command_queue
2239
+ if not stack and (not self._watching or not self._pipeline_slots):
2240
+ return []
1704
2241
 
1705
- for pairs in slots_pairs.values():
1706
- self.execute_command("MSET", *pairs)
2242
+ return await self._execute_transaction_with_retries(stack, raise_on_error)
1707
2243
 
1708
- return self
2244
+ async def _execute_transaction_with_retries(
2245
+ self, stack: List["PipelineCommand"], raise_on_error: bool
2246
+ ):
2247
+ return await self._retry.call_with_retry(
2248
+ lambda: self._execute_transaction(stack, raise_on_error),
2249
+ self._reinitialize_on_error,
2250
+ )
1709
2251
 
2252
+ async def _execute_transaction(
2253
+ self, stack: List["PipelineCommand"], raise_on_error: bool
2254
+ ):
2255
+ if len(self._pipeline_slots) > 1:
2256
+ raise CrossSlotTransactionError(
2257
+ "All keys involved in a cluster transaction must map to the same slot"
2258
+ )
1710
2259
 
1711
- for command in PIPELINE_BLOCKED_COMMANDS:
1712
- command = command.replace(" ", "_").lower()
1713
- if command == "mset_nonatomic":
1714
- continue
2260
+ self._executing = True
1715
2261
 
1716
- setattr(ClusterPipeline, command, block_pipeline_command(command))
2262
+ redis_node, connection = self._get_client_and_connection_for_transaction()
1717
2263
 
2264
+ stack = chain(
2265
+ [PipelineCommand(0, "MULTI")],
2266
+ stack,
2267
+ [PipelineCommand(0, "EXEC")],
2268
+ )
2269
+ commands = [c.args for c in stack if EMPTY_RESPONSE not in c.kwargs]
2270
+ packed_commands = connection.pack_commands(commands)
2271
+ await connection.send_packed_command(packed_commands)
2272
+ errors = []
2273
+
2274
+ # parse off the response for MULTI
2275
+ # NOTE: we need to handle ResponseErrors here and continue
2276
+ # so that we read all the additional command messages from
2277
+ # the socket
2278
+ try:
2279
+ await redis_node.parse_response(connection, "MULTI")
2280
+ except ResponseError as e:
2281
+ self._annotate_exception(e, 0, "MULTI")
2282
+ errors.append(e)
2283
+ except self.CONNECTION_ERRORS as cluster_error:
2284
+ self._annotate_exception(cluster_error, 0, "MULTI")
2285
+ raise
1718
2286
 
1719
- class PipelineCommand:
1720
- def __init__(self, position: int, *args: Any, **kwargs: Any) -> None:
1721
- self.args = args
1722
- self.kwargs = kwargs
1723
- self.position = position
1724
- self.result: Union[Any, Exception] = None
2287
+ # and all the other commands
2288
+ for i, command in enumerate(self._command_queue):
2289
+ if EMPTY_RESPONSE in command.kwargs:
2290
+ errors.append((i, command.kwargs[EMPTY_RESPONSE]))
2291
+ else:
2292
+ try:
2293
+ _ = await redis_node.parse_response(connection, "_")
2294
+ except self.SLOT_REDIRECT_ERRORS as slot_error:
2295
+ self._annotate_exception(slot_error, i + 1, command.args)
2296
+ errors.append(slot_error)
2297
+ except self.CONNECTION_ERRORS as cluster_error:
2298
+ self._annotate_exception(cluster_error, i + 1, command.args)
2299
+ raise
2300
+ except ResponseError as e:
2301
+ self._annotate_exception(e, i + 1, command.args)
2302
+ errors.append(e)
2303
+
2304
+ response = None
2305
+ # parse the EXEC.
2306
+ try:
2307
+ response = await redis_node.parse_response(connection, "EXEC")
2308
+ except ExecAbortError:
2309
+ if errors:
2310
+ raise errors[0]
2311
+ raise
1725
2312
 
1726
- def __repr__(self) -> str:
1727
- return f"[{self.position}] {self.args} ({self.kwargs})"
2313
+ self._executing = False
2314
+
2315
+ # EXEC clears any watched keys
2316
+ self._watching = False
2317
+
2318
+ if response is None:
2319
+ raise WatchError("Watched variable changed.")
2320
+
2321
+ # put any parse errors into the response
2322
+ for i, e in errors:
2323
+ response.insert(i, e)
2324
+
2325
+ if len(response) != len(self._command_queue):
2326
+ raise InvalidPipelineStack(
2327
+ "Unexpected response length for cluster pipeline EXEC."
2328
+ " Command stack was {} but response had length {}".format(
2329
+ [c.args[0] for c in self._command_queue], len(response)
2330
+ )
2331
+ )
2332
+
2333
+ # find any errors in the response and raise if necessary
2334
+ if raise_on_error or len(errors) > 0:
2335
+ self._raise_first_error(
2336
+ response,
2337
+ self._command_queue,
2338
+ )
2339
+
2340
+ # We have to run response callbacks manually
2341
+ data = []
2342
+ for r, cmd in zip(response, self._command_queue):
2343
+ if not isinstance(r, Exception):
2344
+ command_name = cmd.args[0]
2345
+ if command_name in self._pipe.cluster_client.response_callbacks:
2346
+ r = self._pipe.cluster_client.response_callbacks[command_name](
2347
+ r, **cmd.kwargs
2348
+ )
2349
+ data.append(r)
2350
+ return data
2351
+
2352
+ async def reset(self):
2353
+ self._command_queue = []
2354
+
2355
+ # make sure to reset the connection state in the event that we were
2356
+ # watching something
2357
+ if self._transaction_connection:
2358
+ try:
2359
+ if self._watching:
2360
+ # call this manually since our unwatch or
2361
+ # immediate_execute_command methods can call reset()
2362
+ await self._transaction_connection.send_command("UNWATCH")
2363
+ await self._transaction_connection.read_response()
2364
+ # we can safely return the connection to the pool here since we're
2365
+ # sure we're no longer WATCHing anything
2366
+ self._transaction_node.release(self._transaction_connection)
2367
+ self._transaction_connection = None
2368
+ except self.CONNECTION_ERRORS:
2369
+ # disconnect will also remove any previous WATCHes
2370
+ if self._transaction_connection:
2371
+ await self._transaction_connection.disconnect()
2372
+
2373
+ # clean up the other instance attributes
2374
+ self._transaction_node = None
2375
+ self._watching = False
2376
+ self._explicit_transaction = False
2377
+ self._pipeline_slots = set()
2378
+ self._executing = False
2379
+
2380
+ def multi(self):
2381
+ if self._explicit_transaction:
2382
+ raise RedisError("Cannot issue nested calls to MULTI")
2383
+ if self._command_queue:
2384
+ raise RedisError(
2385
+ "Commands without an initial WATCH have already been issued"
2386
+ )
2387
+ self._explicit_transaction = True
2388
+
2389
+ async def watch(self, *names):
2390
+ if self._explicit_transaction:
2391
+ raise RedisError("Cannot issue a WATCH after a MULTI")
2392
+
2393
+ return await self.execute_command("WATCH", *names)
2394
+
2395
+ async def unwatch(self):
2396
+ if self._watching:
2397
+ return await self.execute_command("UNWATCH")
2398
+
2399
+ return True
2400
+
2401
+ async def discard(self):
2402
+ await self.reset()
2403
+
2404
+ async def unlink(self, *names):
2405
+ return self.execute_command("UNLINK", *names)