web3 7.6.1__py3-none-any.whl → 7.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. ens/async_ens.py +1 -1
  2. ens/ens.py +1 -1
  3. web3/_utils/caching/caching_utils.py +64 -0
  4. web3/_utils/caching/request_caching_validation.py +1 -0
  5. web3/_utils/events.py +1 -1
  6. web3/_utils/http_session_manager.py +32 -3
  7. web3/_utils/module_testing/eth_module.py +5 -18
  8. web3/_utils/module_testing/module_testing_utils.py +1 -43
  9. web3/_utils/module_testing/persistent_connection_provider.py +696 -207
  10. web3/_utils/module_testing/utils.py +99 -33
  11. web3/beacon/api_endpoints.py +10 -0
  12. web3/beacon/async_beacon.py +47 -0
  13. web3/beacon/beacon.py +45 -0
  14. web3/contract/async_contract.py +2 -206
  15. web3/contract/base_contract.py +217 -13
  16. web3/contract/contract.py +2 -205
  17. web3/datastructures.py +15 -16
  18. web3/eth/async_eth.py +23 -5
  19. web3/exceptions.py +7 -0
  20. web3/main.py +24 -3
  21. web3/manager.py +140 -48
  22. web3/method.py +1 -1
  23. web3/middleware/attrdict.py +12 -22
  24. web3/middleware/base.py +14 -6
  25. web3/module.py +17 -21
  26. web3/providers/async_base.py +23 -14
  27. web3/providers/base.py +6 -8
  28. web3/providers/ipc.py +7 -6
  29. web3/providers/legacy_websocket.py +1 -1
  30. web3/providers/persistent/async_ipc.py +5 -3
  31. web3/providers/persistent/persistent.py +121 -17
  32. web3/providers/persistent/persistent_connection.py +11 -4
  33. web3/providers/persistent/request_processor.py +49 -41
  34. web3/providers/persistent/subscription_container.py +56 -0
  35. web3/providers/persistent/subscription_manager.py +298 -0
  36. web3/providers/persistent/websocket.py +4 -4
  37. web3/providers/rpc/async_rpc.py +16 -3
  38. web3/providers/rpc/rpc.py +9 -5
  39. web3/types.py +28 -14
  40. web3/utils/__init__.py +4 -0
  41. web3/utils/subscriptions.py +289 -0
  42. {web3-7.6.1.dist-info → web3-7.8.0.dist-info}/LICENSE +1 -1
  43. {web3-7.6.1.dist-info → web3-7.8.0.dist-info}/METADATA +68 -56
  44. {web3-7.6.1.dist-info → web3-7.8.0.dist-info}/RECORD +46 -43
  45. {web3-7.6.1.dist-info → web3-7.8.0.dist-info}/WHEEL +1 -1
  46. {web3-7.6.1.dist-info → web3-7.8.0.dist-info}/top_level.txt +0 -0
