wreq-rb 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (167) hide show
  1. checksums.yaml +7 -0
  2. data/Cargo.lock +2688 -0
  3. data/Cargo.toml +6 -0
  4. data/README.md +179 -0
  5. data/ext/wreq_rb/Cargo.toml +39 -0
  6. data/ext/wreq_rb/extconf.rb +22 -0
  7. data/ext/wreq_rb/src/client.rs +565 -0
  8. data/ext/wreq_rb/src/error.rs +25 -0
  9. data/ext/wreq_rb/src/lib.rs +20 -0
  10. data/ext/wreq_rb/src/response.rs +132 -0
  11. data/lib/wreq-rb/version.rb +5 -0
  12. data/lib/wreq-rb.rb +17 -0
  13. data/patches/0001-add-transfer-size-tracking.patch +292 -0
  14. data/vendor/wreq/Cargo.toml +306 -0
  15. data/vendor/wreq/LICENSE +202 -0
  16. data/vendor/wreq/README.md +122 -0
  17. data/vendor/wreq/examples/cert_store.rs +77 -0
  18. data/vendor/wreq/examples/connect_via_lower_priority_tokio_runtime.rs +258 -0
  19. data/vendor/wreq/examples/emulation.rs +118 -0
  20. data/vendor/wreq/examples/form.rs +14 -0
  21. data/vendor/wreq/examples/http1_websocket.rs +37 -0
  22. data/vendor/wreq/examples/http2_websocket.rs +45 -0
  23. data/vendor/wreq/examples/json_dynamic.rs +41 -0
  24. data/vendor/wreq/examples/json_typed.rs +47 -0
  25. data/vendor/wreq/examples/keylog.rs +16 -0
  26. data/vendor/wreq/examples/request_with_emulation.rs +115 -0
  27. data/vendor/wreq/examples/request_with_interface.rs +37 -0
  28. data/vendor/wreq/examples/request_with_local_address.rs +16 -0
  29. data/vendor/wreq/examples/request_with_proxy.rs +13 -0
  30. data/vendor/wreq/examples/request_with_redirect.rs +22 -0
  31. data/vendor/wreq/examples/request_with_version.rs +15 -0
  32. data/vendor/wreq/examples/tor_socks.rs +24 -0
  33. data/vendor/wreq/examples/unix_socket.rs +33 -0
  34. data/vendor/wreq/src/client/body.rs +304 -0
  35. data/vendor/wreq/src/client/conn/conn.rs +231 -0
  36. data/vendor/wreq/src/client/conn/connector.rs +549 -0
  37. data/vendor/wreq/src/client/conn/http.rs +1023 -0
  38. data/vendor/wreq/src/client/conn/proxy/socks.rs +233 -0
  39. data/vendor/wreq/src/client/conn/proxy/tunnel.rs +260 -0
  40. data/vendor/wreq/src/client/conn/proxy.rs +39 -0
  41. data/vendor/wreq/src/client/conn/tls_info.rs +98 -0
  42. data/vendor/wreq/src/client/conn/uds.rs +44 -0
  43. data/vendor/wreq/src/client/conn/verbose.rs +149 -0
  44. data/vendor/wreq/src/client/conn.rs +323 -0
  45. data/vendor/wreq/src/client/core/body/incoming.rs +485 -0
  46. data/vendor/wreq/src/client/core/body/length.rs +118 -0
  47. data/vendor/wreq/src/client/core/body.rs +34 -0
  48. data/vendor/wreq/src/client/core/common/buf.rs +149 -0
  49. data/vendor/wreq/src/client/core/common/rewind.rs +141 -0
  50. data/vendor/wreq/src/client/core/common/watch.rs +76 -0
  51. data/vendor/wreq/src/client/core/common.rs +3 -0
  52. data/vendor/wreq/src/client/core/conn/http1.rs +342 -0
  53. data/vendor/wreq/src/client/core/conn/http2.rs +307 -0
  54. data/vendor/wreq/src/client/core/conn.rs +11 -0
  55. data/vendor/wreq/src/client/core/dispatch.rs +299 -0
  56. data/vendor/wreq/src/client/core/error.rs +435 -0
  57. data/vendor/wreq/src/client/core/ext.rs +201 -0
  58. data/vendor/wreq/src/client/core/http1.rs +178 -0
  59. data/vendor/wreq/src/client/core/http2.rs +483 -0
  60. data/vendor/wreq/src/client/core/proto/h1/conn.rs +988 -0
  61. data/vendor/wreq/src/client/core/proto/h1/decode.rs +1170 -0
  62. data/vendor/wreq/src/client/core/proto/h1/dispatch.rs +684 -0
  63. data/vendor/wreq/src/client/core/proto/h1/encode.rs +580 -0
  64. data/vendor/wreq/src/client/core/proto/h1/io.rs +879 -0
  65. data/vendor/wreq/src/client/core/proto/h1/role.rs +694 -0
  66. data/vendor/wreq/src/client/core/proto/h1.rs +104 -0
  67. data/vendor/wreq/src/client/core/proto/h2/client.rs +650 -0
  68. data/vendor/wreq/src/client/core/proto/h2/ping.rs +539 -0
  69. data/vendor/wreq/src/client/core/proto/h2.rs +379 -0
  70. data/vendor/wreq/src/client/core/proto/headers.rs +138 -0
  71. data/vendor/wreq/src/client/core/proto.rs +58 -0
  72. data/vendor/wreq/src/client/core/rt/bounds.rs +57 -0
  73. data/vendor/wreq/src/client/core/rt/timer.rs +150 -0
  74. data/vendor/wreq/src/client/core/rt/tokio.rs +99 -0
  75. data/vendor/wreq/src/client/core/rt.rs +25 -0
  76. data/vendor/wreq/src/client/core/upgrade.rs +267 -0
  77. data/vendor/wreq/src/client/core.rs +16 -0
  78. data/vendor/wreq/src/client/emulation.rs +161 -0
  79. data/vendor/wreq/src/client/http/client/error.rs +142 -0
  80. data/vendor/wreq/src/client/http/client/exec.rs +29 -0
  81. data/vendor/wreq/src/client/http/client/extra.rs +77 -0
  82. data/vendor/wreq/src/client/http/client/lazy.rs +79 -0
  83. data/vendor/wreq/src/client/http/client/pool.rs +1105 -0
  84. data/vendor/wreq/src/client/http/client/util.rs +104 -0
  85. data/vendor/wreq/src/client/http/client.rs +1003 -0
  86. data/vendor/wreq/src/client/http/future.rs +99 -0
  87. data/vendor/wreq/src/client/http.rs +1629 -0
  88. data/vendor/wreq/src/client/layer/config/options.rs +156 -0
  89. data/vendor/wreq/src/client/layer/config.rs +116 -0
  90. data/vendor/wreq/src/client/layer/cookie.rs +161 -0
  91. data/vendor/wreq/src/client/layer/decoder.rs +139 -0
  92. data/vendor/wreq/src/client/layer/redirect/future.rs +270 -0
  93. data/vendor/wreq/src/client/layer/redirect/policy.rs +63 -0
  94. data/vendor/wreq/src/client/layer/redirect.rs +145 -0
  95. data/vendor/wreq/src/client/layer/retry/classify.rs +105 -0
  96. data/vendor/wreq/src/client/layer/retry/scope.rs +51 -0
  97. data/vendor/wreq/src/client/layer/retry.rs +151 -0
  98. data/vendor/wreq/src/client/layer/timeout/body.rs +233 -0
  99. data/vendor/wreq/src/client/layer/timeout/future.rs +90 -0
  100. data/vendor/wreq/src/client/layer/timeout.rs +177 -0
  101. data/vendor/wreq/src/client/layer.rs +15 -0
  102. data/vendor/wreq/src/client/multipart.rs +717 -0
  103. data/vendor/wreq/src/client/request.rs +818 -0
  104. data/vendor/wreq/src/client/response.rs +534 -0
  105. data/vendor/wreq/src/client/ws/json.rs +99 -0
  106. data/vendor/wreq/src/client/ws/message.rs +453 -0
  107. data/vendor/wreq/src/client/ws.rs +714 -0
  108. data/vendor/wreq/src/client.rs +27 -0
  109. data/vendor/wreq/src/config.rs +140 -0
  110. data/vendor/wreq/src/cookie.rs +579 -0
  111. data/vendor/wreq/src/dns/gai.rs +249 -0
  112. data/vendor/wreq/src/dns/hickory.rs +78 -0
  113. data/vendor/wreq/src/dns/resolve.rs +180 -0
  114. data/vendor/wreq/src/dns.rs +69 -0
  115. data/vendor/wreq/src/error.rs +502 -0
  116. data/vendor/wreq/src/ext.rs +398 -0
  117. data/vendor/wreq/src/hash.rs +143 -0
  118. data/vendor/wreq/src/header.rs +506 -0
  119. data/vendor/wreq/src/into_uri.rs +187 -0
  120. data/vendor/wreq/src/lib.rs +586 -0
  121. data/vendor/wreq/src/proxy/mac.rs +82 -0
  122. data/vendor/wreq/src/proxy/matcher.rs +806 -0
  123. data/vendor/wreq/src/proxy/uds.rs +66 -0
  124. data/vendor/wreq/src/proxy/win.rs +31 -0
  125. data/vendor/wreq/src/proxy.rs +569 -0
  126. data/vendor/wreq/src/redirect.rs +575 -0
  127. data/vendor/wreq/src/retry.rs +198 -0
  128. data/vendor/wreq/src/sync.rs +129 -0
  129. data/vendor/wreq/src/tls/conn/cache.rs +123 -0
  130. data/vendor/wreq/src/tls/conn/cert_compression.rs +125 -0
  131. data/vendor/wreq/src/tls/conn/ext.rs +82 -0
  132. data/vendor/wreq/src/tls/conn/macros.rs +34 -0
  133. data/vendor/wreq/src/tls/conn/service.rs +138 -0
  134. data/vendor/wreq/src/tls/conn.rs +681 -0
  135. data/vendor/wreq/src/tls/keylog/handle.rs +64 -0
  136. data/vendor/wreq/src/tls/keylog.rs +99 -0
  137. data/vendor/wreq/src/tls/options.rs +464 -0
  138. data/vendor/wreq/src/tls/x509/identity.rs +122 -0
  139. data/vendor/wreq/src/tls/x509/parser.rs +71 -0
  140. data/vendor/wreq/src/tls/x509/store.rs +228 -0
  141. data/vendor/wreq/src/tls/x509.rs +68 -0
  142. data/vendor/wreq/src/tls.rs +154 -0
  143. data/vendor/wreq/src/trace.rs +55 -0
  144. data/vendor/wreq/src/util.rs +122 -0
  145. data/vendor/wreq/tests/badssl.rs +228 -0
  146. data/vendor/wreq/tests/brotli.rs +350 -0
  147. data/vendor/wreq/tests/client.rs +1098 -0
  148. data/vendor/wreq/tests/connector_layers.rs +227 -0
  149. data/vendor/wreq/tests/cookie.rs +306 -0
  150. data/vendor/wreq/tests/deflate.rs +347 -0
  151. data/vendor/wreq/tests/emulation.rs +260 -0
  152. data/vendor/wreq/tests/gzip.rs +347 -0
  153. data/vendor/wreq/tests/layers.rs +261 -0
  154. data/vendor/wreq/tests/multipart.rs +165 -0
  155. data/vendor/wreq/tests/proxy.rs +438 -0
  156. data/vendor/wreq/tests/redirect.rs +629 -0
  157. data/vendor/wreq/tests/retry.rs +135 -0
  158. data/vendor/wreq/tests/support/delay_server.rs +117 -0
  159. data/vendor/wreq/tests/support/error.rs +16 -0
  160. data/vendor/wreq/tests/support/layer.rs +183 -0
  161. data/vendor/wreq/tests/support/mod.rs +9 -0
  162. data/vendor/wreq/tests/support/server.rs +232 -0
  163. data/vendor/wreq/tests/timeouts.rs +281 -0
  164. data/vendor/wreq/tests/unix_socket.rs +135 -0
  165. data/vendor/wreq/tests/upgrade.rs +98 -0
  166. data/vendor/wreq/tests/zstd.rs +559 -0
  167. metadata +225 -0
