aws-crt 0.1.4 → 0.1.5

Sign up to get free protection for your applications and to get access to all the features.
Files changed (69) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/VERSION +1 -1
  4. data/aws-crt-ffi/crt/aws-c-cal/CMakeLists.txt +2 -0
  5. data/aws-crt-ffi/crt/aws-c-cal/bin/produce_x_platform_fuzz_corpus/CMakeLists.txt +30 -0
  6. data/aws-crt-ffi/crt/aws-c-cal/bin/produce_x_platform_fuzz_corpus/main.c +208 -0
  7. data/aws-crt-ffi/crt/aws-c-cal/bin/run_x_platform_fuzz_corpus/CMakeLists.txt +30 -0
  8. data/aws-crt-ffi/crt/aws-c-cal/bin/run_x_platform_fuzz_corpus/main.c +244 -0
  9. data/aws-crt-ffi/crt/aws-c-cal/ecdsa-fuzz-corpus/darwin/p256_sig_corpus.txt +10000 -0
  10. data/aws-crt-ffi/crt/aws-c-cal/ecdsa-fuzz-corpus/windows/p256_sig_corpus.txt +10000 -0
  11. data/aws-crt-ffi/crt/aws-c-cal/source/windows/bcrypt_ecc.c +8 -0
  12. data/aws-crt-ffi/crt/aws-c-http/tests/CMakeLists.txt +11 -10
  13. data/aws-crt-ffi/crt/aws-c-io/include/aws/io/tls_channel_handler.h +2 -0
  14. data/aws-crt-ffi/crt/aws-c-io/source/darwin/darwin_pki_utils.c +8 -0
  15. data/aws-crt-ffi/crt/aws-c-io/source/tls_channel_handler.c +2 -0
  16. data/aws-crt-ffi/crt/aws-c-io/source/windows/windows_pki_utils.c +65 -35
  17. data/aws-crt-ffi/crt/s2n/CMakeLists.txt +67 -21
  18. data/aws-crt-ffi/crt/s2n/Makefile +10 -0
  19. data/aws-crt-ffi/crt/s2n/bin/Makefile +9 -0
  20. data/aws-crt-ffi/crt/s2n/bindings/rust/Makefile +14 -0
  21. data/aws-crt-ffi/crt/s2n/bindings/rust/integration/Cargo.toml +2 -2
  22. data/aws-crt-ffi/crt/s2n/bindings/rust/s2n-tls/Cargo.toml +3 -2
  23. data/aws-crt-ffi/crt/s2n/bindings/rust/s2n-tls/src/raw/config.rs +265 -39
  24. data/aws-crt-ffi/crt/s2n/bindings/rust/s2n-tls/src/raw/connection.rs +170 -20
  25. data/aws-crt-ffi/crt/s2n/bindings/rust/s2n-tls/src/testing/s2n_tls.rs +120 -0
  26. data/aws-crt-ffi/crt/s2n/bindings/rust/s2n-tls/src/testing.rs +58 -23
  27. data/aws-crt-ffi/crt/s2n/bindings/rust/s2n-tls-sys/Cargo.toml +1 -1
  28. data/aws-crt-ffi/crt/s2n/bindings/rust/s2n-tls-sys/src/internal.rs +3 -0
  29. data/aws-crt-ffi/crt/s2n/crypto/s2n_composite_cipher_aes_sha.c +1 -1
  30. data/aws-crt-ffi/crt/s2n/crypto/s2n_drbg.c +8 -3
  31. data/aws-crt-ffi/crt/s2n/error/s2n_errno.c +3 -0
  32. data/aws-crt-ffi/crt/s2n/error/s2n_errno.h +2 -0
  33. data/aws-crt-ffi/crt/s2n/lib/Makefile +11 -0
  34. data/aws-crt-ffi/crt/s2n/pq-crypto/kyber_90s_r2/ntt.h +2 -2
  35. data/aws-crt-ffi/crt/s2n/pq-crypto/kyber_r2/ntt.h +2 -2
  36. data/aws-crt-ffi/crt/s2n/pq-crypto/kyber_r3/kyber512r3_poly_avx2.h +2 -2
  37. data/aws-crt-ffi/crt/s2n/pq-crypto/kyber_r3/kyber512r3_polyvec_avx2.h +2 -2
  38. data/aws-crt-ffi/crt/s2n/pq-crypto/sike_r1/P503_internal_r1.h +1 -1
  39. data/aws-crt-ffi/crt/s2n/pq-crypto/sike_r1/fips202_r1.h +1 -1
  40. data/aws-crt-ffi/crt/s2n/pq-crypto/sike_r3/sikep434r3_fp_x64_asm.S +4 -0
  41. data/aws-crt-ffi/crt/s2n/s2n.mk +25 -0
  42. data/aws-crt-ffi/crt/s2n/scripts/s2n_safety_macros.py +14 -0
  43. data/aws-crt-ffi/crt/s2n/tests/benchmark/Readme.md +23 -9
  44. data/aws-crt-ffi/crt/s2n/tests/features/clone.c +24 -0
  45. data/aws-crt-ffi/crt/s2n/tests/features/madvise.c +27 -0
  46. data/aws-crt-ffi/crt/s2n/tests/features/minherit.c +22 -0
  47. data/aws-crt-ffi/crt/s2n/tests/integrationv2/conftest.py +2 -2
  48. data/aws-crt-ffi/crt/s2n/tests/unit/s2n_connection_test.c +1 -1
  49. data/aws-crt-ffi/crt/s2n/tests/unit/s2n_fork_generation_number_test.c +335 -0
  50. data/aws-crt-ffi/crt/s2n/tests/unit/s2n_mem_usage_test.c +1 -1
  51. data/aws-crt-ffi/crt/s2n/tests/unit/s2n_self_talk_client_hello_cb_test.c +93 -11
  52. data/aws-crt-ffi/crt/s2n/tests/unit/s2n_server_hello_retry_test.c +123 -1
  53. data/aws-crt-ffi/crt/s2n/tests/unit/s2n_tls13_key_schedule_rfc8448_test.c +18 -3
  54. data/aws-crt-ffi/crt/s2n/tests/unit/s2n_tls13_key_schedule_test.c +0 -38
  55. data/aws-crt-ffi/crt/s2n/tests/unit/s2n_tls13_secrets_test.c +134 -15
  56. data/aws-crt-ffi/crt/s2n/tls/s2n_cipher_suites.c +1 -1
  57. data/aws-crt-ffi/crt/s2n/tls/s2n_client_hello.c +20 -9
  58. data/aws-crt-ffi/crt/s2n/tls/s2n_client_hello.h +8 -0
  59. data/aws-crt-ffi/crt/s2n/tls/s2n_config.c +13 -0
  60. data/aws-crt-ffi/crt/s2n/tls/s2n_config.h +6 -0
  61. data/aws-crt-ffi/crt/s2n/tls/s2n_handshake_io.c +2 -1
  62. data/aws-crt-ffi/crt/s2n/tls/s2n_internal.h +9 -0
  63. data/aws-crt-ffi/crt/s2n/tls/s2n_tls13_key_schedule.c +7 -7
  64. data/aws-crt-ffi/crt/s2n/tls/s2n_tls13_secrets.c +61 -8
  65. data/aws-crt-ffi/crt/s2n/tls/s2n_tls13_secrets.h +11 -5
  66. data/aws-crt-ffi/crt/s2n/utils/s2n_fork_detection.c +367 -0
  67. data/aws-crt-ffi/crt/s2n/utils/s2n_fork_detection.h +28 -0
  68. data/aws-crt-ffi/crt/s2n/utils/s2n_safety_macros.h +13 -22
  69. metadata +18 -3