@@ -6,6 +6,7 @@ from typing import (
6
6
  Any,
7
7
  Callable,
8
8
  Coroutine,
9
+ Dict,
9
10
  List,
10
11
  Optional,
11
12
  Set,
@@ -43,6 +44,7 @@ from web3.middleware.base import (
43
44
  )
44
45
  from web3.types import (
45
46
  RPCEndpoint,
47
+ RPCRequest,
46
48
  RPCResponse,
47
49
  )
48
50
  from web3.utils import (
@@ -75,7 +77,8 @@ class AsyncBaseProvider:
75
77
 
76
78
  _is_batching: bool = False
77
79
  _batch_request_func_cache: Tuple[
78
- Tuple[Middleware, ...], Callable[..., Coroutine[Any, Any, List[RPCResponse]]]
80
+ Tuple[Middleware, ...],
81
+ Callable[..., Coroutine[Any, Any, Union[List[RPCResponse], RPCResponse]]],
79
82
  ] = (None, None)
80
83
 
81
84
  is_async = True
@@ -83,10 +86,6 @@ class AsyncBaseProvider:
83
86
  global_ccip_read_enabled: bool = True
84
87
  ccip_read_max_redirects: int = 4
85
88
 
86
- # request caching
87
- _request_cache: SimpleCache
88
- _request_cache_lock: asyncio.Lock = asyncio.Lock()
89
-
90
89
  def __init__(
91
90
  self,
92
91
  cache_allowed_requests: bool = False,
@@ -96,6 +95,8 @@ class AsyncBaseProvider:
96
95
  ] = empty,
97
96
  ) -> None:
98
97
  self._request_cache = SimpleCache(1000)
98
+ self._request_cache_lock: asyncio.Lock = asyncio.Lock()
99
+
99
100
  self.cache_allowed_requests = cache_allowed_requests
100
101
  self.cacheable_requests = cacheable_requests or CACHEABLE_REQUESTS
101
102
  self.request_cache_validation_threshold = request_cache_validation_threshold
@@ -119,7 +120,7 @@ class AsyncBaseProvider:
119
120
 
120
121
  async def batch_request_func(
121
122
  self, async_w3: "AsyncWeb3", middleware_onion: MiddlewareOnion
122
- ) -> Callable[..., Coroutine[Any, Any, List[RPCResponse]]]:
123
+ ) -> Callable[..., Coroutine[Any, Any, Union[List[RPCResponse], RPCResponse]]]:
123
124
  middleware: Tuple[Middleware, ...] = middleware_onion.as_tuple_of_middleware()
124
125
 
125
126
  cache_key = self._batch_request_func_cache[0]
@@ -141,8 +142,8 @@ class AsyncBaseProvider:
141
142
 
142
143
  async def make_batch_request(
143
144
  self, requests: List[Tuple[RPCEndpoint, Any]]
144
- ) -> List[RPCResponse]:
145
- raise NotImplementedError("Only AsyncHTTPProvider supports this method")
145
+ ) -> Union[List[RPCResponse], RPCResponse]:
146
+ raise NotImplementedError("Providers must implement this method")
146
147
 
147
148
  async def is_connected(self, show_traceback: bool = False) -> bool:
148
149
  raise NotImplementedError("Providers must implement this method")
@@ -172,23 +173,31 @@ class AsyncBaseProvider:
172
173
 
173
174
 
174
175
  class AsyncJSONBaseProvider(AsyncBaseProvider):
175
- logger = logging.getLogger("web3.providers.async_base.AsyncJSONBaseProvider")
176
-
177
176
  def __init__(self, **kwargs: Any) -> None:
178
- self.request_counter = itertools.count()
179
177
  super().__init__(**kwargs)
178
+ self.request_counter = itertools.count()
180
179
 
181
- def encode_rpc_request(self, method: RPCEndpoint, params: Any) -> bytes:
180
+ def form_request(self, method: RPCEndpoint, params: Any = None) -> RPCRequest:
182
181
  request_id = next(self.request_counter)
183
182
  rpc_dict = {
183
+ "id": request_id,
184
184
  "jsonrpc": "2.0",
185
185
  "method": method,
186
186
  "params": params or [],
187
- "id": request_id,
188
187
  }
189
- encoded = FriendlyJsonSerde().json_encode(rpc_dict, cls=Web3JsonEncoder)
188
+ return cast(RPCRequest, rpc_dict)
189
+
190
+ @staticmethod
191
+ def encode_rpc_dict(rpc_dict: RPCRequest) -> bytes:
192
+ encoded = FriendlyJsonSerde().json_encode(
193
+ cast(Dict[str, Any], rpc_dict), cls=Web3JsonEncoder
194
+ )
190
195
  return to_bytes(text=encoded)
191
196
 
197
+ def encode_rpc_request(self, method: RPCEndpoint, params: Any) -> bytes:
198
+ rpc_dict = self.form_request(method, params)
199
+ return self.encode_rpc_dict(rpc_dict)
200
+
192
201
  @staticmethod
193
202
  def decode_rpc_response(raw_response: bytes) -> RPCResponse:
194
203
  text_response = str(
web3/providers/base.py CHANGED
@@ -66,10 +66,6 @@ class BaseProvider:
66
66
  global_ccip_read_enabled: bool = True
67
67
  ccip_read_max_redirects: int = 4
68
68
 
69
- # request caching
70
- _request_cache: SimpleCache
71
- _request_cache_lock: threading.Lock = threading.Lock()
72
-
73
69
  def __init__(
74
70
  self,
75
71
  cache_allowed_requests: bool = False,
@@ -79,6 +75,8 @@ class BaseProvider:
79
75
  ] = empty,
80
76
  ) -> None:
81
77
  self._request_cache = SimpleCache(1000)
78
+ self._request_cache_lock: threading.Lock = threading.Lock()
79
+
82
80
  self.cache_allowed_requests = cache_allowed_requests
83
81
  self.cacheable_requests = cacheable_requests or CACHEABLE_REQUESTS
84
82
  self.request_cache_validation_threshold = request_cache_validation_threshold
@@ -120,12 +118,12 @@ class JSONBaseProvider(BaseProvider):
120
118
 
121
119
  _is_batching: bool = False
122
120
  _batch_request_func_cache: Tuple[
123
- Tuple[Middleware, ...], Callable[..., List[RPCResponse]]
121
+ Tuple[Middleware, ...], Callable[..., Union[List[RPCResponse], RPCResponse]]
124
122
  ] = (None, None)
125
123
 
126
124
  def __init__(self, **kwargs: Any) -> None:
127
- self.request_counter = itertools.count()
128
125
  super().__init__(**kwargs)
126
+ self.request_counter = itertools.count()
129
127
 
130
128
  def encode_rpc_request(self, method: RPCEndpoint, params: Any) -> bytes:
131
129
  rpc_dict = {
@@ -170,7 +168,7 @@ class JSONBaseProvider(BaseProvider):
170
168
 
171
169
  def batch_request_func(
172
170
  self, w3: "Web3", middleware_onion: MiddlewareOnion
173
- ) -> Callable[..., List[RPCResponse]]:
171
+ ) -> Callable[..., Union[List[RPCResponse], RPCResponse]]:
174
172
  middleware: Tuple[Middleware, ...] = middleware_onion.as_tuple_of_middleware()
175
173
 
176
174
  cache_key = self._batch_request_func_cache[0]
@@ -201,5 +199,5 @@ class JSONBaseProvider(BaseProvider):
201
199
 
202
200
  def make_batch_request(
203
201
  self, requests: List[Tuple[RPCEndpoint, Any]]
204
- ) -> List[RPCResponse]:
202
+ ) -> Union[List[RPCResponse], RPCResponse]:
205
203
  raise NotImplementedError("Providers must implement this method")
web3/providers/ipc.py CHANGED
@@ -116,14 +116,15 @@ def get_default_ipc_path() -> str:
116
116
 
117
117
 
118
118
  def get_dev_ipc_path() -> str:
119
- if os.environ.get("WEB3_PROVIDER_URI", ""):
120
- return os.environ.get("WEB3_PROVIDER_URI")
119
+ web3_provider_uri = os.environ.get("WEB3_PROVIDER_URI", "")
120
+ if web3_provider_uri and "geth.ipc" in web3_provider_uri:
121
+ return web3_provider_uri
121
122
 
122
- elif sys.platform == "darwin":
123
- tmpdir = os.environ.get("TMPDIR", "")
123
+ elif sys.platform == "darwin" or sys.platform.startswith("linux"):
124
+ tmpdir = os.environ.get("TMPDIR", "/tmp")
124
125
  return os.path.expanduser(os.path.join(tmpdir, "geth.ipc"))
125
126
 
126
- elif sys.platform.startswith("linux") or sys.platform.startswith("freebsd"):
127
+ elif sys.platform.endswith("freebsd"):
127
128
  return os.path.expanduser(os.path.join("/tmp", "geth.ipc"))
