web3 7.0.0b5__py3-none-any.whl → 7.0.0b7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. ens/__init__.py +13 -2
  2. web3/__init__.py +21 -5
  3. web3/_utils/batching.py +217 -0
  4. web3/_utils/caching.py +26 -2
  5. web3/_utils/compat/__init__.py +1 -0
  6. web3/_utils/contract_sources/contract_data/arrays_contract.py +3 -3
  7. web3/_utils/contract_sources/contract_data/bytes_contracts.py +5 -5
  8. web3/_utils/contract_sources/contract_data/constructor_contracts.py +7 -7
  9. web3/_utils/contract_sources/contract_data/contract_caller_tester.py +3 -3
  10. web3/_utils/contract_sources/contract_data/emitter_contract.py +3 -3
  11. web3/_utils/contract_sources/contract_data/event_contracts.py +5 -5
  12. web3/_utils/contract_sources/contract_data/extended_resolver.py +3 -3
  13. web3/_utils/contract_sources/contract_data/fallback_function_contract.py +3 -3
  14. web3/_utils/contract_sources/contract_data/function_name_tester_contract.py +3 -3
  15. web3/_utils/contract_sources/contract_data/math_contract.py +3 -3
  16. web3/_utils/contract_sources/contract_data/offchain_lookup.py +3 -3
  17. web3/_utils/contract_sources/contract_data/offchain_resolver.py +3 -3
  18. web3/_utils/contract_sources/contract_data/panic_errors_contract.py +3 -3
  19. web3/_utils/contract_sources/contract_data/payable_tester.py +3 -3
  20. web3/_utils/contract_sources/contract_data/receive_function_contracts.py +5 -5
  21. web3/_utils/contract_sources/contract_data/reflector_contracts.py +3 -3
  22. web3/_utils/contract_sources/contract_data/revert_contract.py +3 -3
  23. web3/_utils/contract_sources/contract_data/simple_resolver.py +3 -3
  24. web3/_utils/contract_sources/contract_data/storage_contract.py +3 -3
  25. web3/_utils/contract_sources/contract_data/string_contract.py +3 -3
  26. web3/_utils/contract_sources/contract_data/tuple_contracts.py +5 -5
  27. web3/_utils/events.py +2 -2
  28. web3/_utils/http.py +3 -0
  29. web3/_utils/http_session_manager.py +280 -0
  30. web3/_utils/method_formatters.py +0 -2
  31. web3/_utils/module_testing/eth_module.py +92 -119
  32. web3/_utils/module_testing/module_testing_utils.py +27 -9
  33. web3/_utils/module_testing/persistent_connection_provider.py +1 -0
  34. web3/_utils/module_testing/web3_module.py +438 -17
  35. web3/_utils/rpc_abi.py +0 -3
  36. web3/beacon/__init__.py +5 -0
  37. web3/beacon/async_beacon.py +9 -5
  38. web3/beacon/beacon.py +7 -5
  39. web3/contract/__init__.py +7 -0
  40. web3/contract/base_contract.py +10 -1
  41. web3/contract/utils.py +112 -4
  42. web3/eth/__init__.py +7 -0
  43. web3/eth/async_eth.py +5 -37
  44. web3/eth/eth.py +7 -57
  45. web3/exceptions.py +20 -0
  46. web3/gas_strategies/time_based.py +2 -2
  47. web3/main.py +21 -9
  48. web3/manager.py +113 -8
  49. web3/method.py +29 -9
  50. web3/middleware/__init__.py +17 -0
  51. web3/middleware/base.py +43 -0
  52. web3/module.py +47 -7
  53. web3/providers/__init__.py +21 -0
  54. web3/providers/async_base.py +55 -23
  55. web3/providers/base.py +59 -26
  56. web3/providers/eth_tester/__init__.py +5 -0
  57. web3/providers/eth_tester/defaults.py +0 -6
  58. web3/providers/eth_tester/middleware.py +3 -8
  59. web3/providers/ipc.py +23 -8
  60. web3/providers/legacy_websocket.py +26 -1
  61. web3/providers/persistent/__init__.py +7 -0
  62. web3/providers/persistent/async_ipc.py +60 -76
  63. web3/providers/persistent/persistent.py +134 -10
  64. web3/providers/persistent/request_processor.py +98 -14
  65. web3/providers/persistent/websocket.py +43 -66
  66. web3/providers/rpc/__init__.py +5 -0
  67. web3/providers/rpc/async_rpc.py +34 -12
  68. web3/providers/rpc/rpc.py +34 -12
  69. web3/providers/rpc/utils.py +0 -3
  70. web3/tools/benchmark/main.py +7 -6
  71. web3/tools/benchmark/node.py +1 -1
  72. web3/types.py +7 -1
  73. web3/utils/__init__.py +14 -5
  74. web3/utils/async_exception_handling.py +19 -7
  75. web3/utils/exception_handling.py +7 -5
  76. {web3-7.0.0b5.dist-info → web3-7.0.0b7.dist-info}/LICENSE +1 -1
  77. {web3-7.0.0b5.dist-info → web3-7.0.0b7.dist-info}/METADATA +33 -20
  78. {web3-7.0.0b5.dist-info → web3-7.0.0b7.dist-info}/RECORD +80 -80
  79. {web3-7.0.0b5.dist-info → web3-7.0.0b7.dist-info}/WHEEL +1 -1
  80. web3/_utils/contract_sources/contract_data/address_reflector.py +0 -29
  81. web3/_utils/request.py +0 -265
  82. {web3-7.0.0b5.dist-info → web3-7.0.0b7.dist-info}/top_level.txt +0 -0