@@ -2,70 +2,159 @@
2
2
  // SPDX-License-Identifier: Apache-2.0
3
3
 
4
4
  use crate::raw::{
5
+ connection::Connection,
5
6
  error::{Error, Fallible},
6
7
  security,
7
8
  };
8
- use alloc::sync::Arc;
9
- use core::{convert::TryInto, ptr::NonNull};
9
+ use core::{convert::TryInto, ptr::NonNull, task::Poll};
10
10
  use s2n_tls_sys::*;
11
- use std::ffi::CString;
11
+ use std::{
12
+ ffi::{c_void, CStr, CString},
13
+ mem::ManuallyDrop,
14
+ sync::atomic::{AtomicUsize, Ordering},
15
+ };
12
16
 
13
- struct Owned(NonNull<s2n_config>);
17
+ pub struct Config(NonNull<s2n_config>);
14
18
 
19
+ /// # Safety
20
+ ///
15
21
  /// Safety: s2n_config objects can be sent across threads
16
- unsafe impl Send for Owned {}
22
+ unsafe impl Send for Config {}
17
23
 
18
- impl Default for Owned {
19
- fn default() -> Self {
20
- Self::new()
24
+ impl Config {
25
+ /// Returns a Config object with pre-defined defaults.
26
+ ///
27
+ /// Use the [`Builder`] if custom configuration is desired.
28
+ pub fn new() -> Self {
29
+ Self::default()
21
30
  }
22
- }
23
31
 
24
- impl Owned {
25
- fn new() -> Self {
26
- crate::raw::init::init();
27
- let config = unsafe { s2n_config_new().into_result() }.unwrap();
32
+ /// Returns a Builder which can be used to configure the Config
33
+ pub fn builder() -> Builder {
34
+ Builder::default()
35
+ }
36
+
37
+ /// # Safety
38
+ ///
39
+ /// This config _MUST_ have been initialized with a [`Builder`].
40
+ pub(crate) unsafe fn from_raw(config: NonNull<s2n_config>) -> Self {
28
41
  Self(config)
29
42
  }
30
43
 
31
44
  pub(crate) fn as_mut_ptr(&mut self) -> *mut s2n_config {
32
45
  self.0.as_ptr()
33
46
  }
34
- }
35
47
 
36
- impl Drop for Owned {
37
- fn drop(&mut self) {
38
- let _ = unsafe { s2n_config_free(self.0.as_ptr()).into_result() };
48
+ /// Retrieve a reference to the [`Context`] stored on the config.
49
+ ///
50
+ /// # Safety
51
+ ///
52
+ /// This function assumes that this Config was created using a [`Builder`]
53
+ /// and has a [`Context`] stored on its context.
54
+ unsafe fn context(&self) -> &Context {
55
+ let mut ctx = core::ptr::null_mut();
56
+ s2n_config_get_ctx(self.0.as_ptr(), &mut ctx)
57
+ .into_result()
58
+ .unwrap();
59
+ &*(ctx as *const Context)
39
60
  }
40
- }
41
61
 
42
- #[derive(Clone, Default)]
43
- pub struct Config(Arc<Owned>);
62
+ /// Retrieve a mutable reference to the [`Context`] stored on the config.
63
+ ///
64
+ /// # Safety
65
+ ///
66
+ /// This function assumes that this Config was created using a [`Builder`]
67
+ /// and has a [`Context`] stored on its context.
68
+ unsafe fn context_mut(&mut self) -> &mut Context {
69
+ let mut ctx = core::ptr::null_mut();
70
+ s2n_config_get_ctx(self.as_mut_ptr(), &mut ctx)
71
+ .into_result()
72
+ .unwrap();
73
+ &mut *(ctx as *mut Context)
74
+ }
44
75
 
45
- /// Safety: s2n_config objects can be sent across threads
46
- #[allow(unknown_lints, clippy::non_send_fields_in_send_ty)]
47
- unsafe impl Send for Config {}
76
+ #[cfg(test)]
77
+ /// Get the refcount associated with the config
78
+ pub fn test_get_refcount(&self) -> Result<usize, Error> {
79
+ let context = unsafe { self.context() };
80
+ Ok(context.refcount.load(Ordering::SeqCst))
81
+ }
82
+ }
48
83
 
49
- impl Config {
50
- pub fn new() -> Self {
51
- Self::default()
84
+ impl Default for Config {
85
+ fn default() -> Self {
86
+ Builder::new().build().unwrap()
52
87
  }
88
+ }
53
89
 
54
- pub fn builder() -> Builder {
55
- Builder::default()
90
+ impl Clone for Config {
91
+ fn clone(&self) -> Self {
92
+ let context = unsafe { self.context() };
93
+
94
+ // Safety
95
+ //
96
+ // Using a relaxed ordering is alright here, as knowledge of the
97
+ // original reference prevents other threads from erroneously deleting
98
+ // the object.
99
+ // https://github.com/rust-lang/rust/blob/e012a191d768adeda1ee36a99ef8b92d51920154/library/alloc/src/sync.rs#L1329
100
+ let _count = context.refcount.fetch_add(1, Ordering::Relaxed);
101
+ Self(self.0)
56
102
  }
103
+ }
57
104
 
58
- pub(crate) fn as_mut_ptr(&mut self) -> *mut s2n_config {
59
- (self.0).0.as_ptr()
105
+ impl Drop for Config {
106
+ fn drop(&mut self) {
107
+ let context = unsafe { self.context_mut() };
108
+ let count = context.refcount.fetch_sub(1, Ordering::Release);
109
+ debug_assert!(count > 0, "refcount should not drop below 1 instance");
110
+
111
+ // only free the config if this is the last instance
112
+ if count != 1 {
113
+ return;
114
+ }
115
+
116
+ // Safety
117
+ //
118
+ // The use of Ordering and fence mirrors the `Arc` implementation in
119
+ // the standard library.
120
+ //
121
+ // This fence is needed to prevent reordering of use of the data and
122
+ // deletion of the data. Because it is marked `Release`, the decreasing
123
+ // of the reference count synchronizes with this `Acquire` fence. This
124
+ // means that use of the data happens before decreasing the reference
125
+ // count, which happens before this fence, which happens before the
126
+ // deletion of the data.
127
+ // https://github.com/rust-lang/rust/blob/e012a191d768adeda1ee36a99ef8b92d51920154/library/alloc/src/sync.rs#L1637
128
+ std::sync::atomic::fence(Ordering::Acquire);
129
+
130
+ unsafe {
131
+ // This is the last instance so free the context.
132
+ let context = Box::from_raw(context);
133
+ drop(context);
134
+
135
+ let _ = s2n_config_free(self.0.as_ptr()).into_result();
136
+ }
60
137
  }
61
138
  }
62
139
 
63
140
  #[derive(Default)]
64
- pub struct Builder(Owned);
141
+ pub struct Builder(Config);
65
142
 
66
143
  impl Builder {
67
144
  pub fn new() -> Self {
68
- Default::default()
145
+ crate::raw::init::init();
146
+ let config = unsafe { s2n_config_new().into_result() }.unwrap();
147
+
148
+ let context = Box::new(Context::default());
149
+ let context = Box::into_raw(context) as *mut c_void;
150
+
151
+ unsafe {
152
+ s2n_config_set_ctx(config.as_ptr(), context)
153
+ .into_result()
154
+ .unwrap();
155
+ }
156
+
157
+ Self(Config(config))
69
158
  }
70
159
 
71
160
  pub fn set_alert_behavior(
@@ -156,15 +245,57 @@ impl Builder {
156
245
  Ok(self)
157
246
  }
158
247
 
159
- /// # Safety
248
+ pub fn wipe_trust_store(&mut self) -> Result<&mut Self, Error> {
249
+ unsafe { s2n_config_wipe_trust_store(self.as_mut_ptr()).into_result()? };
250
+ Ok(self)
251
+ }
252
+
253
+ /// Sets whether or not a client certificate should be required to complete the TLS connection.
160
254
  ///
161
- /// The `context` pointer must live at least as long as the config
162
- pub unsafe fn set_verify_host_callback(
255
+ /// See the [Usage Guide](https://github.com/aws/s2n-tls/blob/main/docs/USAGE-GUIDE.md#client-auth-related-calls) for more details.
256
+ pub fn set_client_auth_type(
163
257
  &mut self,
164
- callback: s2n_verify_host_fn,
165
- context: *mut core::ffi::c_void,
258
+ auth_type: s2n_cert_auth_type::Type,
166
259
  ) -> Result<&mut Self, Error> {
167
- s2n_config_set_verify_host_callback(self.as_mut_ptr(), callback, context).into_result()?;
260
+ unsafe { s2n_config_set_client_auth_type(self.as_mut_ptr(), auth_type).into_result() }?;
261
+ Ok(self)
262
+ }
263
+
264
+ /// Set a custom callback function which is run during client certificate validation during
265
+ /// a mutual TLS handshake.
266
+ ///
267
+ /// The callback may be called more than once during certificate validation as each SAN on
268
+ /// the certificate will be checked.
269
+ pub fn set_verify_host_handler<T: 'static + VerifyClientCertificateHandler>(
270
+ &mut self,
271
+ handler: T,
272
+ ) -> Result<&mut Self, Error> {
273
+ unsafe extern "C" fn verify_host_cb(
274
+ host_name: *const ::libc::c_char,
275
+ host_name_len: usize,
276
+ context: *mut ::libc::c_void,
277
+ ) -> u8 {
278
+ let host_name = host_name as *const u8;
279
+ let host_name = core::slice::from_raw_parts(host_name, host_name_len);
280
+ if let Ok(host_name_str) = core::str::from_utf8(host_name) {
281
+ let context = &mut *(context as *mut Context);
282
+ let handler = context.verify_host_handler.as_mut().unwrap();
283
+ return handler.verify_host_name(host_name_str) as u8;
284
+ }
285
+ 0 // If the host name can't be parsed, fail closed.
286
+ }
287
+
288
+ let handler = Box::new(handler);
289
+ let context = unsafe { self.0.context_mut() };
290
+ context.verify_host_handler = Some(handler);
291
+ unsafe {
292
+ s2n_config_set_verify_host_callback(
293
+ self.as_mut_ptr(),
294
+ Some(verify_host_cb),
295
+ self.0.context_mut() as *mut _ as *mut c_void,
296
+ )
297
+ .into_result()?;
298
+ }
168
299
  Ok(self)
169
300
  }
170
301
 
@@ -185,8 +316,68 @@ impl Builder {
185
316
  Ok(self)
186
317
  }
187
318
 
319
+ /// Set a custom callback function which is run after parsing the client hello.
320
+ ///
321
+ /// The callback can be called more than once. The application is response for calling
322
+ /// [`Connection::server_name_extension_used()`] if connection properties are modified
323
+ /// on the server name extension.
324
+ pub fn set_client_hello_handler<T: 'static + ClientHelloHandler>(
325
+ &mut self,
326
+ handler: T,
327
+ ) -> Result<&mut Self, Error> {
328
+ unsafe extern "C" fn client_hello_cb(
329
+ connection_ptr: *mut s2n_connection,
330
+ context: *mut core::ffi::c_void,
331
+ ) -> libc::c_int {
332
+ let context = &mut *(context as *mut Context);
333
+ let handler = context.client_hello_handler.as_mut().unwrap();
334
+
335
+ let connection_ptr =
336
+ NonNull::new(connection_ptr).expect("connection should not be null");
337
+
338
+ // Since this is a callback, which receives a pointer to the connection, do
339
+ // not call drop on the connection object.
340
+ let mut connection = ManuallyDrop::new(Connection::from_raw(connection_ptr));
341
+
342
+ match handler.poll_client_hello(&mut connection) {
343
+ Poll::Ready(Ok(())) => {
344
+ s2n_client_hello_cb_done(connection_ptr.as_ptr());
345
+ s2n_status_code::SUCCESS
346
+ }
347
+ Poll::Ready(Err(_)) => {
348
+ s2n_client_hello_cb_done(connection_ptr.as_ptr());
349
+ s2n_status_code::FAILURE
350
+ }
351
+ Poll::Pending => s2n_status_code::SUCCESS,
352
+ }
353
+ }
354
+
355
+ let handler = Box::new(handler);
356
+ let context = unsafe { self.0.context_mut() };
357
+ context.client_hello_handler = Some(handler);
358
+
359
+ unsafe {
360
+ // Enable client_hello_callback polling to model the polling behavior of Futures in
361
+ // Rust
362
+ s2n_config_client_hello_cb_enable_poll(self.as_mut_ptr()).into_result()?;
363
+ s2n_config_set_client_hello_cb_mode(
364
+ self.as_mut_ptr(),
365
+ s2n_client_hello_cb_mode::NONBLOCKING,
366
+ )
367
+ .into_result()?;
368
+ s2n_config_set_client_hello_cb(
369
+ self.as_mut_ptr(),
370
+ Some(client_hello_cb),
371
+ self.0.context_mut() as *mut _ as *mut c_void,
372
+ )
373
+ .into_result()?;
374
+ }
375
+
376
+ Ok(self)
377
+ }
378
+
188
379
  pub fn build(self) -> Result<Config, Error> {
189
- Ok(Config(Arc::new(self.0)))
380
+ Ok(self.0)
190
381
  }
191
382
 
192
383
  fn as_mut_ptr(&mut self) -> *mut s2n_config {
@@ -201,3 +392,38 @@ impl Builder {
201
392
  Ok(self)
202
393
  }
203
394
  }
395
+
396
+ pub(crate) struct Context {
397
+ refcount: AtomicUsize,
398
+ client_hello_handler: Option<Box<dyn ClientHelloHandler>>,
399
+ verify_host_handler: Option<Box<dyn VerifyClientCertificateHandler>>,
400
+ }
401
+
402
+ impl Default for Context {
403
+ fn default() -> Self {
404
+ // The AtomicUsize is used to manually track the reference count of the Config.
405
+ // This mechanism is used to track when the Config object should be freed.
406
+ let refcount = AtomicUsize::new(1);
407
+
408
+ Self {
409
+ refcount,
410
+ client_hello_handler: None,
411
+ verify_host_handler: None,
412
+ }
413
+ }
414
+ }
415
+
416
+ /// This trait represents the client_hello callback which is run after parsing the client_hello.
417
+ ///
418
+ /// Use in conjunction with [`Builder::set_client_hello_handler()`].
419
+ pub trait ClientHelloHandler: Send + Sync {
420
+ fn poll_client_hello(&self, connection: &mut Connection) -> core::task::Poll<Result<(), ()>>;
421
+ }
422
+
423
+ /// Trait which a user must implement to verify host name(s) during X509 verification when using
424
+ /// mutual TLS.
425
+ pub trait VerifyClientCertificateHandler: Send + Sync {
426
+ /// The implementation shall verify the host name by returning `true` if the certificate host name is valid,
427
+ /// and `false` otherwise.
428
+ fn verify_host_name(&self, _host_name: &str) -> bool;
429
+ }
@@ -8,17 +8,20 @@ use crate::raw::{
8
8
  error::{Error, Fallible},
9
9
  security,
10
10
  };
11
- use core::{convert::TryInto, fmt, ptr::NonNull, task::Poll};
11
+ use core::{
12
+ convert::TryInto,
13
+ fmt,
14
+ ptr::NonNull,
15
+ task::{Poll, Waker},
16
+ };
12
17
  use libc::c_void;
13
18
  use s2n_tls_sys::*;
19
+ use std::{ffi::CStr, mem};
14
20
 
15
21
  pub use s2n_tls_sys::s2n_mode;
16
22
 
17
23
  pub struct Connection {
18
24
  connection: NonNull<s2n_connection>,
19
- // The config needs to be stored so the reference count is accurate
20
- #[allow(dead_code)]
21
- config: Option<Config>,
22
25
  }
23
26
 
24
27
  impl fmt::Debug for Connection {
@@ -29,6 +32,8 @@ impl fmt::Debug for Connection {
29
32
  }
30
33
  }
31
34
 
35
+ /// # Safety
36
+ ///
32
37
  /// Safety: s2n_connection objects can be sent across threads
33
38
  unsafe impl Send for Connection {}
34
39
 
@@ -38,16 +43,24 @@ impl Connection {
38
43
  let connection = unsafe { s2n_connection_new(mode).into_result() }.unwrap();
39
44
 
40
45
  unsafe {
41
- let mut config = core::ptr::null_mut();
42
46
  debug_assert! {
43
- s2n_connection_get_config(connection.as_ptr(), &mut config).into_result().is_err()
44
- };
47
+ s2n_connection_get_config(connection.as_ptr(), &mut core::ptr::null_mut())
48
+ .into_result()
49
+ .is_err()
50
+ }
45
51
  }
46
- Self {
47
- connection,
48
- config: None,
52
+ let context = Box::new(Context::default());
53
+ let context = Box::into_raw(context) as *mut c_void;
54
+ // allocate a new context object
55
+ unsafe {
56
+ s2n_connection_set_ctx(connection.as_ptr(), context)
57
+ .into_result()
58
+ .unwrap();
49
59
  }
60
+
61
+ Self { connection }
50
62
  }
63
+
51
64
  pub fn new_client() -> Self {
52
65
  Self::new(s2n_mode::CLIENT)
53
66
  }
@@ -56,6 +69,13 @@ impl Connection {
56
69
  Self::new(s2n_mode::SERVER)
57
70
  }
58
71
 
72
+ /// # Safety
73
+ ///
74
+ /// Caller must ensure s2n_connection is a valid reference to a [`s2n_connection`] object
75
+ pub(crate) unsafe fn from_raw(connection: NonNull<s2n_connection>) -> Self {
76
+ Self { connection }
77
+ }
78
+
59
79
  /// can be used to configure s2n to either use built-in blinding (set blinding
60
80
  /// to S2N_BUILT_IN_BLINDING) or self-service blinding (set blinding to
61
81
  /// S2N_SELF_SERVICE_BLINDING).
@@ -80,18 +100,49 @@ impl Connection {
80
100
  Ok(self)
81
101
  }
82
102
 
103
+ /// Attempts to drop the config on the connection.
104
+ ///
105
+ /// # Safety
106
+ ///
107
+ /// The caller must ensure the config associated with the connection was created
108
+ /// with a [`config::Builder`].
109
+ unsafe fn drop_config(&mut self) -> Result<(), Error> {
110
+ let mut prev_config = core::ptr::null_mut();
111
+
112
+ // A valid non-null pointer is returned only if the application previously called
113
+ // [`Self::set_config()`].
114
+ if s2n_connection_get_config(self.connection.as_ptr(), &mut prev_config)
115
+ .into_result()
116
+ .is_ok()
117
+ {
118
+ let prev_config = NonNull::new(prev_config).expect(
119
+ "config should exist since the call to s2n_connection_get_config was successful",
120
+ );
121
+ drop(Config::from_raw(prev_config));
122
+ }
123
+
124
+ Ok(())
125
+ }
126
+
83
127
  /// Associates a configuration object with a connection.
84
128
  pub fn set_config(&mut self, mut config: Config) -> Result<&mut Self, Error> {
85
129
  unsafe {
86
- s2n_connection_set_config(self.connection.as_ptr(), config.as_mut_ptr()).into_result()
87
- }?;
88
- unsafe {
89
- let mut config = core::ptr::null_mut();
130
+ // attempt to drop the currently set config
131
+ self.drop_config()?;
132
+
133
+ s2n_connection_set_config(self.connection.as_ptr(), config.as_mut_ptr())
134
+ .into_result()?;
135
+
90
136
  debug_assert! {
91
- s2n_connection_get_config(self.connection.as_ptr(), &mut config).into_result().is_ok()
137
+ s2n_connection_get_config(self.connection.as_ptr(), &mut core::ptr::null_mut()).into_result().is_ok(),
138
+ "s2n_connection_set_config was successful"
92
139
  };
140
+
141
+ // Setting the config on the connection creates one additional reference to the config
142
+ // so do not drop so prevent Rust from calling `drop()` at the end of this function.
143
+ mem::forget(config);
93
144
  }
94
- self.config = Some(config);
145
+
95
146
  Ok(self)
96
147
  }
97
148
 
@@ -251,11 +302,98 @@ impl Connection {
251
302
  }
252
303
 
253
304
  /// Sets the server name value for the connection
254
- pub fn set_server_name(&mut self, sni: &[u8]) -> Result<&mut Self, Error> {
255
- let sni = std::ffi::CString::new(sni).map_err(|_| Error::InvalidInput)?;
256
- unsafe { s2n_set_server_name(self.connection.as_ptr(), sni.as_ptr()).into_result() }?;
305
+ pub fn set_server_name(&mut self, server_name: &str) -> Result<&mut Self, Error> {
306
+ let server_name = std::ffi::CString::new(server_name).map_err(|_| Error::InvalidInput)?;
307
+ unsafe {
308
+ s2n_set_server_name(self.connection.as_ptr(), server_name.as_ptr()).into_result()
309
+ }?;
310
+ Ok(self)
311
+ }
312
+
313
+ /// Get the server name associated with the connection client hello.
314
+ pub fn server_name(&self) -> Option<&str> {
315
+ unsafe {
316
+ let server_name = s2n_get_server_name(self.connection.as_ptr());
317
+ match server_name.into_result() {
318
+ Ok(server_name) => CStr::from_ptr(server_name).to_str().ok(),
319
+ Err(_) => None,
320
+ }
321
+ }
322
+ }
323
+
324
+ /// Sets a Waker on the connection context or clears it if `None` is passed.
325
+ pub fn set_waker(&mut self, waker: Option<&Waker>) -> Result<&mut Self, Error> {
326
+ let ctx = self.context_mut();
327
+
328
+ if let Some(waker) = waker {
329
+ if let Some(prev_waker) = ctx.waker.as_mut() {
330
+ // only replace the Waker if they dont reference the same task
331
+ if !prev_waker.will_wake(waker) {
332
+ *prev_waker = waker.clone();
333
+ }
334
+ } else {
335
+ ctx.waker = Some(waker.clone());
336
+ }
337
+ } else {
338
+ ctx.waker = None;
339
+ }
257
340
  Ok(self)
258
341
  }
342
+
343
+ /// Returns the Waker set on the connection context.
344
+ pub fn waker(&self) -> Option<&Waker> {
345
+ let ctx = self.context();
346
+ ctx.waker.as_ref()
347
+ }
348
+
349
+ /// Retrieve a mutable reference to the [`Context`] stored on the connection.
350
+ fn context_mut(&mut self) -> &mut Context {
351
+ unsafe {
352
+ let ctx = s2n_connection_get_ctx(self.connection.as_ptr())
353
+ .into_result()
354
+ .unwrap();
355
+ &mut *(ctx.as_ptr() as *mut Context)
356
+ }
357
+ }
358
+
359
+ /// Retrieve a reference to the [`Context`] stored on the connection.
360
+ fn context(&self) -> &Context {
361
+ unsafe {
362
+ let ctx = s2n_connection_get_ctx(self.connection.as_ptr())
363
+ .into_result()
364
+ .unwrap();
365
+ &*(ctx.as_ptr() as *mut Context)
366
+ }
367
+ }
368
+
369
+ /// Mark that the server_name extension was used to configure the connection.
370
+ pub fn server_name_extension_used(&mut self) {
371
+ // TODO: requiring the application to call this method is a pretty sharp edge.
372
+ // Figure out if its possible to automatically call this from the Rust bindings.
373
+ unsafe {
374
+ s2n_connection_server_name_extension_used(self.connection.as_ptr())
375
+ .into_result()
376
+ .unwrap();
377
+ }
378
+ }
379
+
380
+ #[cfg(test)]
381
+ /// Test if a config has been set on the connection.
382
+ ///
383
+ /// `s2n_connection_get_config` should return a NULL pointer if the `s2n_connection_set_config`
384
+ /// has not been called by the application.
385
+ pub fn test_config_exists(&mut self) -> Result<(), Error> {
386
+ let mut config = core::ptr::null_mut();
387
+ let _config = unsafe {
388
+ s2n_connection_get_config(self.connection.as_ptr(), &mut config).into_result()?
389
+ };
390
+ Ok(())
391
+ }
392
+ }
393
+
394
+ #[derive(Default)]
395
+ struct Context {
396
+ waker: Option<Waker>,
259
397
  }
260
398
 
261
399
  #[cfg(feature = "quic")]
@@ -309,6 +447,18 @@ impl Connection {
309
447
  impl Drop for Connection {
310
448
  fn drop(&mut self) {
311
449
  // ignore failures since there's not much we can do about it
312
- let _ = unsafe { s2n_connection_free(self.connection.as_ptr()).into_result() };
450
+ unsafe {
451
+ // clean up context
452
+ let prev_ctx = self.context_mut();
453
+ drop(Box::from_raw(prev_ctx));
454
+ let _ = s2n_connection_set_ctx(self.connection.as_ptr(), core::ptr::null_mut())
455
+ .into_result();
456
+
457
+ // cleanup config
458
+ let _ = self.drop_config();
459
+
460
+ // cleanup connection
461
+ let _ = s2n_connection_free(self.connection.as_ptr()).into_result();
462
+ }
313
463
  }
314
464
  }