128
129
 
129
130
  elif sys.platform == "win32":
@@ -146,6 +147,7 @@ class IPCProvider(JSONBaseProvider):
146
147
  timeout: int = 30,
147
148
  **kwargs: Any,
148
149
  ) -> None:
150
+ super().__init__(**kwargs)
149
151
  if ipc_path is None:
150
152
  self.ipc_path = get_default_ipc_path()
151
153
  elif isinstance(ipc_path, str) or isinstance(ipc_path, Path):
@@ -156,7 +158,6 @@ class IPCProvider(JSONBaseProvider):
156
158
  self.timeout = timeout
157
159
  self._lock = threading.Lock()
158
160
  self._socket = PersistantSocket(self.ipc_path)
159
- super().__init__(**kwargs)
160
161
 
161
162
  def __str__(self) -> str:
162
163
  return f"<{self.__class__.__name__} {self.ipc_path}>"
@@ -102,6 +102,7 @@ class LegacyWebSocketProvider(JSONBaseProvider):
102
102
  websocket_timeout: int = DEFAULT_WEBSOCKET_TIMEOUT,
103
103
  **kwargs: Any,
104
104
  ) -> None:
105
+ super().__init__(**kwargs)
105
106
  self.endpoint_uri = URI(endpoint_uri)
106
107
  self.websocket_timeout = websocket_timeout
107
108
  if self.endpoint_uri is None:
@@ -120,7 +121,6 @@ class LegacyWebSocketProvider(JSONBaseProvider):
120
121
  f"in websocket_kwargs, found: {found_restricted_keys}"
121
122
  )
122
123
  self.conn = PersistentWebSocket(self.endpoint_uri, websocket_kwargs)
123
- super().__init__(**kwargs)
124
124
 
125
125
  def __str__(self) -> str:
126
126
  return f"WS connection {self.endpoint_uri}"
@@ -14,6 +14,7 @@ from typing import (
14
14
  )
15
15
 
16
16
  from web3.types import (
17
+ RPCEndpoint,
17
18
  RPCResponse,
18
19
  )
19
20
 
@@ -59,15 +60,15 @@ class AsyncIPCProvider(PersistentConnectionProvider):
59
60
  # `PersistentConnectionProvider` kwargs can be passed through
60
61
  **kwargs: Any,
61
62
  ) -> None:
63
+ # initialize the ipc_path before calling the super constructor
62
64
  if ipc_path is None:
63
65
  self.ipc_path = get_default_ipc_path()
64
66
  elif isinstance(ipc_path, str) or isinstance(ipc_path, Path):
65
67
  self.ipc_path = str(Path(ipc_path).expanduser().resolve())
66
68
  else:
67
69
  raise Web3TypeError("ipc_path must be of type string or pathlib.Path")
68
-
69
- self.read_buffer_limit = read_buffer_limit
70
70
  super().__init__(**kwargs)
71
+ self.read_buffer_limit = read_buffer_limit
71
72
 
72
73
  def __str__(self) -> str:
73
74
  return f"<{self.__class__.__name__} {self.ipc_path}>"
@@ -77,7 +78,7 @@ class AsyncIPCProvider(PersistentConnectionProvider):
77
78
  return False
78
79
 
79
80
  try:
80
- await self.make_request("web3_clientVersion", [])
81
+ await self.make_request(RPCEndpoint("web3_clientVersion"), [])
81
82
  return True
82
83
  except (OSError, ProviderConnectionError) as e:
83
84
  if show_traceback:
@@ -139,6 +140,7 @@ class AsyncIPCProvider(PersistentConnectionProvider):
139
140
  )
140
141
 
141
142
  async def _provider_specific_disconnect(self) -> None:
143
+ # this should remain idempotent
142
144
  if self._writer and not self._writer.is_closing():
143
145
  self._writer.close()
144
146
  await self._writer.wait_closed()
@@ -3,10 +3,13 @@ from abc import (
3
3
  abstractmethod,
4
4
  )