web3/contract/utils.py CHANGED
@@ -9,6 +9,7 @@ from typing import (
9
9
  Tuple,
10
10
  Type,
11
11
  Union,
12
+ cast,
12
13
  )
13
14
 
14
15
  from eth_abi.exceptions import (
@@ -16,6 +17,11 @@ from eth_abi.exceptions import (
16
17
  )
17
18
  from eth_typing import (
18
19
  ChecksumAddress,
20
+ TypeStr,
21
+ )
22
+ from eth_utils.toolz import (
23
+ compose,
24
+ curry,
19
25
  )
20
26
  from hexbytes import (
21
27
  HexBytes,
@@ -60,10 +66,50 @@ if TYPE_CHECKING:
60
66
  AsyncWeb3,
61
67
  Web3,
62
68
  )
69
+ from web3.providers.persistent import ( # noqa: F401
70
+ PersistentConnectionProvider,
71
+ )
63
72
 
64
73
  ACCEPTABLE_EMPTY_STRINGS = ["0x", b"0x", "", b""]
65
74
 
66
75
 
76
+ @curry
77
+ def format_contract_call_return_data_curried(
78
+ async_w3: Union["AsyncWeb3", "Web3"],
79
+ decode_tuples: bool,
80
+ fn_abi: ABIFunction,
81
+ function_identifier: FunctionIdentifier,
82
+ normalizers: Tuple[Callable[..., Any], ...],
83
+ output_types: Sequence[TypeStr],
84
+ return_data: Any,
85
+ ) -> Any:
86
+ """
87
+ Helper function for formatting contract call return data for batch requests. Curry
88
+ with all arguments except `return_data` and process `return_data` once it is
89
+ available.
90
+ """
91
+ try:
92
+ output_data = async_w3.codec.decode(output_types, return_data)
93
+ except DecodingError as e:
94
+ msg = (
95
+ f"Could not decode contract function call to {function_identifier} "
96
+ f"with return data: {str(return_data)}, output_types: {output_types}"
97
+ )
98
+ raise BadFunctionCallOutput(msg) from e
99
+
100
+ _normalizers = itertools.chain(
101
+ BASE_RETURN_NORMALIZERS,
102
+ normalizers,
103
+ )
104
+ normalized_data = map_abi_data(_normalizers, output_types, output_data)
105
+
106
+ if decode_tuples:
107
+ decoded = named_tree(fn_abi["outputs"], normalized_data)
108
+ normalized_data = recursive_dict_to_namedtuple(decoded)
109
+
110
+ return normalized_data[0] if len(normalized_data) == 1 else normalized_data
111
+
112
+
67
113
  def call_contract_function(
68
114
  w3: "Web3",
69
115
  address: ChecksumAddress,
@@ -108,6 +154,34 @@ def call_contract_function(
108
154
 
109
155
  output_types = get_abi_output_types(fn_abi)
110
156
 
157
+ provider = w3.provider
158
+ if hasattr(provider, "_is_batching") and provider._is_batching:
159
+ # request_information == ((method, params), response_formatters)
160
+ request_information = tuple(return_data)
161
+ method_and_params = request_information[0]
162
+
163
+ # append return data formatting to result formatters
164
+ current_response_formatters = request_information[1]
165
+ current_result_formatters = current_response_formatters[0]
166
+ updated_result_formatters = compose(
167
+ # contract call return data formatter
168
+ format_contract_call_return_data_curried(
169
+ w3,
170
+ decode_tuples,
171
+ fn_abi,
172
+ function_identifier,
173
+ normalizers,
174
+ output_types,
175
+ ),
176
+ current_result_formatters,
177
+ )
178
+ response_formatters = (
179
+ updated_result_formatters, # result formatters
180
+ current_response_formatters[1], # error formatters
181
+ current_response_formatters[2], # null result formatters
182
+ )
183
+ return (method_and_params, response_formatters)
184
+
111
185
  try:
112
186
  output_data = w3.codec.decode(output_types, return_data)
113
187
  except DecodingError as e:
@@ -319,6 +393,43 @@ async def async_call_contract_function(
319
393
 
320
394
  output_types = get_abi_output_types(fn_abi)
321
395
 
396
+ if async_w3.provider._is_batching:
397
+ contract_call_return_data_formatter = format_contract_call_return_data_curried(
398
+ async_w3,
399
+ decode_tuples,
400
+ fn_abi,
401
+ function_identifier,
402
+ normalizers,
403
+ output_types,
404
+ )
405
+ if async_w3.provider.has_persistent_connection:
406
+ # get the current request id
407
+ provider = cast("PersistentConnectionProvider", async_w3.provider)
408
+ current_request_id = provider._batch_request_counter - 1
409
+ provider._request_processor.append_result_formatter_for_request(
410
+ current_request_id, contract_call_return_data_formatter
411
+ )
412
+ else:
413
+ # request_information == ((method, params), response_formatters)
414
+ request_information = tuple(return_data)
415
+ method_and_params = request_information[0]
416
+
417
+ # append return data formatter to result formatters
418
+ current_response_formatters = request_information[1]
419
+ current_result_formatters = current_response_formatters[0]
420
+ updated_result_formatters = compose(
421
+ contract_call_return_data_formatter,
422
+ current_result_formatters,
423
+ )
424
+ response_formatters = (
425
+ updated_result_formatters, # result formatters
426
+ current_response_formatters[1], # error formatters
427
+ current_response_formatters[2], # null result formatters
428
+ )
429
+ return (method_and_params, response_formatters)
430
+
431
+ return return_data
432
+
322
433
  try:
323
434
  output_data = async_w3.codec.decode(output_types, return_data)
324
435
  except DecodingError as e:
@@ -350,10 +461,7 @@ async def async_call_contract_function(
350
461
  decoded = named_tree(fn_abi["outputs"], normalized_data)
351
462
  normalized_data = recursive_dict_to_namedtuple(decoded)
352
463
 
353
- if len(normalized_data) == 1:
354
- return normalized_data[0]
355
- else:
356
- return normalized_data
464
+ return normalized_data[0] if len(normalized_data) == 1 else normalized_data
357
465
 
358
466
 
359
467
  async def async_transact_with_contract_function(
web3/eth/__init__.py CHANGED
@@ -8,3 +8,10 @@ from .eth import (
8
8
  Contract,
9
9
  Eth,
10
10
  )
11
+
12
+ __all__ = [
13
+ "AsyncEth",
14
+ "BaseEth",
15
+ "Contract",
16
+ "Eth",
17
+ ]
web3/eth/async_eth.py CHANGED
@@ -35,6 +35,9 @@ from web3._utils.async_transactions import (
35
35
  from web3._utils.blocks import (
36
36
  select_method_for_block_identifier,
37
37
  )
38
+ from web3._utils.compat import (
39
+ Unpack,
40
+ )
38
41
  from web3._utils.fee_utils import (
39
42
  async_fee_history_priority_fee,
40
43
  )
@@ -124,17 +127,6 @@ class AsyncEth(BaseEth):
124
127
  async def accounts(self) -> Tuple[ChecksumAddress]:
125
128
  return await self._accounts()
126
129
 
127
- # eth_hashrate
128
-
129
- _hashrate: Method[Callable[[], Awaitable[int]]] = Method(
130
- RPC.eth_hashrate,
131
- is_property=True,
132
- )
133
-
134
- @property
135
- async def hashrate(self) -> int:
136
- return await self._hashrate()
137
-
138
130
  # eth_blockNumber
139
131
 
140
132
  get_block_number: Method[Callable[[], Awaitable[BlockNumber]]] = Method(
@@ -157,17 +149,6 @@ class AsyncEth(BaseEth):
157
149
  async def chain_id(self) -> int:
158
150
  return await self._chain_id()
159
151
 
160
- # eth_coinbase
161
-
162
- _coinbase: Method[Callable[[], Awaitable[ChecksumAddress]]] = Method(
163
- RPC.eth_coinbase,
164
- is_property=True,
165
- )
166
-
167
- @property
168
- async def coinbase(self) -> ChecksumAddress:
169
- return await self._coinbase()
170
-
171
152
  # eth_gasPrice
172
153
 
173
154
  _gas_price: Method[Callable[[], Awaitable[Wei]]] = Method(
@@ -203,17 +184,6 @@ class AsyncEth(BaseEth):
203
184
  )
204
185
  return await async_fee_history_priority_fee(self)
205
186
 
206
- # eth_mining
207
-
208
- _mining: Method[Callable[[], Awaitable[bool]]] = Method(
209
- RPC.eth_mining,
210
- is_property=True,
211
- )
212
-
213
- @property
214
- async def mining(self) -> bool:
215
- return await self._mining()
216
-
217
187
  # eth_syncing
218
188
 
219
189
  _syncing: Method[Callable[[], Awaitable[Union[SyncStatus, bool]]]] = Method(
@@ -594,10 +564,8 @@ class AsyncEth(BaseEth):
594
564
  self.w3, current_transaction, new_transaction
595
565
  )
596
566
 
597
- # todo: Update Any to stricter kwarg checking with TxParams
598
- # https://github.com/python/mypy/issues/4441
599
567
  async def modify_transaction(
600
- self, transaction_hash: _Hash32, **transaction_params: Any
568
+ self, transaction_hash: _Hash32, **transaction_params: Unpack[TxParams]
601
569
  ) -> HexBytes:
602
570
  assert_valid_transaction_params(cast(TxParams, transaction_params))
603
571
 
@@ -765,7 +733,7 @@ class AsyncEth(BaseEth):
765
733
 
766
734
  @overload
767
735
  # mypy error: Overloaded function signatures 1 and 2 overlap with incompatible return types # noqa: E501
768
- def contract(self, address: None = None, **kwargs: Any) -> Type[AsyncContract]: # type: ignore[misc] # noqa: E501
736
+ def contract(self, address: None = None, **kwargs: Any) -> Type[AsyncContract]: # type: ignore[overload-overlap] # noqa: E501
769
737
  ...
770
738
 
771
739
  @overload
web3/eth/eth.py CHANGED
@@ -30,6 +30,9 @@ from hexbytes import (
30
30
  from web3._utils.blocks import (
31
31
  select_method_for_block_identifier,
32
32
  )
33
+ from web3._utils.compat import (
34
+ Unpack,
35
+ )
33
36
  from web3._utils.fee_utils import (
34
37
  fee_history_priority_fee,
35
38
  )
@@ -116,17 +119,6 @@ class Eth(BaseEth):
116
119
  def accounts(self) -> Tuple[ChecksumAddress]:
117
120
  return self._accounts()
118
121
 
119
- # eth_hashrate
120
-
121
- _hashrate: Method[Callable[[], int]] = Method(
122
- RPC.eth_hashrate,
123
- is_property=True,
124
- )
125
-
126
- @property
127
- def hashrate(self) -> int:
128
- return self._hashrate()
129
-
130
122
  # eth_blockNumber
131
123
 
132
124
  get_block_number: Method[Callable[[], BlockNumber]] = Method(
@@ -149,17 +141,6 @@ class Eth(BaseEth):
149
141
  def chain_id(self) -> int:
150
142
  return self._chain_id()
151
143
 
152
- # eth_coinbase
153
-
154
- _coinbase: Method[Callable[[], ChecksumAddress]] = Method(
155
- RPC.eth_coinbase,
156
- is_property=True,
157
- )
158
-
159
- @property
160
- def coinbase(self) -> ChecksumAddress:
161
- return self._coinbase()
162
-
163
144
  # eth_gasPrice
164
145
 
165
146
  _gas_price: Method[Callable[[], Wei]] = Method(
@@ -195,17 +176,6 @@ class Eth(BaseEth):
195
176
  )
196
177
  return fee_history_priority_fee(self)
197
178
 
198
- # eth_mining
199
-
200
- _mining: Method[Callable[[], bool]] = Method(
201
- RPC.eth_mining,
202
- is_property=True,
203
- )
204
-
205
- @property
206
- def mining(self) -> bool:
207
- return self._mining()
208
-
209
179
  # eth_syncing
210
180
 
211
181
  _syncing: Method[Callable[[], Union[SyncStatus, bool]]] = Method(
@@ -282,7 +252,8 @@ class Eth(BaseEth):
282
252
  return self._call(transaction, block_identifier, state_override)
283
253
  except OffchainLookup as offchain_lookup:
284
254
  durin_calldata = handle_offchain_lookup(
285
- offchain_lookup.payload, transaction
255
+ offchain_lookup.payload,
256
+ transaction,
286
257
  )
287
258
  transaction["data"] = durin_calldata
288
259
 
@@ -596,10 +567,8 @@ class Eth(BaseEth):
596
567
  current_transaction = get_required_transaction(self.w3, transaction_hash)
597
568
  return replace_transaction(self.w3, current_transaction, new_transaction)
598
569
 
599
- # todo: Update Any to stricter kwarg checking with TxParams
600
- # https://github.com/python/mypy/issues/4441
601
570
  def modify_transaction(
602
- self, transaction_hash: _Hash32, **transaction_params: Any
571
+ self, transaction_hash: _Hash32, **transaction_params: Unpack[TxParams]
603
572
  ) -> HexBytes:
604
573
  assert_valid_transaction_params(cast(TxParams, transaction_params))
605
574
  current_transaction = get_required_transaction(self.w3, transaction_hash)
@@ -659,28 +628,9 @@ class Eth(BaseEth):
659
628
  mungers=[default_root_munger],
660
629
  )
661
630
 
662
- # eth_submitHashrate
663
-
664
- submit_hashrate: Method[Callable[[int, _Hash32], bool]] = Method(
665
- RPC.eth_submitHashrate,
666
- mungers=[default_root_munger],
667
- )
668
-
669
- # eth_getWork, eth_submitWork
670
-
671
- get_work: Method[Callable[[], List[HexBytes]]] = Method(
672
- RPC.eth_getWork,
673
- is_property=True,
674
- )
675
-
676
- submit_work: Method[Callable[[int, _Hash32, _Hash32], bool]] = Method(
677
- RPC.eth_submitWork,
678
- mungers=[default_root_munger],
679
- )
680
-
681
631
  @overload
682
632
  # type error: Overloaded function signatures 1 and 2 overlap with incompatible return types # noqa: E501
683
- def contract(self, address: None = None, **kwargs: Any) -> Type[Contract]: # type: ignore[misc] # noqa: E501
633
+ def contract(self, address: None = None, **kwargs: Any) -> Type[Contract]: # type: ignore[overload-overlap] # noqa: E501
684
634
  ...
685
635
 
686
636
  @overload
web3/exceptions.py CHANGED
@@ -1,6 +1,7 @@
1
1
  import datetime
2
2
  import time
3
3
  from typing import (
4
+ TYPE_CHECKING,
4
5
  Any,
5
6
  Dict,
6
7
  Optional,
@@ -16,6 +17,9 @@ from web3.types import (
16
17
  RPCResponse,
17
18
  )
18
19
 
20
+ if TYPE_CHECKING:
21
+ import asyncio
22
+
19
23
 
20
24
  class Web3Exception(Exception):
21
25
  """
@@ -307,6 +311,22 @@ class BadResponseFormat(Web3Exception):
307
311
  """
308
312
 
309
313
 
314
+ class TaskNotRunning(Web3Exception):
315
+ """
316
+ Used to signal between asyncio contexts that a task that is being awaited
317
+ is not currently running.
318
+ """
319
+
320
+ def __init__(
321
+ self, task: "asyncio.Task[Any]", message: Optional[str] = None
322
+ ) -> None:
323
+ self.task = task
324
+ if message is None:
325
+ message = f"Task {task} is not running."
326
+ self.message = message
327
+ super().__init__(message)
328
+
329
+
310
330
  class Web3RPCError(Web3Exception):
311
331
  """
312
332
  Raised when a JSON-RPC response contains an error field.
@@ -110,11 +110,11 @@ def _aggregate_miner_data(
110
110
  # types ignored b/c mypy has trouble inferring gas_prices: Sequence[Wei]
111
111
  price_percentile = percentile(gas_prices, percentile=20) # type: ignore
112
112
  except InsufficientData:
113
- price_percentile = min(gas_prices) # type: ignore
113
+ price_percentile = min(gas_prices)
114
114
  yield MinerData(
115
115
  miner,
116
116
  len(set(block_hashes)),
117
- min(gas_prices), # type: ignore
117
+ min(gas_prices),
118
118
  price_percentile,
119
119
  )
120
120
 
web3/main.py CHANGED
@@ -34,6 +34,7 @@ from typing import (
34
34
  TYPE_CHECKING,
35
35
  Any,
36
36
  AsyncIterator,
37
+ Callable,
37
38
  Dict,
38
39
  Generator,
39
40
  List,
@@ -100,6 +101,9 @@ from web3.manager import (
100
101
  RequestManager as DefaultRequestManager,
101
102
  )
102
103
  from web3.middleware.base import MiddlewareOnion
104
+ from web3.method import (
105
+ Method,
106
+ )
103
107
  from web3.module import (
104
108
  Module,
105
109
  )
@@ -143,7 +147,9 @@ from web3.types import (
143
147
  )
144
148
 
145
149
  if TYPE_CHECKING:
150
+ from web3._utils.batching import RequestBatcher # noqa: F401
146
151
  from web3._utils.empty import Empty # noqa: F401
152
+ from web3.providers.persistent import PersistentConnectionProvider # noqa: F401
147
153
 
148
154
 
149
155
  def get_async_default_modules() -> Dict[str, Union[Type[Module], Sequence[Any]]]:
@@ -336,6 +342,13 @@ class BaseWeb3:
336
342
  def is_encodable(self, _type: TypeStr, value: Any) -> bool:
337
343
  return self.codec.is_encodable(_type, value)
338
344
 
345
+ # -- APIs for high-level requests -- #
346
+
347
+ def batch_requests(
348
+ self,
349
+ ) -> "RequestBatcher[Method[Callable[..., Any]]]":
350
+ return self.manager._batch_requests()
351
+
339
352
 
340
353
  class Web3(BaseWeb3):
341
354
  # mypy types
@@ -468,7 +481,7 @@ class AsyncWeb3(BaseWeb3):
468
481
  new_ens.w3 = self # set self object reference for ``AsyncENS.w3``
469
482
  self._ens = new_ens
470
483
 
471
- # -- persistent connection methods -- #
484
+ # -- persistent connection methods -- #
472
485
 
473
486
  @property
474
487
  @persistent_connection_provider_method()
@@ -511,12 +524,11 @@ class AsyncWeb3(BaseWeb3):
511
524
  "when instantiating via ``async for``."
512
525
  )
513
526
  async def __aiter__(self) -> AsyncIterator[Self]:
514
- if not await self.provider.is_connected():
515
- await self.provider.connect()
516
-
527
+ provider = self.provider
517
528
  while True:
518
- try:
519
- yield self
520
- except Exception:
521
- # provider should handle connection / reconnection
522
- continue
529
+ await provider.connect()
530
+ yield self
531
+ cast("PersistentConnectionProvider", provider).logger.error(
532
+ "Connection interrupted, attempting to reconnect..."
533
+ )
534
+ await provider.disconnect()
web3/manager.py CHANGED
@@ -1,9 +1,11 @@
1
+ import asyncio
1
2
  import logging
2
3
  from typing import (
3
4
  TYPE_CHECKING,
4
5
  Any,
5
6
  AsyncGenerator,
6
7
  Callable,
8
+ Coroutine,
7
9
  List,
8
10
  Optional,
9
11
  Sequence,
@@ -22,6 +24,9 @@ from websockets.exceptions import (
22
24
  ConnectionClosedOK,
23
25
  )
24
26
 
27
+ from web3._utils.batching import (
28
+ RequestBatcher,
29
+ )
25
30
  from web3._utils.caching import (
26
31
  generate_cache_key,
27
32
  )
@@ -35,9 +40,13 @@ from web3.exceptions import (
35
40
  BadResponseFormat,
36
41
  MethodUnavailable,
37
42
  ProviderConnectionError,
43
+ TaskNotRunning,
38
44
  Web3RPCError,
39
45
  Web3TypeError,
40
46
  )
47
+ from web3.method import (
48
+ Method,
49
+ )
41
50
  from web3.middleware import (
42
51
  AttributeDictMiddleware,
43
52
  BufferedGasEstimateMiddleware,
@@ -54,8 +63,12 @@ from web3.module import (
54
63
  )
55
64
  from web3.providers import (
56
65
  AutoProvider,
66
+ JSONBaseProvider,
57
67
  PersistentConnectionProvider,
58
68
  )
69
+ from web3.providers.async_base import (
70
+ AsyncJSONBaseProvider,
71
+ )
59
72
  from web3.types import (
60
73
  RPCEndpoint,
61
74
  RPCResponse,
@@ -375,6 +388,88 @@ class RequestManager:
375
388
  response, params, error_formatters, null_result_formatters
376
389
  )
377
390
 
391
+ # -- batch requests management -- #
392
+
393
+ def _batch_requests(self) -> RequestBatcher[Method[Callable[..., Any]]]:
394
+ """
395
+ Context manager for making batch requests
396
+ """
397
+ if not isinstance(self.provider, (AsyncJSONBaseProvider, JSONBaseProvider)):
398
+ raise Web3TypeError("Batch requests are not supported by this provider.")
399
+ return RequestBatcher(self.w3)
400
+
401
+ def _make_batch_request(
402
+ self, requests_info: List[Tuple[Tuple["RPCEndpoint", Any], Sequence[Any]]]
403
+ ) -> List[RPCResponse]:
404
+ """
405
+ Make a batch request using the provider
406
+ """
407
+ provider = cast(JSONBaseProvider, self.provider)
408
+ request_func = provider.batch_request_func(
409
+ cast("Web3", self.w3), cast("MiddlewareOnion", self.middleware_onion)
410
+ )
411
+ responses = request_func(
412
+ [
413
+ (method, params)
414
+ for (method, params), _response_formatters in requests_info
415
+ ]
416
+ )
417
+ formatted_responses = [
418
+ self._format_batched_response(info, resp)
419
+ for info, resp in zip(requests_info, responses)
420
+ ]
421
+ return list(formatted_responses)
422
+
423
+ async def _async_make_batch_request(
424
+ self,
425
+ requests_info: List[
426
+ Coroutine[Any, Any, Tuple[Tuple["RPCEndpoint", Any], Sequence[Any]]]
427
+ ],
428
+ ) -> List[RPCResponse]:
429
+ """
430
+ Make an asynchronous batch request using the provider
431
+ """
432
+ provider = cast(AsyncJSONBaseProvider, self.provider)
433
+ request_func = await provider.batch_request_func(
434
+ cast("AsyncWeb3", self.w3),
435
+ cast("MiddlewareOnion", self.middleware_onion),
436
+ )
437
+ # since we add items to the batch without awaiting, we unpack the coroutines
438
+ # and await them all here
439
+ unpacked_requests_info = await asyncio.gather(*requests_info)
440
+ responses = await request_func(
441
+ [
442
+ (method, params)
443
+ for (method, params), _response_formatters in unpacked_requests_info
444
+ ]
445
+ )
446
+
447
+ if isinstance(self.provider, PersistentConnectionProvider):
448
+ # call _process_response for each response in the batch
449
+ return [await self._process_response(resp) for resp in responses]
450
+
451
+ formatted_responses = [
452
+ self._format_batched_response(info, resp)
453
+ for info, resp in zip(unpacked_requests_info, responses)
454
+ ]
455
+ return list(formatted_responses)
456
+
457
+ def _format_batched_response(
458
+ self,
459
+ requests_info: Tuple[Tuple[RPCEndpoint, Any], Sequence[Any]],
460
+ response: RPCResponse,
461
+ ) -> RPCResponse:
462
+ result_formatters, error_formatters, null_result_formatters = requests_info[1]
463
+ return apply_result_formatters(
464
+ result_formatters,
465
+ self.formatted_response(
466
+ response,
467
+ requests_info[0][1],
468
+ error_formatters,
469
+ null_result_formatters,
470
+ ),
471
+ )
472
+
378
473
  # -- persistent connection -- #
379
474
 
380
475
  async def send(self, method: RPCEndpoint, params: Any) -> RPCResponse:
@@ -408,14 +503,24 @@ class RequestManager:
408
503
  )
409
504
 
410
505
  while True:
411
- response = await self._request_processor.pop_raw_response(subscription=True)
412
- if (
413
- response is not None
414
- and response.get("params", {}).get("subscription")
415
- in self._request_processor.active_subscriptions
416
- ):
417
- # if response is an active subscription response, process it
418
- yield await self._process_response(response)
506
+ try:
507
+ response = await self._request_processor.pop_raw_response(
508
+ subscription=True
509
+ )
510
+ if (
511
+ response is not None
512
+ and response.get("params", {}).get("subscription")
513
+ in self._request_processor.active_subscriptions
514
+ ):
515
+ # if response is an active subscription response, process it
516
+ yield await self._process_response(response)
517
+ except TaskNotRunning:
518
+ self._provider._handle_listener_task_exceptions()
519
+ self.logger.error(
520
+ "Message listener background task has stopped unexpectedly. "
521
+ "Stopping message stream."
522
+ )
523
+ raise StopAsyncIteration
419
524
 
420
525
  async def _process_response(self, response: RPCResponse) -> RPCResponse:
421
526
  provider = cast(PersistentConnectionProvider, self._provider)