web3 7.6.1__py3-none-any.whl → 7.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. ens/async_ens.py +1 -1
  2. ens/ens.py +1 -1
  3. web3/_utils/caching/caching_utils.py +64 -0
  4. web3/_utils/caching/request_caching_validation.py +1 -0
  5. web3/_utils/events.py +1 -1
  6. web3/_utils/http_session_manager.py +32 -3
  7. web3/_utils/module_testing/eth_module.py +5 -18
  8. web3/_utils/module_testing/module_testing_utils.py +1 -43
  9. web3/_utils/module_testing/persistent_connection_provider.py +696 -207
  10. web3/_utils/module_testing/utils.py +99 -33
  11. web3/beacon/api_endpoints.py +10 -0
  12. web3/beacon/async_beacon.py +47 -0
  13. web3/beacon/beacon.py +45 -0
  14. web3/contract/async_contract.py +2 -206
  15. web3/contract/base_contract.py +217 -13
  16. web3/contract/contract.py +2 -205
  17. web3/datastructures.py +15 -16
  18. web3/eth/async_eth.py +23 -5
  19. web3/exceptions.py +7 -0
  20. web3/main.py +24 -3
  21. web3/manager.py +140 -48
  22. web3/method.py +1 -1
  23. web3/middleware/attrdict.py +12 -22
  24. web3/middleware/base.py +14 -6
  25. web3/module.py +17 -21
  26. web3/providers/async_base.py +23 -14
  27. web3/providers/base.py +6 -8
  28. web3/providers/ipc.py +7 -6
  29. web3/providers/legacy_websocket.py +1 -1
  30. web3/providers/persistent/async_ipc.py +5 -3
  31. web3/providers/persistent/persistent.py +121 -17
  32. web3/providers/persistent/persistent_connection.py +11 -4
  33. web3/providers/persistent/request_processor.py +49 -41
  34. web3/providers/persistent/subscription_container.py +56 -0
  35. web3/providers/persistent/subscription_manager.py +298 -0
  36. web3/providers/persistent/websocket.py +4 -4
  37. web3/providers/rpc/async_rpc.py +16 -3
  38. web3/providers/rpc/rpc.py +9 -5
  39. web3/types.py +28 -14
  40. web3/utils/__init__.py +4 -0
  41. web3/utils/subscriptions.py +289 -0
  42. {web3-7.6.1.dist-info → web3-7.8.0.dist-info}/LICENSE +1 -1
  43. {web3-7.6.1.dist-info → web3-7.8.0.dist-info}/METADATA +68 -56
  44. {web3-7.6.1.dist-info → web3-7.8.0.dist-info}/RECORD +46 -43
  45. {web3-7.6.1.dist-info → web3-7.8.0.dist-info}/WHEEL +1 -1
  46. {web3-7.6.1.dist-info → web3-7.8.0.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,4 @@
1
1
  import asyncio
2
- from copy import (
3
- copy,
4
- )
5
2
  import sys
6
3
  from typing import (
7
4
  TYPE_CHECKING,
@@ -9,6 +6,7 @@ from typing import (
9
6
  Callable,
10
7
  Dict,
11
8
  Generic,
9
+ List,
12
10
  Optional,
13
11
  Tuple,
14
12
  TypeVar,
@@ -27,11 +25,16 @@ from web3._utils.caching import (
27
25
  generate_cache_key,
28
26
  )
29
27
  from web3.exceptions import (
28
+ SubscriptionProcessingFinished,
30
29
  TaskNotRunning,
31
30
  Web3ValueError,
32
31
  )
32
+ from web3.providers.persistent.subscription_manager import (
33
+ SubscriptionContainer,
34
+ )
33
35
  from web3.types import (
34
36
  RPCEndpoint,
37
+ RPCId,
35
38
  RPCResponse,
36
39
  )
37
40
  from web3.utils import (
@@ -75,18 +78,23 @@ class TaskReliantQueue(_TaskReliantQueue[T]):
75
78
  class RequestProcessor:
76
79
  _subscription_queue_synced_with_ws_stream: bool = False
77
80
 
81
+ # set by the subscription manager when it is initialized
82
+ _subscription_container: Optional[SubscriptionContainer] = None
83
+
78
84
  def __init__(
79
85
  self,
80
86
  provider: "PersistentConnectionProvider",
81
87
  subscription_response_queue_size: int = 500,
82
88
  ) -> None:
83
89
  self._provider = provider
84
-
85
90
  self._request_information_cache: SimpleCache = SimpleCache(500)
86
91
  self._request_response_cache: SimpleCache = SimpleCache(500)
87
92
  self._subscription_response_queue: TaskReliantQueue[
88
93
  Union[RPCResponse, TaskNotRunning]
89
94
  ] = TaskReliantQueue(maxsize=subscription_response_queue_size)
95
+ self._handler_subscription_queue: TaskReliantQueue[
96
+ Union[RPCResponse, TaskNotRunning, SubscriptionProcessingFinished]
97
+ ] = TaskReliantQueue(maxsize=subscription_response_queue_size)
90
98
 
91
99
  @property
92
100
  def active_subscriptions(self) -> Dict[str, Any]:
@@ -100,6 +108,7 @@ class RequestProcessor:
100
108
 
101
109
  def cache_request_information(
102
110
  self,
111
+ request_id: Optional[RPCId],
103
112
  method: RPCEndpoint,
104
113
  params: Any,
105
114
  response_formatters: Tuple[
@@ -120,25 +129,23 @@ class RequestProcessor:
120
129
  )
121
130
  return None
122
131
 
123
- if self._provider._is_batching:
132
+ if request_id is None:
133
+ if not self._provider._is_batching:
134
+ raise Web3ValueError(
135
+ "Request id must be provided when not batching requests."
136
+ )
124
137
  # the _batch_request_counter is set when entering the context manager
125
- current_request_id = self._provider._batch_request_counter
138
+ request_id = self._provider._batch_request_counter
126
139
  self._provider._batch_request_counter += 1
127
- else:
128
- # copy the request counter and find the next request id without incrementing
129
- # since this is done when / if the request is successfully sent
130
- current_request_id = next(copy(self._provider.request_counter))
131
- cache_key = generate_cache_key(current_request_id)
132
-
133
- self._bump_cache_if_key_present(cache_key, current_request_id)
134
140
 
141
+ cache_key = generate_cache_key(request_id)
135
142
  request_info = RequestInformation(
136
143
  method,
137
144
  params,
138
145
  response_formatters,
139
146
  )
140
147
  self._provider.logger.debug(
141
- f"Caching request info:\n request_id={current_request_id},\n"
148
+ f"Caching request info:\n request_id={request_id},\n"
142
149
  f" cache_key={cache_key},\n request_info={request_info.__dict__}"
143
150
  )
144
151
  self._request_information_cache.cache(
@@ -147,30 +154,6 @@ class RequestProcessor:
147
154
  )
148
155
  return cache_key
149
156
 
150
- def _bump_cache_if_key_present(self, cache_key: str, request_id: int) -> None:
151
- """
152
- If the cache key is present in the cache, bump the cache key and request id
153
- by one to make room for the new request. This behavior is necessary when a
154
- request is made but inner requests, say to `eth_estimateGas` if the `gas` is
155
- missing, are made before the original request is sent.
156
- """
157
- if cache_key in self._request_information_cache:
158
- original_request_info = self._request_information_cache.get_cache_entry(
159
- cache_key
160
- )
161
- bump = generate_cache_key(request_id + 1)
162
-
163
- # recursively bump the cache if the new key is also present
164
- self._bump_cache_if_key_present(bump, request_id + 1)
165
-
166
- self._provider.logger.debug(
167
- "Caching internal request. Bumping original request in cache:\n"
168
- f" request_id=[{request_id}] -> [{request_id + 1}],\n"
169
- f" cache_key=[{cache_key}] -> [{bump}],\n"
170
- f" request_info={original_request_info.__dict__}"
171
- )
172
- self._request_information_cache.cache(bump, original_request_info)
173
-
174
157
  def pop_cached_request_information(
175
158
  self, cache_key: str
176
159
  ) -> Optional[RequestInformation]:
@@ -288,6 +271,15 @@ class RequestProcessor:
288
271
 
289
272
  # raw response cache
290
273
 
274
+ def _is_batch_response(
275
+ self, raw_response: Union[List[RPCResponse], RPCResponse]
276
+ ) -> bool:
277
+ return isinstance(raw_response, list) or (
278
+ isinstance(raw_response, dict)
279
+ and raw_response.get("id") is None
280
+ and self._provider._is_batching
281
+ )
282
+
291
283
  async def cache_raw_response(
292
284
  self, raw_response: Any, subscription: bool = False
293
285
  ) -> None:
@@ -303,8 +295,18 @@ class RequestProcessor:
303
295
  self._provider.logger.debug(
304
296
  f"Caching subscription response:\n response={raw_response}"
305
297
  )
306
- await self._subscription_response_queue.put(raw_response)
307
- elif isinstance(raw_response, list):
298
+ subscription_id = raw_response.get("params", {}).get("subscription")
299
+ sub_container = self._subscription_container
300
+ if sub_container and sub_container.get_handler_subscription_by_id(
301
+ subscription_id
302
+ ):
303
+ # if the subscription has a handler, put it in the handler queue
304
+ await self._handler_subscription_queue.put(raw_response)
305
+ else:
306
+ # otherwise, put it in the subscription response queue so a response
307
+ # can be yielded by the message stream
308
+ await self._subscription_response_queue.put(raw_response)
309
+ elif self._is_batch_response(raw_response):
308
310
  # Since only one batch should be in the cache at all times, we use a
309
311
  # constant cache key for the batch response.
310
312
  cache_key = generate_cache_key(BATCH_REQUEST_ID)
@@ -367,7 +369,12 @@ class RequestProcessor:
367
369
 
368
370
  return raw_response
369
371
 
370
- # request processor class methods
372
+ # cache methods
373
+
374
+ def _reset_handler_subscription_queue(self) -> None:
375
+ self._handler_subscription_queue = TaskReliantQueue(
376
+ maxsize=self._handler_subscription_queue.maxsize
377
+ )
371
378
 
372
379
  def clear_caches(self) -> None:
373
380
  """Clear the request processor caches."""
@@ -376,3 +383,4 @@ class RequestProcessor:
376
383
  self._subscription_response_queue = TaskReliantQueue(
377
384
  maxsize=self._subscription_response_queue.maxsize
378
385
  )
386
+ self._reset_handler_subscription_queue()
@@ -0,0 +1,56 @@
1
+ from typing import (
2
+ Any,
3
+ Dict,
4
+ Iterator,
5
+ List,
6
+ Optional,
7
+ )
8
+
9
+ from eth_typing import (
10
+ HexStr,
11
+ )
12
+
13
+ from web3.utils import (
14
+ EthSubscription,
15
+ )
16
+
17
+
18
+ class SubscriptionContainer:
19
+ def __init__(self) -> None:
20
+ self.subscriptions: List[EthSubscription[Any]] = []
21
+ self.subscriptions_by_id: Dict[HexStr, EthSubscription[Any]] = {}
22
+ self.subscriptions_by_label: Dict[str, EthSubscription[Any]] = {}
23
+
24
+ def __len__(self) -> int:
25
+ return len(self.subscriptions)
26
+
27
+ def __iter__(self) -> Iterator[EthSubscription[Any]]:
28
+ return iter(self.subscriptions)
29
+
30
+ def add_subscription(self, subscription: EthSubscription[Any]) -> None:
31
+ self.subscriptions.append(subscription)
32
+ self.subscriptions_by_id[subscription.id] = subscription
33
+ self.subscriptions_by_label[subscription.label] = subscription
34
+
35
+ def remove_subscription(self, subscription: EthSubscription[Any]) -> None:
36
+ self.subscriptions.remove(subscription)
37
+ self.subscriptions_by_id.pop(subscription.id)
38
+ self.subscriptions_by_label.pop(subscription.label)
39
+
40
+ def get_by_id(self, sub_id: HexStr) -> EthSubscription[Any]:
41
+ return self.subscriptions_by_id.get(sub_id)
42
+
43
+ def get_by_label(self, label: str) -> EthSubscription[Any]:
44
+ return self.subscriptions_by_label.get(label)
45
+
46
+ @property
47
+ def handler_subscriptions(self) -> List[EthSubscription[Any]]:
48
+ return [sub for sub in self.subscriptions if sub._handler is not None]
49
+
50
+ def get_handler_subscription_by_id(
51
+ self, sub_id: HexStr
52
+ ) -> Optional[EthSubscription[Any]]:
53
+ sub = self.get_by_id(sub_id)
54
+ if sub and sub._handler:
55
+ return sub
56
+ return None
@@ -0,0 +1,298 @@
1
+ import asyncio
2
+ import logging
3
+ from typing import (
4
+ TYPE_CHECKING,
5
+ Any,
6
+ List,
7
+ Sequence,
8
+ Union,
9
+ cast,
10
+ overload,
11
+ )
12
+
13
+ from eth_typing import (
14
+ HexStr,
15
+ )
16
+
17
+ from web3.exceptions import (
18
+ SubscriptionProcessingFinished,
19
+ TaskNotRunning,
20
+ Web3TypeError,
21
+ Web3ValueError,
22
+ )
23
+ from web3.providers.persistent.subscription_container import (
24
+ SubscriptionContainer,
25
+ )
26
+ from web3.types import (
27
+ FormattedEthSubscriptionResponse,
28
+ RPCResponse,
29
+ )
30
+ from web3.utils.subscriptions import (
31
+ EthSubscription,
32
+ EthSubscriptionContext,
33
+ )
34
+
35
+ if TYPE_CHECKING:
36
+ from web3 import AsyncWeb3 # noqa: F401
37
+ from web3.providers.persistent import ( # noqa: F401
38
+ PersistentConnectionProvider,
39
+ RequestProcessor,
40
+ )
41
+
42
+
43
+ class SubscriptionManager:
44
+ """
45
+ The ``SubscriptionManager`` is responsible for subscribing, unsubscribing, and
46
+ managing all active subscriptions for an ``AsyncWeb3`` instance. It is also
47
+ used for processing all subscriptions that have handler functions.
48
+ """
49
+
50
+ logger: logging.Logger = logging.getLogger(
51
+ "web3.providers.persistent.subscription_manager"
52
+ )
53
+ total_handler_calls: int = 0
54
+
55
+ def __init__(self, w3: "AsyncWeb3") -> None:
56
+ self._w3 = w3
57
+ self._provider = cast("PersistentConnectionProvider", w3.provider)
58
+ self._subscription_container = SubscriptionContainer()
59
+
60
+ # share the subscription container with the request processor so it can separate
61
+ # subscriptions into different queues based on ``sub._handler`` presence
62
+ self._provider._request_processor._subscription_container = (
63
+ self._subscription_container
64
+ )
65
+
66
+ def _add_subscription(self, subscription: EthSubscription[Any]) -> None:
67
+ self._subscription_container.add_subscription(subscription)
68
+
69
+ def _remove_subscription(self, subscription: EthSubscription[Any]) -> None:
70
+ self._subscription_container.remove_subscription(subscription)
71
+
72
+ def _validate_and_normalize_label(self, subscription: EthSubscription[Any]) -> None:
73
+ if subscription.label == subscription._default_label:
74
+ # if no custom label was provided, generate a unique label
75
+ i = 2
76
+ while self.get_by_label(subscription._label) is not None:
77
+ subscription._label = f"{subscription._default_label}#{i}"
78
+ i += 1
79
+ else:
80
+ if (
81
+ subscription._label
82
+ in self._subscription_container.subscriptions_by_label
83
+ ):
84
+ raise Web3ValueError(
85
+ "Subscription label already exists. Subscriptions must have unique "
86
+ f"labels.\n label: {subscription._label}"
87
+ )
88
+
89
+ @property
90
+ def subscriptions(self) -> List[EthSubscription[Any]]:
91
+ return self._subscription_container.subscriptions
92
+
93
+ def get_by_id(self, sub_id: HexStr) -> EthSubscription[Any]:
94
+ return self._subscription_container.get_by_id(sub_id)
95
+
96
+ def get_by_label(self, label: str) -> EthSubscription[Any]:
97
+ return self._subscription_container.get_by_label(label)
98
+
99
+ @overload
100
+ async def subscribe(self, subscriptions: EthSubscription[Any]) -> HexStr:
101
+ ...
102
+
103
+ @overload
104
+ async def subscribe(
105
+ self, subscriptions: Sequence[EthSubscription[Any]]
106
+ ) -> List[HexStr]:
107
+ ...
108
+
109
+ async def subscribe(
110
+ self, subscriptions: Union[EthSubscription[Any], Sequence[EthSubscription[Any]]]
111
+ ) -> Union[HexStr, List[HexStr]]:
112
+ """
113
+ Used to subscribe to a single or multiple subscriptions.
114
+
115
+ :param subscriptions: A single subscription or a sequence of subscriptions.
116
+ :type subscriptions: Union[EthSubscription, Sequence[EthSubscription]]
117
+ :return:
118
+ """
119
+ if isinstance(subscriptions, EthSubscription):
120
+ subscriptions.manager = self
121
+ self._validate_and_normalize_label(subscriptions)
122
+ sub_id = await self._w3.eth._subscribe(*subscriptions.subscription_params)
123
+ subscriptions._id = sub_id
124
+ self._add_subscription(subscriptions)
125
+ self.logger.info(
126
+ "Successfully subscribed to subscription:\n "
127
+ f"label: {subscriptions.label}\n id: {sub_id}"
128
+ )
129
+ return sub_id
130
+ elif isinstance(subscriptions, Sequence):
131
+ if len(subscriptions) == 0:
132
+ raise Web3ValueError("No subscriptions provided.")
133
+
134
+ sub_ids: List[HexStr] = []
135
+ for sub in subscriptions:
136
+ sub_ids.append(await self.subscribe(sub))
137
+ return sub_ids
138
+ raise Web3TypeError("Expected a Subscription or a sequence of Subscriptions.")
139
+
140
+ @overload
141
+ async def unsubscribe(self, subscriptions: EthSubscription[Any]) -> bool:
142
+ ...
143
+
144
+ @overload
145
+ async def unsubscribe(self, subscriptions: HexStr) -> bool:
146
+ ...
147
+
148
+ @overload
149
+ async def unsubscribe(
150
+ self,
151
+ subscriptions: Sequence[Union[EthSubscription[Any], HexStr]],
152
+ ) -> bool:
153
+ ...
154
+
155
+ async def unsubscribe(
156
+ self,
157
+ subscriptions: Union[
158
+ EthSubscription[Any],
159
+ HexStr,
160
+ Sequence[Union[EthSubscription[Any], HexStr]],
161
+ ],
162
+ ) -> bool:
163
+ """
164
+ Used to unsubscribe from one or multiple subscriptions.
165
+
166
+ :param subscriptions: The subscription(s) to unsubscribe from.
167
+ :type subscriptions: Union[EthSubscription, Sequence[EthSubscription], HexStr,
168
+ Sequence[HexStr]]
169
+ :return: ``True`` if unsubscribing to all was successful, ``False`` otherwise
170
+ with a warning.
171
+ :rtype: bool
172
+ """
173
+ if isinstance(subscriptions, EthSubscription) or isinstance(subscriptions, str):
174
+ if isinstance(subscriptions, str):
175
+ subscription_id = subscriptions
176
+ subscriptions = self.get_by_id(subscription_id)
177
+ if subscriptions is None:
178
+ raise Web3ValueError(
179
+ "Subscription not found or is not being managed by the "
180
+ f"subscription manager.\n id: {subscription_id}"
181
+ )
182
+
183
+ if subscriptions not in self.subscriptions:
184
+ raise Web3ValueError(
185
+ "Subscription not found or is not being managed by the "
186
+ "subscription manager.\n "
187
+ f"label: {subscriptions.label}\n id: {subscriptions._id}"
188
+ )
189
+
190
+ if await self._w3.eth._unsubscribe(subscriptions.id):
191
+ self._remove_subscription(subscriptions)
192
+ self.logger.info(
193
+ "Successfully unsubscribed from subscription:\n "
194
+ f"label: {subscriptions.label}\n id: {subscriptions.id}"
195
+ )
196
+
197
+ if len(self._subscription_container.handler_subscriptions) == 0:
198
+ queue = (
199
+ self._provider._request_processor._handler_subscription_queue
200
+ )
201
+ await queue.put(SubscriptionProcessingFinished())
202
+ return True
203
+
204
+ elif isinstance(subscriptions, Sequence):
205
+ if len(subscriptions) == 0:
206
+ raise Web3ValueError("No subscriptions provided.")
207
+
208
+ unsubscribed: List[bool] = []
209
+ for sub in subscriptions:
210
+ if isinstance(sub, str):
211
+ sub = HexStr(sub)
212
+ unsubscribed.append(await self.unsubscribe(sub))
213
+ return all(unsubscribed)
214
+
215
+ self.logger.warning(
216
+ f"Failed to unsubscribe from subscription\n subscription={subscriptions}"
217
+ )
218
+ return False
219
+
220
+ async def unsubscribe_all(self) -> bool:
221
+ """
222
+ Used to unsubscribe from all subscriptions that are being managed by the
223
+ subscription manager.
224
+
225
+ :return: ``True`` if unsubscribing was successful, ``False`` otherwise.
226
+ :rtype: bool
227
+ """
228
+ unsubscribed = [
229
+ await self.unsubscribe(sub) for sub in self.subscriptions.copy()
230
+ ]
231
+ if all(unsubscribed):
232
+ self.logger.info("Successfully unsubscribed from all subscriptions.")
233
+ return True
234
+ else:
235
+ if len(self.subscriptions) > 0:
236
+ self.logger.warning(
237
+ "Failed to unsubscribe from all subscriptions. Some subscriptions "
238
+ f"are still active.\n subscriptions={self.subscriptions}"
239
+ )
240
+ return False
241
+
242
+ async def handle_subscriptions(self, run_forever: bool = False) -> None:
243
+ """
244
+ Used to handle all subscriptions that have handlers. The method will run until
245
+ all subscriptions that have handler functions are unsubscribed from or, if
246
+ ``run_forever`` is set to ``True``, it will run indefinitely.
247
+
248
+ :param run_forever: If ``True``, the method will run indefinitely.
249
+ :type run_forever: bool
250
+ :return: None
251
+ """
252
+ if not self._subscription_container.handler_subscriptions and not run_forever:
253
+ self.logger.warning(
254
+ "No handler subscriptions found. Subscription handler did not run."
255
+ )
256
+ return
257
+
258
+ queue = self._provider._request_processor._handler_subscription_queue
259
+ while run_forever or self._subscription_container.handler_subscriptions:
260
+ try:
261
+ response = cast(RPCResponse, await queue.get())
262
+ formatted_sub_response = cast(
263
+ FormattedEthSubscriptionResponse,
264
+ await self._w3.manager._process_response(response),
265
+ )
266
+
267
+ # if the subscription was unsubscribed from, the response won't be
268
+ # formatted because we lost the request information
269
+ sub_id = formatted_sub_response.get("subscription")
270
+ sub = self._subscription_container.get_handler_subscription_by_id(
271
+ sub_id
272
+ )
273
+ if sub:
274
+ await sub._handler(
275
+ EthSubscriptionContext(
276
+ self._w3,
277
+ sub,
278
+ formatted_sub_response["result"],
279
+ **sub._handler_context,
280
+ )
281
+ )
282
+ except SubscriptionProcessingFinished:
283
+ if not run_forever:
284
+ self.logger.info(
285
+ "All handler subscriptions have been unsubscribed from. "
286
+ "Stopping subscription handling."
287
+ )
288
+ break
289
+ except TaskNotRunning:
290
+ await asyncio.sleep(0)
291
+ self._provider._handle_listener_task_exceptions()
292
+ self.logger.error(
293
+ "Message listener background task for the provider has stopped "
294
+ "unexpectedly. Stopping subscription handling."
295
+ )
296
+
297
+ # no active handler subscriptions, clear the handler subscription queue
298
+ self._provider._request_processor._reset_handler_subscription_queue()
@@ -59,8 +59,6 @@ class WebSocketProvider(PersistentConnectionProvider):
59
59
  logger = logging.getLogger("web3.providers.WebSocketProvider")
60
60
  is_async: bool = True
61
61
 
62
- _ws: Optional[WebSocketClientProtocol] = None
63
-
64
62
  def __init__(
65
63
  self,
66
64
  endpoint_uri: Optional[Union[URI, str]] = None,
@@ -68,9 +66,12 @@ class WebSocketProvider(PersistentConnectionProvider):
68
66
  # `PersistentConnectionProvider` kwargs can be passed through
69
67
  **kwargs: Any,
70
68
  ) -> None:
69
+ # initialize the endpoint_uri before calling the super constructor
71
70
  self.endpoint_uri = (
72
71
  URI(endpoint_uri) if endpoint_uri is not None else get_default_endpoint()
73
72
  )
73
+ super().__init__(**kwargs)
74
+ self._ws: Optional[WebSocketClientProtocol] = None
74
75
 
75
76
  if not any(
76
77
  self.endpoint_uri.startswith(prefix)
@@ -93,8 +94,6 @@ class WebSocketProvider(PersistentConnectionProvider):
93
94
 
94
95
  self.websocket_kwargs = merge(DEFAULT_WEBSOCKET_KWARGS, websocket_kwargs or {})
95
96
 
96
- super().__init__(**kwargs)
97
-
98
97
  def __str__(self) -> str:
99
98
  return f"WebSocket connection: {self.endpoint_uri}"
100
99
 
@@ -133,6 +132,7 @@ class WebSocketProvider(PersistentConnectionProvider):
133
132
  self._ws = await connect(self.endpoint_uri, **self.websocket_kwargs)
134
133
 
135
134
  async def _provider_specific_disconnect(self) -> None:
135
+ # this should remain idempotent
136
136
  if self._ws is not None and not self._ws.closed:
137
137
  await self._ws.close()
138
138
  self._ws = None
@@ -168,12 +168,25 @@ class AsyncHTTPProvider(AsyncJSONBaseProvider):
168
168
 
169
169
  async def make_batch_request(
170
170
  self, batch_requests: List[Tuple[RPCEndpoint, Any]]
171
- ) -> List[RPCResponse]:
171
+ ) -> Union[List[RPCResponse], RPCResponse]:
172
172
  self.logger.debug(f"Making batch request HTTP - uri: `{self.endpoint_uri}`")
173
173
  request_data = self.encode_batch_rpc_request(batch_requests)
174
174
  raw_response = await self._request_session_manager.async_make_post_request(
175
175
  self.endpoint_uri, request_data, **self.get_request_kwargs()
176
176
  )
177
177
  self.logger.debug("Received batch response HTTP.")
178
- responses_list = cast(List[RPCResponse], self.decode_rpc_response(raw_response))
179
- return sort_batch_response_by_response_ids(responses_list)
178
+ response = self.decode_rpc_response(raw_response)
179
+ if not isinstance(response, list):
180
+ # RPC errors return only one response with the error object
181
+ return response
182
+ return sort_batch_response_by_response_ids(
183
+ cast(List[RPCResponse], sort_batch_response_by_response_ids(response))
184
+ )
185
+
186
+ async def disconnect(self) -> None:
187
+ cache = self._request_session_manager.session_cache
188
+ for _, session in cache.items():
189
+ await session.close()
190
+ cache.clear()
191
+
192
+ self.logger.info(f"Successfully disconnected from: {self.endpoint_uri}")
web3/providers/rpc/rpc.py CHANGED
@@ -71,6 +71,7 @@ class HTTPProvider(JSONBaseProvider):
71
71
  ] = empty,
72
72
  **kwargs: Any,
73
73
  ) -> None:
74
+ super().__init__(**kwargs)
74
75
  self._request_session_manager = HTTPSessionManager()
75
76
 
76
77
  if endpoint_uri is None:
@@ -88,8 +89,6 @@ class HTTPProvider(JSONBaseProvider):
88
89
  self.endpoint_uri, session
89
90
  )
90
91
 
91
- super().__init__(**kwargs)
92
-
93
92
  def __str__(self) -> str:
94
93
  return f"RPC connection {self.endpoint_uri}"
95
94
 
@@ -177,12 +176,17 @@ class HTTPProvider(JSONBaseProvider):
177
176
 
178
177
  def make_batch_request(
179
178
  self, batch_requests: List[Tuple[RPCEndpoint, Any]]
180
- ) -> List[RPCResponse]:
179
+ ) -> Union[List[RPCResponse], RPCResponse]:
181
180
  self.logger.debug(f"Making batch request HTTP, uri: `{self.endpoint_uri}`")
182
181
  request_data = self.encode_batch_rpc_request(batch_requests)
183
182
  raw_response = self._request_session_manager.make_post_request(
184
183
  self.endpoint_uri, request_data, **self.get_request_kwargs()
185
184
  )
186
185
  self.logger.debug("Received batch response HTTP.")
187
- responses_list = cast(List[RPCResponse], self.decode_rpc_response(raw_response))
188
- return sort_batch_response_by_response_ids(responses_list)
186
+ response = self.decode_rpc_response(raw_response)
187
+ if not isinstance(response, list):
188
+ # RPC errors return only one response with the error object
189
+ return response
190
+ return sort_batch_response_by_response_ids(
191
+ cast(List[RPCResponse], sort_batch_response_by_response_ids(response))
192
+ )