5
5
  import asyncio
6
- import json
7
6
  import logging
7
+ import signal
8
8
  from typing import (
9
+ TYPE_CHECKING,
9
10
  Any,
11
+ Callable,
12
+ Coroutine,
10
13
  List,
11
14
  Optional,
12
15
  Tuple,
@@ -24,9 +27,12 @@ from web3._utils.batching import (
24
27
  sort_batch_response_by_response_ids,
25
28
  )
26
29
  from web3._utils.caching import (
27
- async_handle_request_caching,
28
30
  generate_cache_key,
29
31
  )
32
+ from web3._utils.caching.caching_utils import (
33
+ async_handle_recv_caching,
34
+ async_handle_send_caching,
35
+ )
30
36
  from web3.exceptions import (
31
37
  PersistentConnectionClosedOK,
32
38
  ProviderConnectionError,
@@ -43,9 +49,15 @@ from web3.providers.persistent.request_processor import (
43
49
  from web3.types import (
44
50
  RPCEndpoint,
45
51
  RPCId,
52
+ RPCRequest,
46
53
  RPCResponse,
47
54
  )
48
55
 
56
+ if TYPE_CHECKING:
57
+ from web3 import AsyncWeb3 # noqa: F401
58
+ from web3.middleware.base import MiddlewareOnion # noqa: F401
59
+
60
+
49
61
  DEFAULT_PERSISTENT_CONNECTION_TIMEOUT = 30.0
50
62
 
51
63
 
@@ -53,11 +65,14 @@ class PersistentConnectionProvider(AsyncJSONBaseProvider, ABC):
53
65
  logger = logging.getLogger("web3.providers.PersistentConnectionProvider")
54
66
  has_persistent_connection = True
55
67
 
56
- _request_processor: RequestProcessor
57
- _message_listener_task: Optional["asyncio.Task[None]"] = None
58
- _listen_event: asyncio.Event = asyncio.Event()
59
-
60
- _batch_request_counter: Optional[int] = None
68
+ _send_func_cache: Tuple[int, Callable[..., Coroutine[Any, Any, RPCRequest]]] = (
69
+ None,
70
+ None,
71
+ )
72
+ _recv_func_cache: Tuple[int, Callable[..., Coroutine[Any, Any, RPCResponse]]] = (
73
+ None,
74
+ None,
75
+ )
61
76
 
62
77
  def __init__(
63
78
  self,
@@ -72,9 +87,63 @@ class PersistentConnectionProvider(AsyncJSONBaseProvider, ABC):
72
87
  self,
73
88
  subscription_response_queue_size=subscription_response_queue_size,
74
89
  )
90
+ self._message_listener_task: Optional["asyncio.Task[None]"] = None
91
+ self._batch_request_counter: Optional[int] = None
92
+ self._listen_event: asyncio.Event = asyncio.Event()
93
+ self._max_connection_retries = max_connection_retries
94
+
75
95
  self.request_timeout = request_timeout
76
96
  self.silence_listener_task_exceptions = silence_listener_task_exceptions
77
- self._max_connection_retries = max_connection_retries
97
+
98
+ async def send_func(
99
+ self, async_w3: "AsyncWeb3", middleware_onion: "MiddlewareOnion"
100
+ ) -> Callable[..., Coroutine[Any, Any, RPCRequest]]:
101
+ """
102
+ Cache the middleware chain for `send`.
103
+ """
104
+ middleware = middleware_onion.as_tuple_of_middleware()
105
+ cache_key = hash(tuple(id(mw) for mw in middleware))
106
+
107
+ if cache_key != self._send_func_cache[0]:
108
+
109
+ async def send_function(method: RPCEndpoint, params: Any) -> RPCRequest:
110
+ for mw in middleware:
111
+ initialized = mw(async_w3)
112
+ method, params = await initialized.async_request_processor(
113
+ method, params
114
+ )
115
+
116
+ return await self.send_request(method, params)
117
+
118
+ self._send_func_cache = (cache_key, send_function)
119
+
120
+ return self._send_func_cache[1]
121
+
122
+ async def recv_func(
123
+ self, async_w3: "AsyncWeb3", middleware_onion: "MiddlewareOnion"
124
+ ) -> Any:
125
+ """
126
+ Cache and compose the middleware stack for `recv`.
127
+ """
128
+ middleware = middleware_onion.as_tuple_of_middleware()
129
+ cache_key = hash(tuple(id(mw) for mw in middleware))
130
+
131
+ if cache_key != self._recv_func_cache[0]:
132
+
133
+ async def recv_function(rpc_request: RPCRequest) -> RPCResponse:
134
+ # first, retrieve the response
135
+ response = await self.recv_for_request(rpc_request)
136
+ method = rpc_request["method"]
137
+ for mw in reversed(middleware):
138
+ initialized = mw(async_w3)
139
+ response = await initialized.async_response_processor(
140
+ method, response
141
+ )
142
+ return response
143
+
144
+ self._recv_func_cache = (cache_key, recv_function)
145
+
146
+ return self._recv_func_cache[1]
78
147
 
79
148
  def get_endpoint_uri_or_ipc_path(self) -> str:
80
149
  if hasattr(self, "endpoint_uri"):
@@ -124,6 +193,7 @@ class PersistentConnectionProvider(AsyncJSONBaseProvider, ABC):
124
193
  _backoff_time *= _backoff_rate_change
125
194
 
126
195
  async def disconnect(self) -> None:
196
+ # this should remain idempotent
127
197
  try:
128
198
  if self._message_listener_task:
129
199
  self._message_listener_task.cancel()
@@ -140,15 +210,23 @@ class PersistentConnectionProvider(AsyncJSONBaseProvider, ABC):
140
210
  f"Successfully disconnected from: {self.get_endpoint_uri_or_ipc_path()}"
141
211
  )
142
212
 
143
- @async_handle_request_caching
144
- async def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse:
145
- request_data = self.encode_rpc_request(method, params)
146
- await self.socket_send(request_data)
213
+ @async_handle_send_caching
214
+ async def send_request(self, method: RPCEndpoint, params: Any) -> RPCRequest:
215
+ request_dict = self.form_request(method, params)
216
+ await self.socket_send(self.encode_rpc_dict(request_dict))
217
+ return request_dict
147
218
 
148
- current_request_id = json.loads(request_data)["id"]
149
- response = await self._get_response_for_request_id(current_request_id)
219
+ @async_handle_recv_caching
220
+ async def recv_for_request(self, rpc_request: RPCRequest) -> RPCResponse:
221
+ return await self._get_response_for_request_id(rpc_request["id"])
150
222
 
151
- return response
223
+ async def make_request(
224
+ self,
225
+ method: RPCEndpoint,
226
+ params: Any,
227
+ ) -> RPCResponse:
228
+ rpc_request = await self.send_request(method, params)
229
+ return await self.recv_for_request(rpc_request)
152
230
 
153
231
  async def make_batch_request(
154
232
  self, requests: List[Tuple[RPCEndpoint, Any]]
@@ -184,19 +262,45 @@ class PersistentConnectionProvider(AsyncJSONBaseProvider, ABC):
184
262
  raise NotImplementedError("Must be implemented by subclasses")
185
263
 
186
264
  async def _provider_specific_disconnect(self) -> None:
265
+ # this method should be idempotent
187
266
  raise NotImplementedError("Must be implemented by subclasses")
188
267
 
189
268
  async def _provider_specific_socket_reader(self) -> RPCResponse:
190
269
  raise NotImplementedError("Must be implemented by subclasses")
191
270
 
271
+ def _set_signal_handlers(self) -> None:
272
+ def extended_handler(sig: int, frame: Any, existing_handler: Any) -> None:
273
+ loop = asyncio.get_event_loop()
274
+
275
+ # invoke the existing handler, if callable
276
+ if callable(existing_handler):
277
+ existing_handler(sig, frame)
278
+ loop.create_task(self.disconnect())
279
+
280
+ existing_sigint_handler = signal.getsignal(signal.SIGINT)
281
+ existing_sigterm_handler = signal.getsignal(signal.SIGTERM)
282
+
283
+ # extend the existing signal handlers to include the disconnect method
284
+ signal.signal(
285
+ signal.SIGINT,
286
+ lambda sig, frame: extended_handler(sig, frame, existing_sigint_handler),
287
+ )
288
+ signal.signal(
289
+ signal.SIGTERM,
290
+ lambda sig, frame: extended_handler(sig, frame, existing_sigterm_handler),
291
+ )
292
+
192
293
  def _message_listener_callback(
193
294
  self, message_listener_task: "asyncio.Task[None]"
194
295
  ) -> None:
195
- # Puts a `TaskNotRunning` in the queue to signal the end of the listener task
196
- # to any running subscription streams that are awaiting a response.
296
+ # Puts a `TaskNotRunning` in appropriate queues to signal the end of the
297
+ # listener task to any listeners relying on the queues.
197
298
  self._request_processor._subscription_response_queue.put_nowait(
198
299
  TaskNotRunning(message_listener_task)
199
300
  )
301
+ self._request_processor._handler_subscription_queue.put_nowait(
302
+ TaskNotRunning(message_listener_task)
303
+ )
200
304
 
201
305
  async def _message_listener(self) -> None:
202
306
  self.logger.info(
@@ -2,9 +2,12 @@ from typing import (
2
2
  TYPE_CHECKING,
3
3
  Any,
4
4
  Dict,
5
+ Union,
6
+ cast,
5
7
  )
6
8
 
7
9
  from web3.types import (
10
+ FormattedEthSubscriptionResponse,
8
11
  RPCEndpoint,
9
12
  RPCResponse,
10
13
  )
@@ -16,6 +19,9 @@ if TYPE_CHECKING:
16
19
  from web3.manager import ( # noqa: F401
17
20
  _AsyncPersistentMessageStream,
18
21
  )
22
+ from web3.providers.persistent import ( # noqa: F401
23
+ PersistentConnectionProvider,
24
+ )
19
25
 
20
26
 
21
27
  class PersistentConnection:
@@ -26,6 +32,7 @@ class PersistentConnection:
26
32
 
27
33
  def __init__(self, w3: "AsyncWeb3"):
28
34
  self._manager = w3.manager
35
+ self.provider = cast("PersistentConnectionProvider", self._manager.provider)
29
36
 
30
37
  @property
31
38
  def subscriptions(self) -> Dict[str, Any]:
@@ -47,10 +54,10 @@ class PersistentConnection:
47
54
  :param method: The RPC method, e.g. `eth_getBlockByNumber`.
48
55
  :param params: The RPC method parameters, e.g. `["0x1337", False]`.
49
56
 
50
- :return: The processed response from the persistent connection.
57
+ :return: The unprocessed response from the persistent connection.
51
58
  :rtype: RPCResponse
52
59
  """
53
- return await self._manager.socket_request(method, params)
60
+ return await self.provider.make_request(method, params)
54
61
 
55
62
  async def send(self, method: RPCEndpoint, params: Any) -> None:
56
63
  """
@@ -63,14 +70,14 @@ class PersistentConnection:
63
70
  """
64
71
  await self._manager.send(method, params)
65
72
 
66
- async def recv(self) -> RPCResponse:
73
+ async def recv(self) -> Union[RPCResponse, FormattedEthSubscriptionResponse]:
67
74
  """
68
75
  Receive the next unprocessed response for a request from the persistent
69
76
  connection.
70
77
 
71
78
  :return: The next unprocessed response for a request from the persistent
72
79
  connection.
73
- :rtype: RPCResponse
80
+ :rtype: Union[RPCResponse, FormattedEthSubscriptionResponse]
74
81
  """
75
82
  return await self._manager.recv()
76
83