redis 6.1.0__py3-none-any.whl → 6.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
redis/asyncio/cluster.py CHANGED
@@ -2,16 +2,23 @@ import asyncio
2
2
  import collections
3
3
  import random
4
4
  import socket
5
+ import threading
6
+ import time
5
7
  import warnings
8
+ from abc import ABC, abstractmethod
9
+ from copy import copy
10
+ from itertools import chain
6
11
  from typing import (
7
12
  Any,
8
13
  Callable,
14
+ Coroutine,
9
15
  Deque,
10
16
  Dict,
11
17
  Generator,
12
18
  List,
13
19
  Mapping,
14
20
  Optional,
21
+ Set,
15
22
  Tuple,
16
23
  Type,
17
24
  TypeVar,
@@ -53,7 +60,10 @@ from redis.exceptions import (
53
60
  ClusterDownError,
54
61
  ClusterError,
55
62
  ConnectionError,
63
+ CrossSlotTransactionError,
56
64
  DataError,
65
+ ExecAbortError,
66
+ InvalidPipelineStack,
57
67
  MaxConnectionsError,
58
68
  MovedError,
59
69
  RedisClusterException,
@@ -62,6 +72,7 @@ from redis.exceptions import (
62
72
  SlotNotCoveredError,
63
73
  TimeoutError,
64
74
  TryAgainError,
75
+ WatchError,
65
76
  )
66
77
  from redis.typing import AnyKeyT, EncodableT, KeyT
67
78
  from redis.utils import (
@@ -134,6 +145,14 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
134
145
  | Enable read from replicas in READONLY mode and defines the load balancing
135
146
  strategy that will be used for cluster node selection.
136
147
  The data read from replicas is eventually consistent with the data in primary nodes.
148
+ :param dynamic_startup_nodes:
149
+ | Set the RedisCluster's startup nodes to all the discovered nodes.
150
+ If true (default value), the cluster's discovered nodes will be used to
151
+ determine the cluster nodes-slots mapping in the next topology refresh.
152
+ It will remove the initial passed startup nodes if their endpoints aren't
153
+ listed in the CLUSTER SLOTS output.
154
+ If you use dynamic DNS endpoints for startup nodes but CLUSTER SLOTS lists
155
+ specific IP addresses, it is best to set it to false.
137
156
  :param reinitialize_steps:
138
157
  | Specifies the number of MOVED errors that need to occur before reinitializing
139
158
  the whole cluster topology. If a MOVED error occurs and the cluster does not
@@ -250,6 +269,7 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
250
269
  require_full_coverage: bool = True,
251
270
  read_from_replicas: bool = False,
252
271
  load_balancing_strategy: Optional[LoadBalancingStrategy] = None,
272
+ dynamic_startup_nodes: bool = True,
253
273
  reinitialize_steps: int = 5,
254
274
  cluster_error_retry_attempts: int = 3,
255
275
  max_connections: int = 2**31,
@@ -388,6 +408,7 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
388
408
  startup_nodes,
389
409
  require_full_coverage,
390
410
  kwargs,
411
+ dynamic_startup_nodes=dynamic_startup_nodes,
391
412
  address_remap=address_remap,
392
413
  event_dispatcher=self._event_dispatcher,
393
414
  )
@@ -860,10 +881,7 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
860
881
  if shard_hint:
861
882
  raise RedisClusterException("shard_hint is deprecated in cluster mode")
862
883
 
863
- if transaction:
864
- raise RedisClusterException("transaction is deprecated in cluster mode")
865
-
866
- return ClusterPipeline(self)
884
+ return ClusterPipeline(self, transaction)
867
885
 
868
886
  def lock(
869
887
  self,
@@ -946,6 +964,30 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
946
964
  raise_on_release_error=raise_on_release_error,
947
965
  )
948
966
 
967
+ async def transaction(
968
+ self, func: Coroutine[None, "ClusterPipeline", Any], *watches, **kwargs
969
+ ):
970
+ """
971
+ Convenience method for executing the callable `func` as a transaction
972
+ while watching all keys specified in `watches`. The 'func' callable
973
+ should expect a single argument which is a Pipeline object.
974
+ """
975
+ shard_hint = kwargs.pop("shard_hint", None)
976
+ value_from_callable = kwargs.pop("value_from_callable", False)
977
+ watch_delay = kwargs.pop("watch_delay", None)
978
+ async with self.pipeline(True, shard_hint) as pipe:
979
+ while True:
980
+ try:
981
+ if watches:
982
+ await pipe.watch(*watches)
983
+ func_value = await func(pipe)
984
+ exec_value = await pipe.execute()
985
+ return func_value if value_from_callable else exec_value
986
+ except WatchError:
987
+ if watch_delay is not None and watch_delay > 0:
988
+ time.sleep(watch_delay)
989
+ continue
990
+
949
991
 
950
992
  class ClusterNode:
951
993
  """
@@ -1067,6 +1109,12 @@ class ClusterNode:
1067
1109
 
1068
1110
  raise MaxConnectionsError()
1069
1111
 
1112
+ def release(self, connection: Connection) -> None:
1113
+ """
1114
+ Release connection back to free queue.
1115
+ """
1116
+ self._free.append(connection)
1117
+
1070
1118
  async def parse_response(
1071
1119
  self, connection: Connection, command: str, **kwargs: Any
1072
1120
  ) -> Any:
@@ -1162,6 +1210,7 @@ class ClusterNode:
1162
1210
 
1163
1211
  class NodesManager:
1164
1212
  __slots__ = (
1213
+ "_dynamic_startup_nodes",
1165
1214
  "_moved_exception",
1166
1215
  "_event_dispatcher",
1167
1216
  "connection_kwargs",
@@ -1179,6 +1228,7 @@ class NodesManager:
1179
1228
  startup_nodes: List["ClusterNode"],
1180
1229
  require_full_coverage: bool,
1181
1230
  connection_kwargs: Dict[str, Any],
1231
+ dynamic_startup_nodes: bool = True,
1182
1232
  address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None,
1183
1233
  event_dispatcher: Optional[EventDispatcher] = None,
1184
1234
  ) -> None:
@@ -1191,6 +1241,8 @@ class NodesManager:
1191
1241
  self.nodes_cache: Dict[str, "ClusterNode"] = {}
1192
1242
  self.slots_cache: Dict[int, List["ClusterNode"]] = {}
1193
1243
  self.read_load_balancer = LoadBalancer()
1244
+
1245
+ self._dynamic_startup_nodes: bool = dynamic_startup_nodes
1194
1246
  self._moved_exception: MovedError = None
1195
1247
  if event_dispatcher is None:
1196
1248
  self._event_dispatcher = EventDispatcher()
@@ -1233,6 +1285,9 @@ class NodesManager:
1233
1285
  task = asyncio.create_task(old[name].disconnect()) # noqa
1234
1286
  old[name] = node
1235
1287
 
1288
+ def update_moved_exception(self, exception):
1289
+ self._moved_exception = exception
1290
+
1236
1291
  def _update_moved_slots(self) -> None:
1237
1292
  e = self._moved_exception
1238
1293
  redirected_node = self.get_node(host=e.host, port=e.port)
@@ -1433,8 +1488,10 @@ class NodesManager:
1433
1488
  # Set the tmp variables to the real variables
1434
1489
  self.slots_cache = tmp_slots
1435
1490
  self.set_nodes(self.nodes_cache, tmp_nodes_cache, remove_old=True)
1436
- # Populate the startup nodes with all discovered nodes
1437
- self.set_nodes(self.startup_nodes, self.nodes_cache, remove_old=True)
1491
+
1492
+ if self._dynamic_startup_nodes:
1493
+ # Populate the startup nodes with all discovered nodes
1494
+ self.set_nodes(self.startup_nodes, self.nodes_cache, remove_old=True)
1438
1495
 
1439
1496
  # Set the default node
1440
1497
  self.default_node = self.get_nodes_by_server_type(PRIMARY)[0]
@@ -1498,41 +1555,47 @@ class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterComm
1498
1555
  | Existing :class:`~.RedisCluster` client
1499
1556
  """
1500
1557
 
1501
- __slots__ = ("_command_stack", "_client")
1502
-
1503
- def __init__(self, client: RedisCluster) -> None:
1504
- self._client = client
1558
+ __slots__ = ("cluster_client", "_transaction", "_execution_strategy")
1505
1559
 
1506
- self._command_stack: List["PipelineCommand"] = []
1560
+ def __init__(
1561
+ self, client: RedisCluster, transaction: Optional[bool] = None
1562
+ ) -> None:
1563
+ self.cluster_client = client
1564
+ self._transaction = transaction
1565
+ self._execution_strategy: ExecutionStrategy = (
1566
+ PipelineStrategy(self)
1567
+ if not self._transaction
1568
+ else TransactionStrategy(self)
1569
+ )
1507
1570
 
1508
1571
  async def initialize(self) -> "ClusterPipeline":
1509
- if self._client._initialize:
1510
- await self._client.initialize()
1511
- self._command_stack = []
1572
+ await self._execution_strategy.initialize()
1512
1573
  return self
1513
1574
 
1514
1575
  async def __aenter__(self) -> "ClusterPipeline":
1515
1576
  return await self.initialize()
1516
1577
 
1517
1578
  async def __aexit__(self, exc_type: None, exc_value: None, traceback: None) -> None:
1518
- self._command_stack = []
1579
+ await self.reset()
1519
1580
 
1520
1581
  def __await__(self) -> Generator[Any, None, "ClusterPipeline"]:
1521
1582
  return self.initialize().__await__()
1522
1583
 
1523
1584
  def __enter__(self) -> "ClusterPipeline":
1524
- self._command_stack = []
1585
+ # TODO: Remove this method before 7.0.0
1586
+ self._execution_strategy._command_queue = []
1525
1587
  return self
1526
1588
 
1527
1589
  def __exit__(self, exc_type: None, exc_value: None, traceback: None) -> None:
1528
- self._command_stack = []
1590
+ # TODO: Remove this method before 7.0.0
1591
+ self._execution_strategy._command_queue = []
1529
1592
 
1530
1593
  def __bool__(self) -> bool:
1531
1594
  "Pipeline instances should always evaluate to True on Python 3+"
1532
1595
  return True
1533
1596
 
1534
1597
  def __len__(self) -> int:
1535
- return len(self._command_stack)
1598
+ return len(self._execution_strategy)
1536
1599
 
1537
1600
  def execute_command(
1538
1601
  self, *args: Union[KeyT, EncodableT], **kwargs: Any
@@ -1548,10 +1611,7 @@ class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterComm
1548
1611
  or List[:class:`~.ClusterNode`] or Dict[Any, :class:`~.ClusterNode`]
1549
1612
  - Rest of the kwargs are passed to the Redis connection
1550
1613
  """
1551
- self._command_stack.append(
1552
- PipelineCommand(len(self._command_stack), *args, **kwargs)
1553
- )
1554
- return self
1614
+ return self._execution_strategy.execute_command(*args, **kwargs)
1555
1615
 
1556
1616
  async def execute(
1557
1617
  self, raise_on_error: bool = True, allow_redirections: bool = True
@@ -1571,34 +1631,294 @@ class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterComm
1571
1631
  :raises RedisClusterException: if target_nodes is not provided & the command
1572
1632
  can't be mapped to a slot
1573
1633
  """
1574
- if not self._command_stack:
1634
+ try:
1635
+ return await self._execution_strategy.execute(
1636
+ raise_on_error, allow_redirections
1637
+ )
1638
+ finally:
1639
+ await self.reset()
1640
+
1641
+ def _split_command_across_slots(
1642
+ self, command: str, *keys: KeyT
1643
+ ) -> "ClusterPipeline":
1644
+ for slot_keys in self.cluster_client._partition_keys_by_slot(keys).values():
1645
+ self.execute_command(command, *slot_keys)
1646
+
1647
+ return self
1648
+
1649
+ async def reset(self):
1650
+ """
1651
+ Reset back to empty pipeline.
1652
+ """
1653
+ await self._execution_strategy.reset()
1654
+
1655
+ def multi(self):
1656
+ """
1657
+ Start a transactional block of the pipeline after WATCH commands
1658
+ are issued. End the transactional block with `execute`.
1659
+ """
1660
+ self._execution_strategy.multi()
1661
+
1662
+ async def discard(self):
1663
+ """ """
1664
+ await self._execution_strategy.discard()
1665
+
1666
+ async def watch(self, *names):
1667
+ """Watches the values at keys ``names``"""
1668
+ await self._execution_strategy.watch(*names)
1669
+
1670
+ async def unwatch(self):
1671
+ """Unwatches all previously specified keys"""
1672
+ await self._execution_strategy.unwatch()
1673
+
1674
+ async def unlink(self, *names):
1675
+ await self._execution_strategy.unlink(*names)
1676
+
1677
+ def mset_nonatomic(
1678
+ self, mapping: Mapping[AnyKeyT, EncodableT]
1679
+ ) -> "ClusterPipeline":
1680
+ return self._execution_strategy.mset_nonatomic(mapping)
1681
+
1682
+
1683
+ for command in PIPELINE_BLOCKED_COMMANDS:
1684
+ command = command.replace(" ", "_").lower()
1685
+ if command == "mset_nonatomic":
1686
+ continue
1687
+
1688
+ setattr(ClusterPipeline, command, block_pipeline_command(command))
1689
+
1690
+
1691
+ class PipelineCommand:
1692
+ def __init__(self, position: int, *args: Any, **kwargs: Any) -> None:
1693
+ self.args = args
1694
+ self.kwargs = kwargs
1695
+ self.position = position
1696
+ self.result: Union[Any, Exception] = None
1697
+
1698
+ def __repr__(self) -> str:
1699
+ return f"[{self.position}] {self.args} ({self.kwargs})"
1700
+
1701
+
1702
+ class ExecutionStrategy(ABC):
1703
+ @abstractmethod
1704
+ async def initialize(self) -> "ClusterPipeline":
1705
+ """
1706
+ Initialize the execution strategy.
1707
+
1708
+ See ClusterPipeline.initialize()
1709
+ """
1710
+ pass
1711
+
1712
+ @abstractmethod
1713
+ def execute_command(
1714
+ self, *args: Union[KeyT, EncodableT], **kwargs: Any
1715
+ ) -> "ClusterPipeline":
1716
+ """
1717
+ Append a raw command to the pipeline.
1718
+
1719
+ See ClusterPipeline.execute_command()
1720
+ """
1721
+ pass
1722
+
1723
+ @abstractmethod
1724
+ async def execute(
1725
+ self, raise_on_error: bool = True, allow_redirections: bool = True
1726
+ ) -> List[Any]:
1727
+ """
1728
+ Execute the pipeline.
1729
+
1730
+ It will retry the commands as specified by retries specified in :attr:`retry`
1731
+ & then raise an exception.
1732
+
1733
+ See ClusterPipeline.execute()
1734
+ """
1735
+ pass
1736
+
1737
+ @abstractmethod
1738
+ def mset_nonatomic(
1739
+ self, mapping: Mapping[AnyKeyT, EncodableT]
1740
+ ) -> "ClusterPipeline":
1741
+ """
1742
+ Executes multiple MSET commands according to the provided slot/pairs mapping.
1743
+
1744
+ See ClusterPipeline.mset_nonatomic()
1745
+ """
1746
+ pass
1747
+
1748
+ @abstractmethod
1749
+ async def reset(self):
1750
+ """
1751
+ Resets current execution strategy.
1752
+
1753
+ See: ClusterPipeline.reset()
1754
+ """
1755
+ pass
1756
+
1757
+ @abstractmethod
1758
+ def multi(self):
1759
+ """
1760
+ Starts transactional context.
1761
+
1762
+ See: ClusterPipeline.multi()
1763
+ """
1764
+ pass
1765
+
1766
+ @abstractmethod
1767
+ async def watch(self, *names):
1768
+ """
1769
+ Watch given keys.
1770
+
1771
+ See: ClusterPipeline.watch()
1772
+ """
1773
+ pass
1774
+
1775
+ @abstractmethod
1776
+ async def unwatch(self):
1777
+ """
1778
+ Unwatches all previously specified keys
1779
+
1780
+ See: ClusterPipeline.unwatch()
1781
+ """
1782
+ pass
1783
+
1784
+ @abstractmethod
1785
+ async def discard(self):
1786
+ pass
1787
+
1788
+ @abstractmethod
1789
+ async def unlink(self, *names):
1790
+ """
1791
+ "Unlink a key specified by ``names``"
1792
+
1793
+ See: ClusterPipeline.unlink()
1794
+ """
1795
+ pass
1796
+
1797
+ @abstractmethod
1798
+ def __len__(self) -> int:
1799
+ pass
1800
+
1801
+
1802
+ class AbstractStrategy(ExecutionStrategy):
1803
+ def __init__(self, pipe: ClusterPipeline) -> None:
1804
+ self._pipe: ClusterPipeline = pipe
1805
+ self._command_queue: List["PipelineCommand"] = []
1806
+
1807
+ async def initialize(self) -> "ClusterPipeline":
1808
+ if self._pipe.cluster_client._initialize:
1809
+ await self._pipe.cluster_client.initialize()
1810
+ self._command_queue = []
1811
+ return self._pipe
1812
+
1813
+ def execute_command(
1814
+ self, *args: Union[KeyT, EncodableT], **kwargs: Any
1815
+ ) -> "ClusterPipeline":
1816
+ self._command_queue.append(
1817
+ PipelineCommand(len(self._command_queue), *args, **kwargs)
1818
+ )
1819
+ return self._pipe
1820
+
1821
+ def _annotate_exception(self, exception, number, command):
1822
+ """
1823
+ Provides extra context to the exception prior to it being handled
1824
+ """
1825
+ cmd = " ".join(map(safe_str, command))
1826
+ msg = (
1827
+ f"Command # {number} ({truncate_text(cmd)}) of pipeline "
1828
+ f"caused error: {exception.args[0]}"
1829
+ )
1830
+ exception.args = (msg,) + exception.args[1:]
1831
+
1832
+ @abstractmethod
1833
+ def mset_nonatomic(
1834
+ self, mapping: Mapping[AnyKeyT, EncodableT]
1835
+ ) -> "ClusterPipeline":
1836
+ pass
1837
+
1838
+ @abstractmethod
1839
+ async def execute(
1840
+ self, raise_on_error: bool = True, allow_redirections: bool = True
1841
+ ) -> List[Any]:
1842
+ pass
1843
+
1844
+ @abstractmethod
1845
+ async def reset(self):
1846
+ pass
1847
+
1848
+ @abstractmethod
1849
+ def multi(self):
1850
+ pass
1851
+
1852
+ @abstractmethod
1853
+ async def watch(self, *names):
1854
+ pass
1855
+
1856
+ @abstractmethod
1857
+ async def unwatch(self):
1858
+ pass
1859
+
1860
+ @abstractmethod
1861
+ async def discard(self):
1862
+ pass
1863
+
1864
+ @abstractmethod
1865
+ async def unlink(self, *names):
1866
+ pass
1867
+
1868
+ def __len__(self) -> int:
1869
+ return len(self._command_queue)
1870
+
1871
+
1872
+ class PipelineStrategy(AbstractStrategy):
1873
+ def __init__(self, pipe: ClusterPipeline) -> None:
1874
+ super().__init__(pipe)
1875
+
1876
+ def mset_nonatomic(
1877
+ self, mapping: Mapping[AnyKeyT, EncodableT]
1878
+ ) -> "ClusterPipeline":
1879
+ encoder = self._pipe.cluster_client.encoder
1880
+
1881
+ slots_pairs = {}
1882
+ for pair in mapping.items():
1883
+ slot = key_slot(encoder.encode(pair[0]))
1884
+ slots_pairs.setdefault(slot, []).extend(pair)
1885
+
1886
+ for pairs in slots_pairs.values():
1887
+ self.execute_command("MSET", *pairs)
1888
+
1889
+ return self._pipe
1890
+
1891
+ async def execute(
1892
+ self, raise_on_error: bool = True, allow_redirections: bool = True
1893
+ ) -> List[Any]:
1894
+ if not self._command_queue:
1575
1895
  return []
1576
1896
 
1577
1897
  try:
1578
- retry_attempts = self._client.retry.get_retries()
1898
+ retry_attempts = self._pipe.cluster_client.retry.get_retries()
1579
1899
  while True:
1580
1900
  try:
1581
- if self._client._initialize:
1582
- await self._client.initialize()
1901
+ if self._pipe.cluster_client._initialize:
1902
+ await self._pipe.cluster_client.initialize()
1583
1903
  return await self._execute(
1584
- self._client,
1585
- self._command_stack,
1904
+ self._pipe.cluster_client,
1905
+ self._command_queue,
1586
1906
  raise_on_error=raise_on_error,
1587
1907
  allow_redirections=allow_redirections,
1588
1908
  )
1589
1909
 
1590
- except self.__class__.ERRORS_ALLOW_RETRY as e:
1910
+ except RedisCluster.ERRORS_ALLOW_RETRY as e:
1591
1911
  if retry_attempts > 0:
1592
1912
  # Try again with the new cluster setup. All other errors
1593
1913
  # should be raised.
1594
1914
  retry_attempts -= 1
1595
- await self._client.aclose()
1915
+ await self._pipe.cluster_client.aclose()
1596
1916
  await asyncio.sleep(0.25)
1597
1917
  else:
1598
1918
  # All other errors should be raised.
1599
1919
  raise e
1600
1920
  finally:
1601
- self._command_stack = []
1921
+ await self.reset()
1602
1922
 
1603
1923
  async def _execute(
1604
1924
  self,
@@ -1678,50 +1998,401 @@ class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterComm
1678
1998
  for cmd in default_node[1]:
1679
1999
  # Check if it has a command that failed with a relevant
1680
2000
  # exception
1681
- if type(cmd.result) in self.__class__.ERRORS_ALLOW_RETRY:
2001
+ if type(cmd.result) in RedisCluster.ERRORS_ALLOW_RETRY:
1682
2002
  client.replace_default_node()
1683
2003
  break
1684
2004
 
1685
2005
  return [cmd.result for cmd in stack]
1686
2006
 
1687
- def _split_command_across_slots(
1688
- self, command: str, *keys: KeyT
1689
- ) -> "ClusterPipeline":
1690
- for slot_keys in self._client._partition_keys_by_slot(keys).values():
1691
- self.execute_command(command, *slot_keys)
2007
+ async def reset(self):
2008
+ """
2009
+ Reset back to empty pipeline.
2010
+ """
2011
+ self._command_queue = []
1692
2012
 
1693
- return self
2013
+ def multi(self):
2014
+ raise RedisClusterException(
2015
+ "method multi() is not supported outside of transactional context"
2016
+ )
2017
+
2018
+ async def watch(self, *names):
2019
+ raise RedisClusterException(
2020
+ "method watch() is not supported outside of transactional context"
2021
+ )
2022
+
2023
+ async def unwatch(self):
2024
+ raise RedisClusterException(
2025
+ "method unwatch() is not supported outside of transactional context"
2026
+ )
2027
+
2028
+ async def discard(self):
2029
+ raise RedisClusterException(
2030
+ "method discard() is not supported outside of transactional context"
2031
+ )
2032
+
2033
+ async def unlink(self, *names):
2034
+ if len(names) != 1:
2035
+ raise RedisClusterException(
2036
+ "unlinking multiple keys is not implemented in pipeline command"
2037
+ )
2038
+
2039
+ return self.execute_command("UNLINK", names[0])
2040
+
2041
+
2042
+ class TransactionStrategy(AbstractStrategy):
2043
+ NO_SLOTS_COMMANDS = {"UNWATCH"}
2044
+ IMMEDIATE_EXECUTE_COMMANDS = {"WATCH", "UNWATCH"}
2045
+ UNWATCH_COMMANDS = {"DISCARD", "EXEC", "UNWATCH"}
2046
+ SLOT_REDIRECT_ERRORS = (AskError, MovedError)
2047
+ CONNECTION_ERRORS = (
2048
+ ConnectionError,
2049
+ OSError,
2050
+ ClusterDownError,
2051
+ SlotNotCoveredError,
2052
+ )
2053
+
2054
+ def __init__(self, pipe: ClusterPipeline) -> None:
2055
+ super().__init__(pipe)
2056
+ self._explicit_transaction = False
2057
+ self._watching = False
2058
+ self._pipeline_slots: Set[int] = set()
2059
+ self._transaction_node: Optional[ClusterNode] = None
2060
+ self._transaction_connection: Optional[Connection] = None
2061
+ self._executing = False
2062
+ self._retry = copy(self._pipe.cluster_client.retry)
2063
+ self._retry.update_supported_errors(
2064
+ RedisCluster.ERRORS_ALLOW_RETRY + self.SLOT_REDIRECT_ERRORS
2065
+ )
2066
+
2067
+ def _get_client_and_connection_for_transaction(
2068
+ self,
2069
+ ) -> Tuple[ClusterNode, Connection]:
2070
+ """
2071
+ Find a connection for a pipeline transaction.
2072
+
2073
+ For running an atomic transaction, watch keys ensure that contents have not been
2074
+ altered as long as the watch commands for those keys were sent over the same
2075
+ connection. So once we start watching a key, we fetch a connection to the
2076
+ node that owns that slot and reuse it.
2077
+ """
2078
+ if not self._pipeline_slots:
2079
+ raise RedisClusterException(
2080
+ "At least a command with a key is needed to identify a node"
2081
+ )
2082
+
2083
+ node: ClusterNode = self._pipe.cluster_client.nodes_manager.get_node_from_slot(
2084
+ list(self._pipeline_slots)[0], False
2085
+ )
2086
+ self._transaction_node = node
2087
+
2088
+ if not self._transaction_connection:
2089
+ connection: Connection = self._transaction_node.acquire_connection()
2090
+ self._transaction_connection = connection
2091
+
2092
+ return self._transaction_node, self._transaction_connection
2093
+
2094
+ def execute_command(self, *args: Union[KeyT, EncodableT], **kwargs: Any) -> "Any":
2095
+ # Given the limitation of ClusterPipeline sync API, we have to run it in thread.
2096
+ response = None
2097
+ error = None
2098
+
2099
+ def runner():
2100
+ nonlocal response
2101
+ nonlocal error
2102
+ try:
2103
+ response = asyncio.run(self._execute_command(*args, **kwargs))
2104
+ except Exception as e:
2105
+ error = e
2106
+
2107
+ thread = threading.Thread(target=runner)
2108
+ thread.start()
2109
+ thread.join()
2110
+
2111
+ if error:
2112
+ raise error
2113
+
2114
+ return response
2115
+
2116
+ async def _execute_command(
2117
+ self, *args: Union[KeyT, EncodableT], **kwargs: Any
2118
+ ) -> Any:
2119
+ if self._pipe.cluster_client._initialize:
2120
+ await self._pipe.cluster_client.initialize()
2121
+
2122
+ slot_number: Optional[int] = None
2123
+ if args[0] not in self.NO_SLOTS_COMMANDS:
2124
+ slot_number = await self._pipe.cluster_client._determine_slot(*args)
2125
+
2126
+ if (
2127
+ self._watching or args[0] in self.IMMEDIATE_EXECUTE_COMMANDS
2128
+ ) and not self._explicit_transaction:
2129
+ if args[0] == "WATCH":
2130
+ self._validate_watch()
2131
+
2132
+ if slot_number is not None:
2133
+ if self._pipeline_slots and slot_number not in self._pipeline_slots:
2134
+ raise CrossSlotTransactionError(
2135
+ "Cannot watch or send commands on different slots"
2136
+ )
2137
+
2138
+ self._pipeline_slots.add(slot_number)
2139
+ elif args[0] not in self.NO_SLOTS_COMMANDS:
2140
+ raise RedisClusterException(
2141
+ f"Cannot identify slot number for command: {args[0]},"
2142
+ "it cannot be triggered in a transaction"
2143
+ )
2144
+
2145
+ return self._immediate_execute_command(*args, **kwargs)
2146
+ else:
2147
+ if slot_number is not None:
2148
+ self._pipeline_slots.add(slot_number)
2149
+
2150
+ return super().execute_command(*args, **kwargs)
2151
+
2152
+ def _validate_watch(self):
2153
+ if self._explicit_transaction:
2154
+ raise RedisError("Cannot issue a WATCH after a MULTI")
2155
+
2156
+ self._watching = True
2157
+
2158
+ async def _immediate_execute_command(self, *args, **options):
2159
+ return await self._retry.call_with_retry(
2160
+ lambda: self._get_connection_and_send_command(*args, **options),
2161
+ self._reinitialize_on_error,
2162
+ )
2163
+
2164
+ async def _get_connection_and_send_command(self, *args, **options):
2165
+ redis_node, connection = self._get_client_and_connection_for_transaction()
2166
+ return await self._send_command_parse_response(
2167
+ connection, redis_node, args[0], *args, **options
2168
+ )
2169
+
2170
+ async def _send_command_parse_response(
2171
+ self,
2172
+ connection: Connection,
2173
+ redis_node: ClusterNode,
2174
+ command_name,
2175
+ *args,
2176
+ **options,
2177
+ ):
2178
+ """
2179
+ Send a command and parse the response
2180
+ """
2181
+
2182
+ await connection.send_command(*args)
2183
+ output = await redis_node.parse_response(connection, command_name, **options)
2184
+
2185
+ if command_name in self.UNWATCH_COMMANDS:
2186
+ self._watching = False
2187
+ return output
2188
+
2189
+ async def _reinitialize_on_error(self, error):
2190
+ if self._watching:
2191
+ if type(error) in self.SLOT_REDIRECT_ERRORS and self._executing:
2192
+ raise WatchError("Slot rebalancing occurred while watching keys")
2193
+
2194
+ if (
2195
+ type(error) in self.SLOT_REDIRECT_ERRORS
2196
+ or type(error) in self.CONNECTION_ERRORS
2197
+ ):
2198
+ if self._transaction_connection:
2199
+ self._transaction_connection = None
2200
+
2201
+ self._pipe.cluster_client.reinitialize_counter += 1
2202
+ if (
2203
+ self._pipe.cluster_client.reinitialize_steps
2204
+ and self._pipe.cluster_client.reinitialize_counter
2205
+ % self._pipe.cluster_client.reinitialize_steps
2206
+ == 0
2207
+ ):
2208
+ await self._pipe.cluster_client.nodes_manager.initialize()
2209
+ self.reinitialize_counter = 0
2210
+ else:
2211
+ self._pipe.cluster_client.nodes_manager.update_moved_exception(error)
2212
+
2213
+ self._executing = False
2214
+
2215
+ def _raise_first_error(self, responses, stack):
2216
+ """
2217
+ Raise the first exception on the stack
2218
+ """
2219
+ for r, cmd in zip(responses, stack):
2220
+ if isinstance(r, Exception):
2221
+ self._annotate_exception(r, cmd.position + 1, cmd.args)
2222
+ raise r
1694
2223
 
1695
2224
  def mset_nonatomic(
1696
2225
  self, mapping: Mapping[AnyKeyT, EncodableT]
1697
2226
  ) -> "ClusterPipeline":
1698
- encoder = self._client.encoder
2227
+ raise NotImplementedError("Method is not supported in transactional context.")
1699
2228
 
1700
- slots_pairs = {}
1701
- for pair in mapping.items():
1702
- slot = key_slot(encoder.encode(pair[0]))
1703
- slots_pairs.setdefault(slot, []).extend(pair)
2229
+ async def execute(
2230
+ self, raise_on_error: bool = True, allow_redirections: bool = True
2231
+ ) -> List[Any]:
2232
+ stack = self._command_queue
2233
+ if not stack and (not self._watching or not self._pipeline_slots):
2234
+ return []
1704
2235
 
1705
- for pairs in slots_pairs.values():
1706
- self.execute_command("MSET", *pairs)
2236
+ return await self._execute_transaction_with_retries(stack, raise_on_error)
1707
2237
 
1708
- return self
2238
+ async def _execute_transaction_with_retries(
2239
+ self, stack: List["PipelineCommand"], raise_on_error: bool
2240
+ ):
2241
+ return await self._retry.call_with_retry(
2242
+ lambda: self._execute_transaction(stack, raise_on_error),
2243
+ self._reinitialize_on_error,
2244
+ )
1709
2245
 
2246
+ async def _execute_transaction(
2247
+ self, stack: List["PipelineCommand"], raise_on_error: bool
2248
+ ):
2249
+ if len(self._pipeline_slots) > 1:
2250
+ raise CrossSlotTransactionError(
2251
+ "All keys involved in a cluster transaction must map to the same slot"
2252
+ )
1710
2253
 
1711
- for command in PIPELINE_BLOCKED_COMMANDS:
1712
- command = command.replace(" ", "_").lower()
1713
- if command == "mset_nonatomic":
1714
- continue
2254
+ self._executing = True
1715
2255
 
1716
- setattr(ClusterPipeline, command, block_pipeline_command(command))
2256
+ redis_node, connection = self._get_client_and_connection_for_transaction()
1717
2257
 
2258
+ stack = chain(
2259
+ [PipelineCommand(0, "MULTI")],
2260
+ stack,
2261
+ [PipelineCommand(0, "EXEC")],
2262
+ )
2263
+ commands = [c.args for c in stack if EMPTY_RESPONSE not in c.kwargs]
2264
+ packed_commands = connection.pack_commands(commands)
2265
+ await connection.send_packed_command(packed_commands)
2266
+ errors = []
2267
+
2268
+ # parse off the response for MULTI
2269
+ # NOTE: we need to handle ResponseErrors here and continue
2270
+ # so that we read all the additional command messages from
2271
+ # the socket
2272
+ try:
2273
+ await redis_node.parse_response(connection, "MULTI")
2274
+ except ResponseError as e:
2275
+ self._annotate_exception(e, 0, "MULTI")
2276
+ errors.append(e)
2277
+ except self.CONNECTION_ERRORS as cluster_error:
2278
+ self._annotate_exception(cluster_error, 0, "MULTI")
2279
+ raise
1718
2280
 
1719
- class PipelineCommand:
1720
- def __init__(self, position: int, *args: Any, **kwargs: Any) -> None:
1721
- self.args = args
1722
- self.kwargs = kwargs
1723
- self.position = position
1724
- self.result: Union[Any, Exception] = None
2281
+ # and all the other commands
2282
+ for i, command in enumerate(self._command_queue):
2283
+ if EMPTY_RESPONSE in command.kwargs:
2284
+ errors.append((i, command.kwargs[EMPTY_RESPONSE]))
2285
+ else:
2286
+ try:
2287
+ _ = await redis_node.parse_response(connection, "_")
2288
+ except self.SLOT_REDIRECT_ERRORS as slot_error:
2289
+ self._annotate_exception(slot_error, i + 1, command.args)
2290
+ errors.append(slot_error)
2291
+ except self.CONNECTION_ERRORS as cluster_error:
2292
+ self._annotate_exception(cluster_error, i + 1, command.args)
2293
+ raise
2294
+ except ResponseError as e:
2295
+ self._annotate_exception(e, i + 1, command.args)
2296
+ errors.append(e)
2297
+
2298
+ response = None
2299
+ # parse the EXEC.
2300
+ try:
2301
+ response = await redis_node.parse_response(connection, "EXEC")
2302
+ except ExecAbortError:
2303
+ if errors:
2304
+ raise errors[0]
2305
+ raise
1725
2306
 
1726
- def __repr__(self) -> str:
1727
- return f"[{self.position}] {self.args} ({self.kwargs})"
2307
+ self._executing = False
2308
+
2309
+ # EXEC clears any watched keys
2310
+ self._watching = False
2311
+
2312
+ if response is None:
2313
+ raise WatchError("Watched variable changed.")
2314
+
2315
+ # put any parse errors into the response
2316
+ for i, e in errors:
2317
+ response.insert(i, e)
2318
+
2319
+ if len(response) != len(self._command_queue):
2320
+ raise InvalidPipelineStack(
2321
+ "Unexpected response length for cluster pipeline EXEC."
2322
+ " Command stack was {} but response had length {}".format(
2323
+ [c.args[0] for c in self._command_queue], len(response)
2324
+ )
2325
+ )
2326
+
2327
+ # find any errors in the response and raise if necessary
2328
+ if raise_on_error or len(errors) > 0:
2329
+ self._raise_first_error(
2330
+ response,
2331
+ self._command_queue,
2332
+ )
2333
+
2334
+ # We have to run response callbacks manually
2335
+ data = []
2336
+ for r, cmd in zip(response, self._command_queue):
2337
+ if not isinstance(r, Exception):
2338
+ command_name = cmd.args[0]
2339
+ if command_name in self._pipe.cluster_client.response_callbacks:
2340
+ r = self._pipe.cluster_client.response_callbacks[command_name](
2341
+ r, **cmd.kwargs
2342
+ )
2343
+ data.append(r)
2344
+ return data
2345
+
2346
+ async def reset(self):
2347
+ self._command_queue = []
2348
+
2349
+ # make sure to reset the connection state in the event that we were
2350
+ # watching something
2351
+ if self._transaction_connection:
2352
+ try:
2353
+ # call this manually since our unwatch or
2354
+ # immediate_execute_command methods can call reset()
2355
+ await self._transaction_connection.send_command("UNWATCH")
2356
+ await self._transaction_connection.read_response()
2357
+ # we can safely return the connection to the pool here since we're
2358
+ # sure we're no longer WATCHing anything
2359
+ self._transaction_node.release(self._transaction_connection)
2360
+ self._transaction_connection = None
2361
+ except self.CONNECTION_ERRORS:
2362
+ # disconnect will also remove any previous WATCHes
2363
+ if self._transaction_connection:
2364
+ await self._transaction_connection.disconnect()
2365
+
2366
+ # clean up the other instance attributes
2367
+ self._transaction_node = None
2368
+ self._watching = False
2369
+ self._explicit_transaction = False
2370
+ self._pipeline_slots = set()
2371
+ self._executing = False
2372
+
2373
+ def multi(self):
2374
+ if self._explicit_transaction:
2375
+ raise RedisError("Cannot issue nested calls to MULTI")
2376
+ if self._command_queue:
2377
+ raise RedisError(
2378
+ "Commands without an initial WATCH have already been issued"
2379
+ )
2380
+ self._explicit_transaction = True
2381
+
2382
+ async def watch(self, *names):
2383
+ if self._explicit_transaction:
2384
+ raise RedisError("Cannot issue a WATCH after a MULTI")
2385
+
2386
+ return await self.execute_command("WATCH", *names)
2387
+
2388
+ async def unwatch(self):
2389
+ if self._watching:
2390
+ return await self.execute_command("UNWATCH")
2391
+
2392
+ return True
2393
+
2394
+ async def discard(self):
2395
+ await self.reset()
2396
+
2397
+ async def unlink(self, *names):
2398
+ return self.execute_command("UNLINK", *names)