boto3-refresh-session 1.0.36__py3-none-any.whl → 6.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,570 @@
1
+ """IoT Core X.509 refreshable session implementation."""
2
+
3
+ __all__ = ["IOTX509RefreshableSession"]
4
+
5
+ import json
6
+ import re
7
+ from atexit import register
8
+ from pathlib import Path
9
+ from tempfile import NamedTemporaryFile
10
+ from typing import cast, get_args
11
+ from urllib.parse import ParseResult, urlparse
12
+
13
+ from awscrt import auth, io
14
+ from awscrt.exceptions import AwsCrtError
15
+ from awscrt.http import HttpClientConnection, HttpRequest
16
+ from awscrt.io import (
17
+ ClientBootstrap,
18
+ ClientTlsContext,
19
+ DefaultHostResolver,
20
+ EventLoopGroup,
21
+ LogLevel,
22
+ Pkcs11Lib,
23
+ TlsConnectionOptions,
24
+ TlsContextOptions,
25
+ init_logging,
26
+ )
27
+ from awscrt.mqtt import Connection
28
+ from awsiot import mqtt_connection_builder
29
+
30
+ from ...exceptions import BRSError, BRSWarning
31
+ from ...utils import (
32
+ PKCS11,
33
+ AWSCRTResponse,
34
+ Identity,
35
+ TemporaryCredentials,
36
+ Transport,
37
+ refreshable_session,
38
+ )
39
+ from .core import BaseIoTRefreshableSession
40
+
41
+ _TEMP_PATHS: list[str] = []
42
+
43
+
44
+ @refreshable_session
45
+ class IOTX509RefreshableSession(
46
+ BaseIoTRefreshableSession, registry_key="x509"
47
+ ):
48
+ """A :class:`boto3.session.Session` object that automatically refreshes
49
+ temporary credentials returned by the IoT Core credential provider.
50
+
51
+ Parameters
52
+ ----------
53
+ endpoint : str
54
+ The endpoint URL for the IoT Core credential provider. Must contain
55
+ '.credentials.iot.'.
56
+ role_alias : str
57
+ The IAM role alias to use when requesting temporary credentials.
58
+ certificate : str | bytes
59
+ The X.509 certificate to use when requesting temporary credentials.
60
+ ``str`` represents the file path to the certificate, while ``bytes``
61
+ represents the actual certificate data.
62
+ thing_name : str, optional
63
+ The name of the IoT thing to use when requesting temporary
64
+ credentials. Default is None.
65
+ private_key : str | bytes | None, optional
66
+ The private key to use when requesting temporary credentials. ``str``
67
+ represents the file path to the private key, while ``bytes``
68
+ represents the actual private key data. Optional only if ``pkcs11``
69
+ is provided. Default is None.
70
+ pkcs11 : PKCS11, optional
71
+ The PKCS#11 library to use when requesting temporary credentials. If
72
+ provided, ``private_key`` must be None.
73
+ ca : str | bytes | None, optional
74
+ The CA certificate to use when verifying the IoT Core endpoint. ``str``
75
+ represents the file path to the CA certificate, while ``bytes``
76
+ represents the actual CA certificate data. Default is None.
77
+ verify_peer : bool, optional
78
+ Whether to verify the CA certificate when establishing the TLS
79
+ connection. Default is True.
80
+ timeout : float | int | None, optional
81
+ The timeout for the TLS connection in seconds. Default is 10.0.
82
+ duration_seconds : int | None, optional
83
+ The duration for which the temporary credentials are valid, in
84
+ seconds. Cannot exceed the value declared in the IAM policy.
85
+ Default is None.
86
+ awscrt_log_level : awscrt.LogLevel | None, optional
87
+ The logging level for the AWS CRT library, e.g.
88
+ ``awscrt.LogLevel.INFO``. Default is None.
89
+ defer_refresh : bool, optional
90
+ If ``True`` then temporary credentials are not automatically refreshed
91
+ until they are explicitly needed. If ``False`` then temporary
92
+ credentials refresh immediately upon expiration. It is highly
93
+ recommended that you use ``True``. Default is ``True``.
94
+ advisory_timeout : int, optional
95
+ USE THIS ARGUMENT WITH CAUTION!!!
96
+
97
+ Botocore will attempt to refresh credentials early according to
98
+ this value (in seconds), but will continue using the existing
99
+ credentials if refresh fails. Default is 15 minutes (900 seconds).
100
+ mandatory_timeout : int, optional
101
+ USE THIS ARGUMENT WITH CAUTION!!!
102
+
103
+ Botocore requires a successful refresh before continuing. If
104
+ refresh fails in this window (in seconds), API calls may fail.
105
+ Default is 10 minutes (600 seconds).
106
+ cache_clients : bool, optional
107
+ If ``True`` then clients created by this session will be cached and
108
+ reused for subsequent calls to :meth:`client()` with the same
109
+ parameter signatures. Due to the memory overhead of clients, the
110
+ default is ``True`` in order to protect system resources.
111
+
112
+ Other Parameters
113
+ ----------------
114
+ kwargs : dict, optional
115
+ Optional keyword arguments for the :class:`boto3.session.Session`
116
+ object.
117
+
118
+ Notes
119
+ -----
120
+ Gavin Adams at AWS was a major influence on this implementation.
121
+ Thank you, Gavin!
122
+ """
123
+
124
+ def __init__(
125
+ self,
126
+ endpoint: str,
127
+ role_alias: str,
128
+ certificate: str | bytes,
129
+ thing_name: str | None = None,
130
+ private_key: str | bytes | None = None,
131
+ pkcs11: PKCS11 | None = None,
132
+ ca: str | bytes | None = None,
133
+ verify_peer: bool = True,
134
+ timeout: float | int | None = None,
135
+ duration_seconds: int | None = None,
136
+ awscrt_log_level: LogLevel | None = None,
137
+ **kwargs,
138
+ ):
139
+ # initializing BRSSession
140
+ super().__init__(refresh_method="iot-x509", **kwargs)
141
+
142
+ # logging
143
+ if awscrt_log_level:
144
+ init_logging(log_level=awscrt_log_level, file_name="stdout")
145
+
146
+ # initializing public attributes
147
+ self.endpoint = self._normalize_iot_credential_endpoint(
148
+ endpoint=endpoint
149
+ )
150
+ self.role_alias = role_alias
151
+ self.certificate = self._read_maybe_path_to_bytes(
152
+ certificate, fallback=None, name="certificate"
153
+ )
154
+ self.thing_name = thing_name
155
+ self.private_key = self._read_maybe_path_to_bytes(
156
+ private_key, fallback=None, name="private_key"
157
+ )
158
+ self.pkcs11 = self._validate_pkcs11(pkcs11) if pkcs11 else None
159
+ self.ca = self._read_maybe_path_to_bytes(ca, fallback=None, name="ca")
160
+ self.verify_peer = verify_peer
161
+ self.timeout = 10.0 if timeout is None else timeout
162
+ self.duration_seconds = duration_seconds
163
+
164
+ # either private_key or pkcs11 must be provided
165
+ if self.private_key is None and self.pkcs11 is None:
166
+ raise BRSError(
167
+ "Either 'private_key' or 'pkcs11' must be provided."
168
+ )
169
+
170
+ # . . . but both cannot be provided!
171
+ if self.private_key is not None and self.pkcs11 is not None:
172
+ raise BRSError(
173
+ "Only one of 'private_key' or 'pkcs11' can be provided."
174
+ )
175
+
176
+ def _get_credentials(self) -> TemporaryCredentials:
177
+ url = urlparse(
178
+ f"https://{self.endpoint}/role-aliases/{self.role_alias}"
179
+ "/credentials"
180
+ )
181
+ request = HttpRequest("GET", url.path)
182
+ request.headers.add("host", str(url.hostname))
183
+ if self.thing_name:
184
+ request.headers.add("x-amzn-iot-thingname", self.thing_name)
185
+ if self.duration_seconds:
186
+ request.headers.add(
187
+ "x-amzn-iot-credential-duration-seconds",
188
+ str(self.duration_seconds),
189
+ )
190
+ response = AWSCRTResponse()
191
+ port = 443 if not url.port else url.port
192
+ connection = (
193
+ self._mtls_client_connection(url=url, port=port)
194
+ if not self.pkcs11
195
+ else self._mtls_pkcs11_client_connection(url=url, port=port)
196
+ )
197
+
198
+ try:
199
+ stream = connection.request(
200
+ request, response.on_response, response.on_body
201
+ )
202
+ stream.activate()
203
+ stream.completion_future.result(float(self.timeout))
204
+ finally:
205
+ try:
206
+ connection.close()
207
+ except Exception:
208
+ ...
209
+
210
+ if response.status_code == 200:
211
+ credentials = json.loads(response.body.decode("utf-8"))[
212
+ "credentials"
213
+ ]
214
+ return {
215
+ "access_key": credentials["accessKeyId"],
216
+ "secret_key": credentials["secretAccessKey"],
217
+ "token": credentials["sessionToken"],
218
+ "expiry_time": credentials["expiration"],
219
+ }
220
+ else:
221
+ raise BRSError(
222
+ "Error getting credentials: "
223
+ f"{json.loads(response.body.decode())}"
224
+ )
225
+
226
+ def _mtls_client_connection(
227
+ self, url: ParseResult, port: int
228
+ ) -> HttpClientConnection:
229
+ event_loop_group: EventLoopGroup = EventLoopGroup()
230
+ host_resolver: DefaultHostResolver = DefaultHostResolver(
231
+ event_loop_group
232
+ )
233
+ bootstrap: ClientBootstrap = ClientBootstrap(
234
+ event_loop_group, host_resolver
235
+ )
236
+ tls_ctx_opt = TlsContextOptions.create_client_with_mtls(
237
+ cert_buffer=self.certificate, key_buffer=self.private_key
238
+ )
239
+
240
+ if self.ca:
241
+ tls_ctx_opt.override_default_trust_store(self.ca)
242
+
243
+ tls_ctx_opt.verify_peer = self.verify_peer
244
+ tls_ctx = ClientTlsContext(tls_ctx_opt)
245
+ tls_conn_opt: TlsConnectionOptions = cast(
246
+ TlsConnectionOptions, tls_ctx.new_connection_options()
247
+ )
248
+ tls_conn_opt.set_server_name(str(url.hostname))
249
+
250
+ try:
251
+ connection_future = HttpClientConnection.new(
252
+ host_name=str(url.hostname),
253
+ port=port,
254
+ bootstrap=bootstrap,
255
+ tls_connection_options=tls_conn_opt,
256
+ )
257
+ return connection_future.result(self.timeout)
258
+ except AwsCrtError as err:
259
+ raise BRSError(
260
+ "Error completing mTLS connection to endpoint "
261
+ f"'{url.hostname}'"
262
+ ) from err
263
+
264
+ def _mtls_pkcs11_client_connection(
265
+ self, url: ParseResult, port: int
266
+ ) -> HttpClientConnection:
267
+ event_loop_group: EventLoopGroup = EventLoopGroup()
268
+ host_resolver: DefaultHostResolver = DefaultHostResolver(
269
+ event_loop_group
270
+ )
271
+ bootstrap: ClientBootstrap = ClientBootstrap(
272
+ event_loop_group, host_resolver
273
+ )
274
+
275
+ if not self.pkcs11:
276
+ raise BRSError(
277
+ "Attempting to establish mTLS connection using PKCS#11"
278
+ "but 'pkcs11' parameter is 'None'!"
279
+ )
280
+
281
+ tls_ctx_opt = TlsContextOptions.create_client_with_mtls_pkcs11(
282
+ pkcs11_lib=Pkcs11Lib(file=self.pkcs11["pkcs11_lib"]),
283
+ user_pin=self.pkcs11["user_pin"],
284
+ slot_id=self.pkcs11["slot_id"],
285
+ token_label=self.pkcs11["token_label"],
286
+ private_key_label=self.pkcs11["private_key_label"],
287
+ cert_file_contents=self.certificate,
288
+ )
289
+
290
+ if self.ca:
291
+ tls_ctx_opt.override_default_trust_store(self.ca)
292
+
293
+ tls_ctx_opt.verify_peer = self.verify_peer
294
+ tls_ctx = ClientTlsContext(tls_ctx_opt)
295
+ tls_conn_opt: TlsConnectionOptions = cast(
296
+ TlsConnectionOptions, tls_ctx.new_connection_options()
297
+ )
298
+ tls_conn_opt.set_server_name(str(url.hostname))
299
+
300
+ try:
301
+ connection_future = HttpClientConnection.new(
302
+ host_name=str(url.hostname),
303
+ port=port,
304
+ bootstrap=bootstrap,
305
+ tls_connection_options=tls_conn_opt,
306
+ )
307
+ return connection_future.result(self.timeout)
308
+ except AwsCrtError as err:
309
+ raise BRSError("Error completing mTLS connection.") from err
310
+
311
+ def get_identity(self) -> Identity:
312
+ """Returns metadata about the current caller identity.
313
+
314
+ Returns
315
+ -------
316
+ Identity
317
+ Dict containing information about the current calleridentity.
318
+ """
319
+
320
+ return self.client("sts").get_caller_identity()
321
+
322
+ @staticmethod
323
+ def _normalize_iot_credential_endpoint(endpoint: str) -> str:
324
+ if ".credentials.iot." in endpoint:
325
+ return endpoint
326
+
327
+ if ".iot." in endpoint and "-ats." in endpoint:
328
+ logged_data_endpoint = re.sub(r"^[^. -]+", "***", endpoint)
329
+ logged_credential_endpoint = re.sub(
330
+ r"^[^. -]+",
331
+ "***",
332
+ (endpoint := endpoint.replace("-ats.iot", ".credentials.iot")),
333
+ )
334
+ BRSWarning.warn(
335
+ "The 'endpoint' parameter you provided represents the data "
336
+ "endpoint for IoT not the credentials endpoint! The endpoint "
337
+ "you provided was therefore modified from "
338
+ f"'{logged_data_endpoint}' -> '{logged_credential_endpoint}'"
339
+ )
340
+ return endpoint
341
+
342
+ raise BRSError(
343
+ "Invalid IoT endpoint provided for credentials provider. "
344
+ "Expected '<id>.credentials.iot.<region>.amazonaws.com'"
345
+ )
346
+
347
+ @staticmethod
348
+ def _validate_pkcs11(pkcs11: PKCS11) -> PKCS11:
349
+ if "pkcs11_lib" not in pkcs11:
350
+ raise BRSError(
351
+ "PKCS#11 library path must be provided as 'pkcs11_lib'"
352
+ " in 'pkcs11'."
353
+ )
354
+ elif not Path(pkcs11["pkcs11_lib"]).expanduser().resolve().is_file():
355
+ raise BRSError(
356
+ f"'{pkcs11['pkcs11_lib']}' is not a valid file path for "
357
+ "'pkcs11_lib' in 'pkcs11'."
358
+ )
359
+ pkcs11.setdefault("user_pin", None)
360
+ pkcs11.setdefault("slot_id", None)
361
+ pkcs11.setdefault("token_label", None)
362
+ pkcs11.setdefault("private_key_label", None)
363
+ return pkcs11
364
+
365
+ @staticmethod
366
+ def _read_maybe_path_to_bytes(
367
+ v: str | bytes | None, fallback: bytes | None, name: str
368
+ ) -> bytes | None:
369
+ match v:
370
+ case None:
371
+ return fallback
372
+ case bytes():
373
+ return v
374
+ case str() as p if Path(p).expanduser().resolve().is_file():
375
+ return Path(p).expanduser().resolve().read_bytes()
376
+ case _:
377
+ raise BRSError(f"Invalid {name} provided.")
378
+
379
+ @staticmethod
380
+ def _bytes_to_tempfile(b: bytes, suffix: str = ".pem") -> str:
381
+ f = NamedTemporaryFile("wb", suffix=suffix, delete=False)
382
+ f.write(b)
383
+ f.flush()
384
+ f.close()
385
+ _TEMP_PATHS.append(f.name)
386
+ return f.name
387
+
388
+ @staticmethod
389
+ @register
390
+ def _cleanup_tempfiles():
391
+ for p in _TEMP_PATHS:
392
+ try:
393
+ Path(p).unlink(missing_ok=True)
394
+ except Exception:
395
+ ...
396
+
397
+ def mqtt(
398
+ self,
399
+ *,
400
+ endpoint: str,
401
+ client_id: str,
402
+ transport: Transport = "x509",
403
+ certificate: str | bytes | None = None,
404
+ private_key: str | bytes | None = None,
405
+ ca: str | bytes | None = None,
406
+ pkcs11: PKCS11 | None = None,
407
+ region: str | None = None,
408
+ keep_alive_secs: int = 60,
409
+ clean_start: bool = True,
410
+ port: int | None = None,
411
+ use_alpn: bool = False,
412
+ ) -> Connection:
413
+ """Establishes an MQTT connection using the specified parameters.
414
+
415
+ .. versionadded:: 5.1.0
416
+
417
+ Parameters
418
+ ----------
419
+ endpoint: str
420
+ The MQTT endpoint to connect to.
421
+ client_id: str
422
+ The client ID to use for the MQTT connection.
423
+ transport: Transport
424
+ The transport protocol to use (e.g., "x509" or "ws").
425
+ certificate: str | bytes | None, optional
426
+ The client certificate to use for the connection. Defaults to the
427
+ session certificate.
428
+ private_key: str | bytes | None, optional
429
+ The private key to use for the connection. Defaults to the
430
+ session private key.
431
+ ca: str | bytes | None, optional
432
+ The CA certificate to use for the connection. Defaults to the
433
+ session CA certificate.
434
+ pkcs11: PKCS11 | None, optional
435
+ PKCS#11 configuration for hardware-backed keys. Defaults to the
436
+ session PKCS#11 configuration.
437
+ region: str | None, optional
438
+ The AWS region to use for the connection. Defaults to the
439
+ session region.
440
+ keep_alive_secs: int, optional
441
+ The keep-alive interval for the MQTT connection. Default is 60
442
+ seconds.
443
+ clean_start: bool, optional
444
+ Whether to start a clean session. Default is True.
445
+ port: int | None, optional
446
+ The port to use for the MQTT connection. Default is 8883 if not
447
+ using ALPN, otherwise 443.
448
+ use_alpn: bool, optional
449
+ Whether to use ALPN for the connection. Default is False.
450
+
451
+ Returns
452
+ -------
453
+ awscrt.mqtt.Connection
454
+ The established MQTT connection.
455
+ """
456
+
457
+ # Validate transport
458
+ if transport not in list(get_args(Transport)):
459
+ raise BRSError("Transport must be 'x509' or 'ws'")
460
+
461
+ # Region default (WS only)
462
+ if region is None:
463
+ region = self.region_name
464
+
465
+ # Normalize inputs to bytes using session defaults
466
+ cert_bytes = self._read_maybe_path_to_bytes(
467
+ certificate, getattr(self, "certificate", None), "certificate"
468
+ )
469
+ key_bytes = self._read_maybe_path_to_bytes(
470
+ private_key, getattr(self, "private_key", None), "private_key"
471
+ )
472
+ ca_bytes = self._read_maybe_path_to_bytes(
473
+ ca, getattr(self, "ca", None), "ca"
474
+ )
475
+
476
+ # Validate PKCS#11
477
+ match pkcs11:
478
+ case None:
479
+ pkcs11 = getattr(self, "pkcs11", None)
480
+ case dict():
481
+ pkcs11 = self._validate_pkcs11(pkcs11)
482
+ case _:
483
+ raise BRSError("Invalid PKCS#11 configuration provided.")
484
+
485
+ # X.509 invariants
486
+ if transport == "x509":
487
+ has_key = key_bytes is not None
488
+ has_hsm = pkcs11 is not None
489
+ if not has_key and not has_hsm:
490
+ raise BRSError(
491
+ "For transport='x509', provide either 'private_key' "
492
+ "(bytes/path) or 'pkcs11'."
493
+ )
494
+ if has_key and has_hsm:
495
+ raise BRSError(
496
+ "Provide only one of 'private_key' or 'pkcs11' for "
497
+ "transport='x509'."
498
+ )
499
+ if cert_bytes is None:
500
+ raise BRSError("Certificate is required for transport='x509'")
501
+
502
+ # CRT bootstrap
503
+ event_loop = io.EventLoopGroup(1)
504
+ host_resolver = io.DefaultHostResolver(event_loop)
505
+ bootstrap = io.ClientBootstrap(event_loop, host_resolver)
506
+
507
+ # Build connection
508
+ if transport == "x509":
509
+ if pkcs11 is not None:
510
+ # Cert must be a filepath for PKCS#11 builder → write temp
511
+ cert_path = self._bytes_to_tempfile(
512
+ cast(bytes, cert_bytes), ".crt"
513
+ )
514
+ ca_path = (
515
+ self._bytes_to_tempfile(ca_bytes, ".pem")
516
+ if ca_bytes
517
+ else None
518
+ )
519
+
520
+ return mqtt_connection_builder.mtls_with_pkcs11(
521
+ endpoint=endpoint,
522
+ client_bootstrap=bootstrap,
523
+ pkcs11_lib=Pkcs11Lib(file=pkcs11["pkcs11_lib"]),
524
+ user_pin=pkcs11.get("user_pin"),
525
+ slot_id=pkcs11.get("slot_id"),
526
+ token_label=pkcs11.get("token_label"),
527
+ private_key_object=pkcs11.get("private_key_label"),
528
+ cert_filepath=cert_path,
529
+ ca_filepath=ca_path,
530
+ client_id=client_id,
531
+ clean_session=clean_start,
532
+ keep_alive_secs=keep_alive_secs,
533
+ port=port or (443 if use_alpn else 8883),
534
+ alpn_list=["x-amzn-mqtt-ca"] if use_alpn else None,
535
+ )
536
+ else:
537
+ # pure mTLS with in-memory cert/key/CA
538
+ return mqtt_connection_builder.mtls_from_bytes(
539
+ endpoint=endpoint,
540
+ cert_bytes=cert_bytes,
541
+ pri_key_bytes=key_bytes,
542
+ ca_bytes=ca_bytes,
543
+ client_bootstrap=bootstrap,
544
+ client_id=client_id,
545
+ clean_session=clean_start,
546
+ keep_alive_secs=keep_alive_secs,
547
+ port=port or (443 if use_alpn else 8883),
548
+ alpn_list=["x-amzn-mqtt-ca"] if use_alpn else None,
549
+ )
550
+
551
+ else: # transport == "ws"
552
+ # WebSockets + SigV4
553
+ creds_provider = auth.AwsCredentialsProvider.new_delegate(
554
+ self._credentials
555
+ )
556
+ ca_path = (
557
+ self._bytes_to_tempfile(ca_bytes, ".pem") if ca_bytes else None
558
+ )
559
+
560
+ return mqtt_connection_builder.websockets_with_default_aws_signing(
561
+ endpoint=endpoint,
562
+ client_bootstrap=bootstrap,
563
+ region=region,
564
+ credentials_provider=creds_provider,
565
+ client_id=client_id,
566
+ clean_session=clean_start,
567
+ keep_alive_secs=keep_alive_secs,
568
+ ca_filepath=ca_path,
569
+ port=port or 443,
570
+ )