@@ -0,0 +1,714 @@
1
+ //! WebSocket Upgrade
2
+
3
+ #[cfg(feature = "json")]
4
+ mod json;
5
+ pub mod message;
6
+
7
+ use std::{
8
+ borrow::Cow,
9
+ fmt,
10
+ net::{IpAddr, Ipv4Addr, Ipv6Addr},
11
+ ops::{Deref, DerefMut},
12
+ pin::Pin,
13
+ task::{Context, Poll, ready},
14
+ };
15
+
16
+ use futures_util::{Sink, SinkExt, Stream, StreamExt, stream::FusedStream};
17
+ use http::{
18
+ HeaderMap, HeaderName, HeaderValue, Method, StatusCode, Uri, Version, header, uri::Scheme,
19
+ };
20
+ use http2::ext::Protocol;
21
+ use pin_project_lite::pin_project;
22
+ use tokio_tungstenite::tungstenite::{
23
+ self,
24
+ protocol::{self, CloseFrame, WebSocketConfig},
25
+ };
26
+
27
+ use self::message::{CloseCode, Message, Utf8Bytes};
28
+ use crate::{
29
+ EmulationFactory, Error, RequestBuilder, Response, Upgraded, header::OrigHeaderMap,
30
+ proxy::Proxy,
31
+ };
32
+
33
+ /// A WebSocket stream.
34
+ type WebSocketStream = tokio_tungstenite::WebSocketStream<Upgraded>;
35
+
36
+ /// Wrapper for [`RequestBuilder`] that performs the
37
+ /// websocket handshake when sent.
38
+ pub struct WebSocketRequestBuilder {
39
+ inner: RequestBuilder,
40
+ accept_key: Option<Cow<'static, str>>,
41
+ protocols: Option<Vec<Cow<'static, str>>>,
42
+ config: WebSocketConfig,
43
+ }
44
+
45
+ impl WebSocketRequestBuilder {
46
+ /// Creates a new WebSocket request builder.
47
+ pub fn new(inner: RequestBuilder) -> Self {
48
+ Self {
49
+ inner: inner.version(Version::HTTP_11),
50
+ accept_key: None,
51
+ protocols: None,
52
+ config: WebSocketConfig::default(),
53
+ }
54
+ }
55
+
56
+ /// Sets a custom WebSocket accept key.
57
+ ///
58
+ /// This method allows you to set a custom WebSocket accept key for the connection.
59
+ ///
60
+ /// # Arguments
61
+ ///
62
+ /// * `key` - The custom WebSocket accept key to set.
63
+ ///
64
+ /// # Returns
65
+ ///
66
+ /// * `Self` - The modified instance with the custom WebSocket accept key.
67
+ #[inline]
68
+ pub fn accept_key<K>(mut self, key: K) -> Self
69
+ where
70
+ K: Into<Cow<'static, str>>,
71
+ {
72
+ self.accept_key = Some(key.into());
73
+ self
74
+ }
75
+
76
+ /// Forces the WebSocket connection to use HTTP/2 protocol.
77
+ ///
78
+ /// This method configures the WebSocket connection to use HTTP/2's Extended
79
+ /// CONNECT Protocol (RFC 8441) for the handshake instead of the traditional
80
+ /// HTTP/1.1 upgrade mechanism.
81
+ ///
82
+ /// # Behavior
83
+ ///
84
+ /// - Uses `CONNECT` method with `:protocol: websocket` pseudo-header
85
+ /// - Requires server support for HTTP/2 WebSocket connections
86
+ /// - Will fail if server doesn't support HTTP/2 WebSocket upgrade
87
+ #[inline]
88
+ pub fn force_http2(mut self) -> Self {
89
+ self.inner = self.inner.version(Version::HTTP_2);
90
+ self
91
+ }
92
+
93
+ /// Sets the websocket subprotocols to request.
94
+ ///
95
+ /// This method allows you to specify the subprotocols that the websocket client
96
+ /// should request during the handshake. Subprotocols are used to define the type
97
+ /// of communication expected over the websocket connection.
98
+ #[inline]
99
+ pub fn protocols<P>(mut self, protocols: P) -> Self
100
+ where
101
+ P: IntoIterator,
102
+ P::Item: Into<Cow<'static, str>>,
103
+ {
104
+ let protocols = protocols.into_iter().map(Into::into).collect();
105
+ self.protocols = Some(protocols);
106
+ self
107
+ }
108
+
109
+ /// Sets the websocket max_frame_size configuration.
110
+ #[inline]
111
+ pub fn max_frame_size(mut self, max_frame_size: usize) -> Self {
112
+ self.config.max_frame_size = Some(max_frame_size);
113
+ self
114
+ }
115
+
116
+ /// Sets the websocket read_buffer_size configuration.
117
+ #[inline]
118
+ pub fn read_buffer_size(mut self, read_buffer_size: usize) -> Self {
119
+ self.config.read_buffer_size = read_buffer_size;
120
+ self
121
+ }
122
+
123
+ /// Sets the websocket write_buffer_size configuration.
124
+ #[inline]
125
+ pub fn write_buffer_size(mut self, write_buffer_size: usize) -> Self {
126
+ self.config.write_buffer_size = write_buffer_size;
127
+ self
128
+ }
129
+
130
+ /// Sets the websocket max_write_buffer_size configuration.
131
+ #[inline]
132
+ pub fn max_write_buffer_size(mut self, max_write_buffer_size: usize) -> Self {
133
+ self.config.max_write_buffer_size = max_write_buffer_size;
134
+ self
135
+ }
136
+
137
+ /// Sets the websocket max_message_size configuration.
138
+ #[inline]
139
+ pub fn max_message_size(mut self, max_message_size: usize) -> Self {
140
+ self.config.max_message_size = Some(max_message_size);
141
+ self
142
+ }
143
+
144
+ /// Sets the websocket accept_unmasked_frames configuration.
145
+ #[inline]
146
+ pub fn accept_unmasked_frames(mut self, accept_unmasked_frames: bool) -> Self {
147
+ self.config.accept_unmasked_frames = accept_unmasked_frames;
148
+ self
149
+ }
150
+
151
+ /// Add a `Header` to this Request.
152
+ ///
153
+ /// If the header is already present, the value will be replaced.
154
+ #[inline]
155
+ pub fn header<K, V>(mut self, key: K, value: V) -> Self
156
+ where
157
+ HeaderName: TryFrom<K>,
158
+ <HeaderName as TryFrom<K>>::Error: Into<http::Error>,
159
+ HeaderValue: TryFrom<V>,
160
+ <HeaderValue as TryFrom<V>>::Error: Into<http::Error>,
161
+ {
162
+ self.inner = self.inner.header(key, value);
163
+ self
164
+ }
165
+
166
+ /// Add a set of Headers to the existing ones on this Request.
167
+ ///
168
+ /// The headers will be merged in to any already set.
169
+ #[inline]
170
+ pub fn headers(mut self, headers: HeaderMap) -> Self {
171
+ self.inner = self.inner.headers(headers);
172
+ self
173
+ }
174
+
175
+ /// Set the original headers for this request.
176
+ #[inline]
177
+ pub fn orig_headers(mut self, orig_headers: OrigHeaderMap) -> Self {
178
+ self.inner = self.inner.orig_headers(orig_headers);
179
+ self
180
+ }
181
+
182
+ /// Enable or disable client default headers for this request.
183
+ ///
184
+ /// By default, client default headers are included. Set to `false` to skip them.
185
+ pub fn default_headers(mut self, enable: bool) -> Self {
186
+ self.inner = self.inner.default_headers(enable);
187
+ self
188
+ }
189
+
190
+ /// Enable HTTP authentication.
191
+ ///
192
+ /// ```rust
193
+ /// # use wreq::Error;
194
+ /// #
195
+ /// # async fn run() -> Result<(), Error> {
196
+ /// let client = wreq::Client::new();
197
+ /// let resp = client
198
+ /// .websocket("http://httpbin.org/get")
199
+ /// .auth("your_token_here")
200
+ /// .send()
201
+ /// .await?;
202
+ /// # Ok(())
203
+ /// # }
204
+ /// ```
205
+ #[inline]
206
+ pub fn auth<V>(mut self, value: V) -> Self
207
+ where
208
+ HeaderValue: TryFrom<V>,
209
+ <HeaderValue as TryFrom<V>>::Error: Into<http::Error>,
210
+ {
211
+ self.inner = self.inner.auth(value);
212
+ self
213
+ }
214
+
215
+ /// Enable HTTP basic authentication.
216
+ ///
217
+ /// ```rust
218
+ /// # use wreq::Error;
219
+ ///
220
+ /// # async fn run() -> Result<(), Error> {
221
+ /// let client = wreq::Client::new();
222
+ /// let resp = client
223
+ /// .websocket("http://httpbin.org/delete")
224
+ /// .basic_auth("admin", Some("good password"))
225
+ /// .send()
226
+ /// .await?;
227
+ /// # Ok(())
228
+ /// # }
229
+ /// ```
230
+ #[inline]
231
+ pub fn basic_auth<U, P>(mut self, username: U, password: Option<P>) -> Self
232
+ where
233
+ U: fmt::Display,
234
+ P: fmt::Display,
235
+ {
236
+ self.inner = self.inner.basic_auth(username, password);
237
+ self
238
+ }
239
+
240
+ /// Enable HTTP bearer authentication.
241
+ ///
242
+ /// ```rust
243
+ /// # use wreq::Error;
244
+ /// #
245
+ /// # async fn run() -> Result<(), Error> {
246
+ /// let client = wreq::Client::new();
247
+ /// let resp = client
248
+ /// .websocket("http://httpbin.org/get")
249
+ /// .bearer_auth("your_token_here")
250
+ /// .send()
251
+ /// .await?;
252
+ /// # Ok(())
253
+ /// # }
254
+ /// ```
255
+ #[inline]
256
+ pub fn bearer_auth<T>(mut self, token: T) -> Self
257
+ where
258
+ T: fmt::Display,
259
+ {
260
+ self.inner = self.inner.bearer_auth(token);
261
+ self
262
+ }
263
+
264
+ /// Modify the query string of the URI.
265
+ ///
266
+ /// Modifies the URI of this request, adding the parameters provided.
267
+ /// This method appends and does not overwrite. This means that it can
268
+ /// be called multiple times and that existing query parameters are not
269
+ /// overwritten if the same key is used. The key will simply show up
270
+ /// twice in the query string.
271
+ /// Calling `.query(&[("foo", "a"), ("foo", "b")])` gives `"foo=a&foo=b"`.
272
+ ///
273
+ /// # Note
274
+ /// This method does not support serializing a single key-value
275
+ /// pair. Instead of using `.query(("key", "val"))`, use a sequence, such
276
+ /// as `.query(&[("key", "val")])`. It's also possible to serialize structs
277
+ /// and maps into a key-value pair.
278
+ ///
279
+ /// # Errors
280
+ /// This method will fail if the object you provide cannot be serialized
281
+ /// into a query string.
282
+ #[inline]
283
+ #[cfg(feature = "query")]
284
+ #[cfg_attr(docsrs, doc(cfg(feature = "query")))]
285
+ pub fn query<T: serde::Serialize + ?Sized>(mut self, query: &T) -> Self {
286
+ self.inner = self.inner.query(query);
287
+ self
288
+ }
289
+
290
+ /// Set the proxy for this request.
291
+ #[inline]
292
+ pub fn proxy(mut self, proxy: Proxy) -> Self {
293
+ self.inner = self.inner.proxy(proxy);
294
+ self
295
+ }
296
+
297
+ /// Set the local address for this request.
298
+ #[inline]
299
+ pub fn local_address<V>(mut self, local_address: V) -> Self
300
+ where
301
+ V: Into<Option<IpAddr>>,
302
+ {
303
+ self.inner = self.inner.local_address(local_address);
304
+ self
305
+ }
306
+
307
+ /// Set the local addresses for this request.
308
+ #[inline]
309
+ pub fn local_addresses<V4, V6>(mut self, ipv4: V4, ipv6: V6) -> Self
310
+ where
311
+ V4: Into<Option<Ipv4Addr>>,
312
+ V6: Into<Option<Ipv6Addr>>,
313
+ {
314
+ self.inner = self.inner.local_addresses(ipv4, ipv6);
315
+ self
316
+ }
317
+
318
+ /// Set the interface for this request.
319
+ #[inline]
320
+ #[cfg(any(
321
+ target_os = "android",
322
+ target_os = "fuchsia",
323
+ target_os = "illumos",
324
+ target_os = "ios",
325
+ target_os = "linux",
326
+ target_os = "macos",
327
+ target_os = "solaris",
328
+ target_os = "tvos",
329
+ target_os = "visionos",
330
+ target_os = "watchos",
331
+ ))]
332
+ #[cfg_attr(
333
+ docsrs,
334
+ doc(cfg(any(
335
+ target_os = "android",
336
+ target_os = "fuchsia",
337
+ target_os = "illumos",
338
+ target_os = "ios",
339
+ target_os = "linux",
340
+ target_os = "macos",
341
+ target_os = "solaris",
342
+ target_os = "tvos",
343
+ target_os = "visionos",
344
+ target_os = "watchos",
345
+ )))
346
+ )]
347
+ pub fn interface<I>(mut self, interface: I) -> Self
348
+ where
349
+ I: Into<std::borrow::Cow<'static, str>>,
350
+ {
351
+ self.inner = self.inner.interface(interface);
352
+ self
353
+ }
354
+
355
+ /// Set the emulation for this request.
356
+ #[inline]
357
+ pub fn emulation<P>(mut self, factory: P) -> Self
358
+ where
359
+ P: EmulationFactory,
360
+ {
361
+ self.inner = self.inner.emulation(factory);
362
+ self
363
+ }
364
+
365
+ /// Sends the request and returns and [`WebSocketResponse`].
366
+ pub async fn send(self) -> Result<WebSocketResponse, Error> {
367
+ let (client, request) = self.inner.build_split();
368
+ let mut request = request?;
369
+
370
+ // Ensure the scheme is http or https
371
+ let uri = request.uri_mut();
372
+ let scheme = match uri.scheme_str() {
373
+ Some("ws") => Some(Scheme::HTTP),
374
+ Some("wss") => Some(Scheme::HTTPS),
375
+ _ => None,
376
+ };
377
+ if scheme.is_some() {
378
+ let mut parts = uri.clone().into_parts();
379
+ parts.scheme = scheme;
380
+ *uri = Uri::from_parts(parts).map_err(Error::builder)?;
381
+ }
382
+
383
+ // Get the version of the request
384
+ let version = request.version();
385
+
386
+ // Set the headers for the websocket handshake
387
+ let headers = request.headers_mut();
388
+ headers.insert(
389
+ header::SEC_WEBSOCKET_VERSION,
390
+ HeaderValue::from_static("13"),
391
+ );
392
+
393
+ // Ensure the request is HTTP 1.1/HTTP 2
394
+ let accept_key = match version {
395
+ Some(Version::HTTP_10 | Version::HTTP_11) => {
396
+ // Generate a nonce if one wasn't provided
397
+ let nonce = self
398
+ .accept_key
399
+ .unwrap_or_else(|| Cow::Owned(tungstenite::handshake::client::generate_key()));
400
+
401
+ headers.insert(header::UPGRADE, HeaderValue::from_static("websocket"));
402
+ headers.insert(header::CONNECTION, HeaderValue::from_static("upgrade"));
403
+ headers.insert(
404
+ header::SEC_WEBSOCKET_KEY,
405
+ HeaderValue::from_str(&nonce).map_err(Error::builder)?,
406
+ );
407
+
408
+ *request.method_mut() = Method::GET;
409
+ *request.version_mut() = Some(Version::HTTP_11);
410
+ Some(nonce)
411
+ }
412
+ Some(Version::HTTP_2) => {
413
+ *request.method_mut() = Method::CONNECT;
414
+ *request.version_mut() = Some(Version::HTTP_2);
415
+ request
416
+ .extensions_mut()
417
+ .insert(Protocol::from_static("websocket"));
418
+ None
419
+ }
420
+ unsupported => {
421
+ return Err(Error::upgrade(format!(
422
+ "unsupported version: {unsupported:?}"
423
+ )));
424
+ }
425
+ };
426
+
427
+ // Set websocket subprotocols
428
+ if let Some(ref protocols) = self.protocols {
429
+ // Sets subprotocols
430
+ if !protocols.is_empty() {
431
+ let subprotocols = protocols
432
+ .iter()
433
+ .map(|s| s.as_ref())
434
+ .collect::<Vec<&str>>()
435
+ .join(", ");
436
+
437
+ request.headers_mut().insert(
438
+ header::SEC_WEBSOCKET_PROTOCOL,
439
+ subprotocols.parse().map_err(Error::builder)?,
440
+ );
441
+ }
442
+ }
443
+
444
+ client
445
+ .execute(request)
446
+ .await
447
+ .map(|inner| WebSocketResponse {
448
+ inner,
449
+ accept_key,
450
+ protocols: self.protocols,
451
+ config: self.config,
452
+ })
453
+ }
454
+ }
455
+
456
+ /// The server's response to the websocket upgrade request.
457
+ ///
458
+ /// This implements `Deref<Target = Response>`, so you can access all the usual
459
+ /// information from the [`Response`].
460
+ #[derive(Debug)]
461
+ pub struct WebSocketResponse {
462
+ inner: Response,
463
+ accept_key: Option<Cow<'static, str>>,
464
+ protocols: Option<Vec<Cow<'static, str>>>,
465
+ config: WebSocketConfig,
466
+ }
467
+
468
+ impl Deref for WebSocketResponse {
469
+ type Target = Response;
470
+
471
+ fn deref(&self) -> &Self::Target {
472
+ &self.inner
473
+ }
474
+ }
475
+
476
+ impl DerefMut for WebSocketResponse {
477
+ fn deref_mut(&mut self) -> &mut Self::Target {
478
+ &mut self.inner
479
+ }
480
+ }
481
+
482
+ impl WebSocketResponse {
483
+ /// Turns the response into a websocket. This checks if the websocket
484
+ /// handshake was successful.
485
+ pub async fn into_websocket(self) -> Result<WebSocket, Error> {
486
+ let (inner, protocol) = {
487
+ let status = self.inner.status();
488
+ let headers = self.inner.headers();
489
+
490
+ match self.inner.version() {
491
+ // HTTP/1.0 and HTTP/1.1 use the traditional upgrade mechanism
492
+ Version::HTTP_10 | Version::HTTP_11 => {
493
+ if status != StatusCode::SWITCHING_PROTOCOLS {
494
+ return Err(Error::upgrade(format!("unexpected status code: {status}")));
495
+ }
496
+
497
+ if !header_contains(self.inner.headers(), header::CONNECTION, "upgrade") {
498
+ return Err(Error::upgrade("missing connection header"));
499
+ }
500
+
501
+ if !header_eq(self.inner.headers(), header::UPGRADE, "websocket") {
502
+ return Err(Error::upgrade("invalid upgrade header"));
503
+ }
504
+
505
+ match self
506
+ .accept_key
507
+ .zip(headers.get(header::SEC_WEBSOCKET_ACCEPT))
508
+ {
509
+ Some((nonce, header)) => {
510
+ if !header.to_str().is_ok_and(|s| {
511
+ s == tungstenite::handshake::derive_accept_key(nonce.as_bytes())
512
+ }) {
513
+ return Err(Error::upgrade(format!(
514
+ "invalid accept key: {header:?}"
515
+ )));
516
+ }
517
+ }
518
+ None => {
519
+ return Err(Error::upgrade("missing accept key"));
520
+ }
521
+ }
522
+ }
523
+ // HTTP/2 uses the Extended CONNECT Protocol (RFC 8441)
524
+ // See: https://datatracker.ietf.org/doc/html/rfc8441
525
+ Version::HTTP_2 => {
526
+ if status != StatusCode::OK {
527
+ return Err(Error::upgrade(format!("unexpected status code: {status}")));
528
+ }
529
+ }
530
+ _ => {
531
+ return Err(Error::upgrade(format!(
532
+ "unsupported version: {:?}",
533
+ self.inner.version()
534
+ )));
535
+ }
536
+ }
537
+
538
+ let protocol = headers.get(header::SEC_WEBSOCKET_PROTOCOL).cloned();
539
+ let requested = self.protocols.as_ref().filter(|p| !p.is_empty());
540
+ let replied = protocol.as_ref().and_then(|v| v.to_str().ok());
541
+
542
+ match (requested, replied) {
543
+ // okay, we requested protocols and got one back
544
+ (Some(req), Some(rep)) => {
545
+ if !req.contains(&Cow::Borrowed(rep)) {
546
+ return Err(Error::upgrade(format!("invalid protocol: {rep}")));
547
+ }
548
+ }
549
+ // server didn't reply with a protocol
550
+ (Some(_), None) => {
551
+ return Err(Error::upgrade(format!(
552
+ "missing protocol: {:?}",
553
+ self.protocols
554
+ )));
555
+ }
556
+ // we didn't request any protocols, but got one anyway
557
+ (None, Some(_)) => {
558
+ return Err(Error::upgrade(format!("invalid protocol: {protocol:?}")));
559
+ }
560
+ // we didn't request any protocols, so we don't expect one
561
+ (None, None) => {}
562
+ };
563
+
564
+ let upgraded = self.inner.upgrade().await?;
565
+ let inner = WebSocketStream::from_raw_socket(
566
+ upgraded,
567
+ protocol::Role::Client,
568
+ Some(self.config),
569
+ )
570
+ .await;
571
+
572
+ (inner, protocol)
573
+ };
574
+
575
+ Ok(WebSocket { inner, protocol })
576
+ }
577
+ }
578
+
579
+ /// Checks if the header value is equal to the given value.
580
+ fn header_eq(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
581
+ if let Some(header) = headers.get(&key) {
582
+ header.as_bytes().eq_ignore_ascii_case(value.as_bytes())
583
+ } else {
584
+ false
585
+ }
586
+ }
587
+
588
+ /// Checks if the header value contains the given value.
589
+ fn header_contains(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
590
+ let header = if let Some(header) = headers.get(&key) {
591
+ header
592
+ } else {
593
+ return false;
594
+ };
595
+
596
+ if let Ok(header) = std::str::from_utf8(header.as_bytes()) {
597
+ header.to_ascii_lowercase().contains(value)
598
+ } else {
599
+ false
600
+ }
601
+ }
602
+
603
+ pin_project! {
604
+ /// A websocket connection
605
+ #[derive(Debug)]
606
+ pub struct WebSocket {
607
+ #[pin]
608
+ inner: WebSocketStream,
609
+ protocol: Option<HeaderValue>,
610
+ }
611
+ }
612
+
613
+ impl WebSocket {
614
+ /// Return the selected WebSocket subprotocol, if one has been chosen.
615
+ #[inline]
616
+ pub fn protocol(&self) -> Option<&HeaderValue> {
617
+ self.protocol.as_ref()
618
+ }
619
+
620
+ /// Receive another message.
621
+ ///
622
+ /// Returns `None` if the stream has closed.
623
+ #[inline]
624
+ pub async fn recv(&mut self) -> Option<Result<Message, Error>> {
625
+ self.next().await
626
+ }
627
+
628
+ /// Send a message.
629
+ #[inline]
630
+ pub async fn send(&mut self, msg: Message) -> Result<(), Error> {
631
+ self.inner
632
+ .send(msg.into_tungstenite())
633
+ .await
634
+ .map_err(Error::websocket)
635
+ }
636
+
637
+ /// Closes the connection with a given code and (optional) reason.
638
+ pub async fn close<C, R>(mut self, code: C, reason: R) -> Result<(), Error>
639
+ where
640
+ C: Into<CloseCode>,
641
+ R: Into<Utf8Bytes>,
642
+ {
643
+ let close_frame = CloseFrame {
644
+ code: code.into().0.into(),
645
+ reason: reason.into().0,
646
+ };
647
+
648
+ self.inner
649
+ .close(Some(close_frame))
650
+ .await
651
+ .map_err(Error::websocket)
652
+ }
653
+ }
654
+
655
+ impl Sink<Message> for WebSocket {
656
+ type Error = Error;
657
+
658
+ #[inline]
659
+ fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
660
+ self.project()
661
+ .inner
662
+ .poll_ready(cx)
663
+ .map_err(Error::websocket)
664
+ }
665
+
666
+ #[inline]
667
+ fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
668
+ self.project()
669
+ .inner
670
+ .start_send(item.into_tungstenite())
671
+ .map_err(Error::websocket)
672
+ }
673
+
674
+ #[inline]
675
+ fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
676
+ self.project()
677
+ .inner
678
+ .poll_flush(cx)
679
+ .map_err(Error::websocket)
680
+ }
681
+
682
+ #[inline]
683
+ fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
684
+ self.project()
685
+ .inner
686
+ .poll_close(cx)
687
+ .map_err(Error::websocket)
688
+ }
689
+ }
690
+
691
+ impl FusedStream for WebSocket {
692
+ #[inline]
693
+ fn is_terminated(&self) -> bool {
694
+ self.inner.is_terminated()
695
+ }
696
+ }
697
+
698
+ impl Stream for WebSocket {
699
+ type Item = Result<Message, Error>;
700
+
701
+ fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
702
+ loop {
703
+ match ready!(self.inner.poll_next_unpin(cx)) {
704
+ Some(Ok(msg)) => {
705
+ if let Some(msg) = Message::from_tungstenite(msg) {
706
+ return Poll::Ready(Some(Ok(msg)));
707
+ }
708
+ }
709
+ Some(Err(err)) => return Poll::Ready(Some(Err(Error::body(err)))),
710
+ None => return Poll::Ready(None),
711
+ }
712
+ }
713
+ }
714
+ }