boto3-assist 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. boto3_assist/__init__.py +0 -0
  2. boto3_assist/aws_config.py +199 -0
  3. boto3_assist/aws_lambda/event_info.py +414 -0
  4. boto3_assist/aws_lambda/mock_context.py +5 -0
  5. boto3_assist/boto3session.py +87 -0
  6. boto3_assist/cloudwatch/cloudwatch_connection.py +84 -0
  7. boto3_assist/cloudwatch/cloudwatch_connection_tracker.py +17 -0
  8. boto3_assist/cloudwatch/cloudwatch_log_connection.py +62 -0
  9. boto3_assist/cloudwatch/cloudwatch_logs.py +39 -0
  10. boto3_assist/cloudwatch/cloudwatch_query.py +191 -0
  11. boto3_assist/cognito/cognito_authorizer.py +169 -0
  12. boto3_assist/cognito/cognito_connection.py +59 -0
  13. boto3_assist/cognito/cognito_utility.py +514 -0
  14. boto3_assist/cognito/jwks_cache.py +21 -0
  15. boto3_assist/cognito/user.py +27 -0
  16. boto3_assist/connection.py +146 -0
  17. boto3_assist/connection_tracker.py +120 -0
  18. boto3_assist/dynamodb/dynamodb.py +1206 -0
  19. boto3_assist/dynamodb/dynamodb_connection.py +113 -0
  20. boto3_assist/dynamodb/dynamodb_helpers.py +333 -0
  21. boto3_assist/dynamodb/dynamodb_importer.py +102 -0
  22. boto3_assist/dynamodb/dynamodb_index.py +507 -0
  23. boto3_assist/dynamodb/dynamodb_iservice.py +29 -0
  24. boto3_assist/dynamodb/dynamodb_key.py +130 -0
  25. boto3_assist/dynamodb/dynamodb_model_base.py +382 -0
  26. boto3_assist/dynamodb/dynamodb_model_base_interfaces.py +34 -0
  27. boto3_assist/dynamodb/dynamodb_re_indexer.py +165 -0
  28. boto3_assist/dynamodb/dynamodb_reindexer.py +165 -0
  29. boto3_assist/dynamodb/dynamodb_reserved_words.py +52 -0
  30. boto3_assist/dynamodb/dynamodb_reserved_words.txt +573 -0
  31. boto3_assist/dynamodb/readme.md +68 -0
  32. boto3_assist/dynamodb/troubleshooting.md +7 -0
  33. boto3_assist/ec2/ec2_connection.py +57 -0
  34. boto3_assist/environment_services/__init__.py +0 -0
  35. boto3_assist/environment_services/environment_loader.py +128 -0
  36. boto3_assist/environment_services/environment_variables.py +219 -0
  37. boto3_assist/erc/__init__.py +64 -0
  38. boto3_assist/erc/ecr_connection.py +57 -0
  39. boto3_assist/errors/custom_exceptions.py +46 -0
  40. boto3_assist/http_status_codes.py +80 -0
  41. boto3_assist/models/serializable_model.py +9 -0
  42. boto3_assist/role_assumption_mixin.py +38 -0
  43. boto3_assist/s3/s3.py +64 -0
  44. boto3_assist/s3/s3_bucket.py +67 -0
  45. boto3_assist/s3/s3_connection.py +76 -0
  46. boto3_assist/s3/s3_event_data.py +168 -0
  47. boto3_assist/s3/s3_object.py +695 -0
  48. boto3_assist/securityhub/securityhub.py +150 -0
  49. boto3_assist/securityhub/securityhub_connection.py +57 -0
  50. boto3_assist/session_setup_mixin.py +70 -0
  51. boto3_assist/ssm/connection.py +57 -0
  52. boto3_assist/ssm/parameter_store/parameter_store.py +116 -0
  53. boto3_assist/utilities/datetime_utility.py +349 -0
  54. boto3_assist/utilities/decimal_conversion_utility.py +140 -0
  55. boto3_assist/utilities/dictionary_utility.py +32 -0
  56. boto3_assist/utilities/file_operations.py +135 -0
  57. boto3_assist/utilities/http_utility.py +48 -0
  58. boto3_assist/utilities/logging_utility.py +0 -0
  59. boto3_assist/utilities/numbers_utility.py +329 -0
  60. boto3_assist/utilities/serialization_utility.py +664 -0
  61. boto3_assist/utilities/string_utility.py +337 -0
  62. boto3_assist/version.py +1 -0
  63. boto3_assist-0.32.0.dist-info/METADATA +76 -0
  64. boto3_assist-0.32.0.dist-info/RECORD +67 -0
  65. boto3_assist-0.32.0.dist-info/WHEEL +4 -0
  66. boto3_assist-0.32.0.dist-info/licenses/LICENSE-EXPLAINED.txt +11 -0
  67. boto3_assist-0.32.0.dist-info/licenses/LICENSE.txt +21 -0
@@ -0,0 +1,87 @@
1
+ """
2
+ Geek Cafe, LLC
3
+ Maintainers: Eric Wilson
4
+ MIT License. See Project Root for the license information.
5
+ """
6
+
7
+ from typing import Optional, List, Any
8
+ import boto3
9
+ from botocore.config import Config
10
+ from .session_setup_mixin import SessionSetupMixin
11
+ from .role_assumption_mixin import RoleAssumptionMixin
12
+
13
+
14
+ class Boto3SessionManager(SessionSetupMixin, RoleAssumptionMixin):
15
+ def __init__(
16
+ self,
17
+ service_name: str,
18
+ *,
19
+ aws_profile: Optional[str] = None,
20
+ aws_region: Optional[str] = None,
21
+ assume_role_arn: Optional[str] = None,
22
+ assume_role_chain: Optional[List[str]] = None,
23
+ assume_role_session_name: Optional[str] = None,
24
+ assume_role_duration_seconds: Optional[int] = 3600,
25
+ config: Optional[Config] = None,
26
+ aws_endpoint_url: Optional[str] = None,
27
+ aws_access_key_id: Optional[str] = None,
28
+ aws_secret_access_key: Optional[str] = None,
29
+ aws_session_token: Optional[str] = None,
30
+ ):
31
+ self.service_name = service_name
32
+ self.aws_profile = aws_profile
33
+ self.aws_region = aws_region
34
+ self.config = config
35
+ self.endpoint_url = aws_endpoint_url
36
+ self.assume_role_chain = assume_role_chain or (
37
+ [assume_role_arn] if assume_role_arn else []
38
+ )
39
+ self.assume_role_session_name = (
40
+ assume_role_session_name or f"AssumeRoleSessionFor{service_name}"
41
+ )
42
+ self.assume_role_duration_seconds = assume_role_duration_seconds
43
+ self.aws_access_key_id = aws_access_key_id
44
+ self.aws_secret_access_key = aws_secret_access_key
45
+ self.aws_session_token = aws_session_token
46
+
47
+ self.__session: Optional[boto3.Session] = None
48
+ self.__client: Any = None
49
+ self.__resource: Any = None
50
+
51
+ self.__initialize()
52
+
53
+ def __initialize(self):
54
+ base_session = self._create_base_session(
55
+ self.aws_profile,
56
+ self.aws_region,
57
+ self.aws_access_key_id,
58
+ self.aws_secret_access_key,
59
+ self.aws_session_token,
60
+ )
61
+
62
+ if self.assume_role_chain:
63
+ self.__session = self._assume_roles_in_chain(
64
+ base_session,
65
+ self.assume_role_chain,
66
+ self.assume_role_session_name,
67
+ self.assume_role_duration_seconds,
68
+ self.aws_region,
69
+ )
70
+ else:
71
+ self.__session = base_session
72
+
73
+ @property
74
+ def client(self) -> Any:
75
+ if not self.__client:
76
+ self.__client = self.__session.client(
77
+ self.service_name, config=self.config, endpoint_url=self.endpoint_url
78
+ )
79
+ return self.__client
80
+
81
+ @property
82
+ def resource(self) -> Any:
83
+ if not self.__resource:
84
+ self.__resource = self.__session.resource(
85
+ self.service_name, config=self.config, endpoint_url=self.endpoint_url
86
+ )
87
+ return self.__resource
@@ -0,0 +1,84 @@
1
+ """
2
+ Geek Cafe, LLC
3
+ Maintainers: Eric Wilson
4
+ MIT License. See Project Root for the license information.
5
+ """
6
+
7
+ from typing import Optional
8
+ from typing import TYPE_CHECKING
9
+
10
+ from aws_lambda_powertools import Logger
11
+
12
+
13
+ from boto3_assist.cloudwatch.cloudwatch_connection_tracker import (
14
+ CloudWatchConnectionTracker,
15
+ )
16
+ from boto3_assist.connection import Connection
17
+
18
+ if TYPE_CHECKING:
19
+ from mypy_boto3_cloudwatch import CloudWatchClient, CloudWatchServiceResource
20
+ else:
21
+ CloudWatchClient = object
22
+ CloudWatchServiceResource = object
23
+
24
+
25
+ logger = Logger()
26
+ tracker: CloudWatchConnectionTracker = CloudWatchConnectionTracker()
27
+
28
+
29
+ class CloudWatchConnection(Connection):
30
+ """CW Environment"""
31
+
32
+ def __init__(
33
+ self,
34
+ *,
35
+ aws_profile: Optional[str] = None,
36
+ aws_region: Optional[str] = None,
37
+ aws_access_key_id: Optional[str] = None,
38
+ aws_secret_access_key: Optional[str] = None,
39
+ ) -> None:
40
+ super().__init__(
41
+ service_name="cloudwatch",
42
+ aws_profile=aws_profile,
43
+ aws_region=aws_region,
44
+ aws_access_key_id=aws_access_key_id,
45
+ aws_secret_access_key=aws_secret_access_key,
46
+ )
47
+
48
+ self.__client: CloudWatchClient | None = None
49
+ self.__resource: CloudWatchServiceResource | None = None
50
+
51
+ self.raise_on_error: bool = True
52
+
53
+ @property
54
+ def client(self) -> CloudWatchClient:
55
+ """CloudWatch Client Connection"""
56
+ if self.__client is None:
57
+ logger.info("Creating CloudWatch Client")
58
+ self.__client = self.session.client
59
+
60
+ if self.raise_on_error and self.__client is None:
61
+ raise RuntimeError("CloudWatch Client is not available")
62
+ return self.__client
63
+
64
+ @client.setter
65
+ def client(self, value: CloudWatchClient):
66
+ logger.info("Setting CloudWatch Client")
67
+ self.__client = value
68
+
69
+ @property
70
+ def resource(self) -> CloudWatchServiceResource:
71
+ """CloudWatch Resource Connection"""
72
+ if self.__resource is None:
73
+ logger.info("Creating CloudWatch Resource")
74
+ self.__resource = self.session.resource
75
+
76
+ if self.raise_on_error and self.__resource is None:
77
+ raise RuntimeError("CloudWatch Resource is not available")
78
+
79
+ return self.__resource
80
+
81
+ @resource.setter
82
+ def resource(self, value: CloudWatchServiceResource):
83
+ logger.info("Setting CloudWatch Resource")
84
+ self.__resource = value
@@ -0,0 +1,17 @@
1
+ """
2
+ Geek Cafe, LLC
3
+ Maintainers: Eric Wilson
4
+ MIT License. See Project Root for the license information.
5
+ """
6
+
7
+ from boto3_assist.connection_tracker import ConnectionTracker
8
+
9
+
10
+ class CloudWatchConnectionTracker(ConnectionTracker):
11
+ """
12
+ Tracks CloudWatch Connection Requests.
13
+ Useful in for performance tuning and debugging.
14
+ """
15
+
16
+ def __init__(self) -> None:
17
+ super().__init__()
@@ -0,0 +1,62 @@
1
+ """
2
+ Geek Cafe, LLC
3
+ Maintainers: Eric Wilson
4
+ MIT License. See Project Root for the license information.
5
+ """
6
+
7
+ from typing import Optional
8
+ from typing import TYPE_CHECKING
9
+
10
+ from aws_lambda_powertools import Logger
11
+
12
+ from boto3_assist.cloudwatch.cloudwatch_connection_tracker import (
13
+ CloudWatchConnectionTracker,
14
+ )
15
+ from boto3_assist.connection import Connection
16
+
17
+ if TYPE_CHECKING:
18
+ from mypy_boto3_logs import CloudWatchLogsClient
19
+ else:
20
+ CloudWatchLogsClient = object
21
+
22
+
23
+ logger = Logger()
24
+ tracker: CloudWatchConnectionTracker = CloudWatchConnectionTracker()
25
+
26
+
27
+ class CloudWatchConnection(Connection):
28
+ """CW Logs Environment"""
29
+
30
+ def __init__(
31
+ self,
32
+ *,
33
+ aws_profile: Optional[str] = None,
34
+ aws_region: Optional[str] = None,
35
+ aws_access_key_id: Optional[str] = None,
36
+ aws_secret_access_key: Optional[str] = None,
37
+ ) -> None:
38
+ super().__init__(
39
+ service_name="logs",
40
+ aws_profile=aws_profile,
41
+ aws_region=aws_region,
42
+ aws_access_key_id=aws_access_key_id,
43
+ aws_secret_access_key=aws_secret_access_key,
44
+ )
45
+
46
+ self.__client: CloudWatchLogsClient | None = None
47
+
48
+ @property
49
+ def client(self) -> CloudWatchLogsClient:
50
+ """CloudWatch Client Connection"""
51
+ if self.__client is None:
52
+ logger.debug("Creating CloudWatch Client")
53
+ self.__client = self.session.client
54
+
55
+ if self.raise_on_error and self.__client is None:
56
+ raise RuntimeError("CloudWatch Client is not available")
57
+ return self.__client
58
+
59
+ @client.setter
60
+ def client(self, value: CloudWatchLogsClient):
61
+ logger.debug("Setting CloudWatch Client")
62
+ self.__client = value
@@ -0,0 +1,39 @@
1
+ """
2
+ Geek Cafe, LLC
3
+ Maintainers: Eric Wilson
4
+ MIT License. See Project Root for the license information.
5
+ """
6
+
7
+ from typing import Optional, List, Dict, Any
8
+ from boto3_assist.cloudwatch.cloudwatch_log_connection import CloudWatchConnection
9
+
10
+
11
+ class CloudWatchLogs(CloudWatchConnection):
12
+ def __init__(
13
+ self,
14
+ *,
15
+ aws_profile: Optional[str] = None,
16
+ aws_region: Optional[str] = None,
17
+ aws_access_key_id: Optional[str] = None,
18
+ aws_secret_access_key: Optional[str] = None,
19
+ ) -> None:
20
+ super().__init__(
21
+ aws_profile=aws_profile,
22
+ aws_region=aws_region,
23
+ aws_access_key_id=aws_access_key_id,
24
+ aws_secret_access_key=aws_secret_access_key,
25
+ )
26
+
27
+ def list_log_groups(self):
28
+ """Retrieve all log groups in the AWS account."""
29
+ log_groups: List[Dict[str, Any]] = []
30
+ paginator = self.client.get_paginator("describe_log_groups")
31
+ for page in paginator.paginate():
32
+ log_groups.extend(page["logGroups"]) # type: ignore[arg-type]
33
+ return log_groups
34
+
35
+
36
+ def main():
37
+ query: CloudWatchLogs = CloudWatchLogs()
38
+ result = query.list_log_groups()
39
+ print(result)
@@ -0,0 +1,191 @@
1
+ """
2
+ Geek Cafe, LLC
3
+ Maintainers: Eric Wilson
4
+ MIT License. See Project Root for the license information.
5
+ """
6
+
7
+ import os
8
+ from datetime import datetime, timedelta, UTC
9
+ from typing import Optional, Dict, Any, List
10
+ from boto3_assist.cloudwatch.cloudwatch_connection import CloudWatchConnection
11
+ from boto3_assist.cloudwatch.cloudwatch_logs import CloudWatchLogs
12
+
13
+
14
+ class CloudWatchQuery(CloudWatchConnection):
15
+ """Query Cloud Watch"""
16
+
17
+ def __init__(
18
+ self,
19
+ *,
20
+ aws_profile: Optional[str] = None,
21
+ aws_region: Optional[str] = None,
22
+ aws_access_key_id: Optional[str] = None,
23
+ aws_secret_access_key: Optional[str] = None,
24
+ ) -> None:
25
+ super().__init__(
26
+ aws_profile=aws_profile,
27
+ aws_region=aws_region,
28
+ aws_access_key_id=aws_access_key_id,
29
+ aws_secret_access_key=aws_secret_access_key,
30
+ )
31
+
32
+ self.__cw_logs: CloudWatchLogs | None = None
33
+
34
+ @property
35
+ def cw_logs(self) -> CloudWatchLogs:
36
+ """CloudWatch Logs Connection"""
37
+ if self.__cw_logs is None:
38
+ self.__cw_logs = CloudWatchLogs(
39
+ aws_profile=self.aws_profile,
40
+ aws_region=self.aws_region,
41
+ aws_access_key_id=self.aws_access_key_id,
42
+ aws_secret_access_key=self.aws_secret_access_key,
43
+ )
44
+ return self.__cw_logs
45
+
46
+ def get_log_group_size(
47
+ self, log_group_name: str, start_time: datetime, end_time: datetime
48
+ ) -> Dict[str, Any]:
49
+ """
50
+ Get the log group size for a given period of time
51
+ Args:
52
+ log_group_name (str): _description_
53
+ start_time (datetime): _description_
54
+ end_time (datetime): _description_
55
+
56
+ Returns:
57
+ _type_: _description_
58
+ """
59
+ response = self.client.get_metric_data(
60
+ MetricDataQueries=[
61
+ {
62
+ "Id": "storedBytes",
63
+ "MetricStat": {
64
+ "Metric": {
65
+ "Namespace": "AWS/Logs",
66
+ # "MetricName": "StoredBytes",
67
+ "MetricName": "IncomingBytes",
68
+ "Dimensions": [
69
+ {"Name": "LogGroupName", "Value": log_group_name}
70
+ ],
71
+ },
72
+ "Period": 86400, # Daily data
73
+ "Stat": "Sum",
74
+ },
75
+ "ReturnData": True,
76
+ },
77
+ ],
78
+ StartTime=start_time,
79
+ EndTime=end_time,
80
+ )
81
+
82
+ # Extract the total size in bytes for the period
83
+ size: float = 0.0
84
+ if response["MetricDataResults"]:
85
+ # Access the first MetricDataResult
86
+ metric_data_result = response["MetricDataResults"][0]
87
+ # Sum the values if they exist
88
+ size = (
89
+ sum(metric_data_result["Values"]) if metric_data_result["Values"] else 0
90
+ )
91
+ else:
92
+ size = 0
93
+
94
+ size_mb = size / (1024 * 1024)
95
+ size_gb = size_mb / 1024
96
+ resp: Dict[str, Any] = {
97
+ "LogGroupName": log_group_name,
98
+ "Size": {
99
+ "Bytes": size,
100
+ "MB": size_mb,
101
+ "GB": size_gb,
102
+ },
103
+ "StartDate": start_time.isoformat(),
104
+ "EndDate": end_time.isoformat(),
105
+ }
106
+
107
+ return resp
108
+
109
+ def get_log_sizes(
110
+ self,
111
+ start_date_time: datetime | None = None,
112
+ end_date_time: datetime | None = None,
113
+ days: int | None = 7,
114
+ top: int = 0,
115
+ ) -> List[Dict[str, Any]]:
116
+ """
117
+ Gets the log sizes for all log groups
118
+
119
+ Args:
120
+ start_date_time (datetime | None, optional): The Start Date. Defaults to None.
121
+ If None it's set to now in UTC time - the days field
122
+ end_date_time (datetime | None, optional): he Start Date. Defaults to None.
123
+ If None it's set to not in UTC time
124
+ days (int | None, optional): The days offset. Defaults to 7.
125
+ top (int, optional): If greater than zero it will return the top x after sorting
126
+ Defaults to 0.
127
+
128
+ Returns:
129
+ list: _description_
130
+ """
131
+ if not days:
132
+ days = 7
133
+ start_time = start_date_time or (datetime.now(UTC) - timedelta(days=days))
134
+ end_time = end_date_time or datetime.now(UTC)
135
+
136
+ # Step 1: List all log groups
137
+ log_groups = self.cw_logs.list_log_groups()
138
+ log_group_sizes = []
139
+
140
+ # Step 2: Get sizes for each log group
141
+ for log_group in log_groups:
142
+ log_group_name = log_group["logGroupName"]
143
+
144
+ size_info = self.get_log_group_size(log_group_name, start_time, end_time)
145
+ log_group_sizes.append(size_info)
146
+
147
+ # Step 3: Sort by size
148
+ # top_log_groups = sorted(log_group_sizes, key=lambda x: x[1], reverse=True)
149
+ top_log_groups = sorted(
150
+ log_group_sizes,
151
+ key=lambda x: x.get("Size", {}).get("Bytes", 0),
152
+ reverse=True,
153
+ )
154
+ if top and top > 0:
155
+ # find the top x if provided
156
+ top_log_groups = top_log_groups[:top]
157
+
158
+ return top_log_groups
159
+
160
+
161
+ def main():
162
+ log_group = os.environ.get("LOG_GROUP_QUERY_SAMPLE", "<enter-log-group-here>")
163
+ start = datetime.now() - timedelta(days=7) # Last 30 days
164
+ end = datetime.now()
165
+ cw_query: CloudWatchQuery = CloudWatchQuery()
166
+ result = cw_query.get_log_group_size(log_group, start, end)
167
+ print(result)
168
+
169
+ top = 25
170
+ days = 7
171
+ top_log_groups = cw_query.get_log_sizes(top=top, days=days)
172
+ print(f"Top {top} log groups by size for the last week:")
173
+
174
+ for top_log_group in top_log_groups:
175
+ log_group_name = top_log_group["LogGroupName"]
176
+ size_in_bytes = top_log_group.get("Size", {}).get("Bytes", 0)
177
+ size_in_megs = top_log_group.get("Size", {}).get("MB", 0)
178
+ size_in_gigs = top_log_group.get("Size", {}).get("GB", 0)
179
+ size: str = ""
180
+ if size_in_gigs > 1:
181
+ size = f"{size_in_gigs:.2f} GB"
182
+ elif size_in_megs > 1:
183
+ size = f"{size_in_megs:.2f} MB"
184
+ else:
185
+ size = f"{size_in_bytes} bytes"
186
+
187
+ print(f"{size}: {log_group_name}")
188
+
189
+
190
+ if __name__ == "__main__":
191
+ main()
@@ -0,0 +1,169 @@
1
+ """
2
+ Geek Cafe, LLC
3
+ Maintainers: Eric Wilson
4
+ MIT License. See Project Root for the license information.
5
+ """
6
+
7
+ import time
8
+ from typing import Any, Dict, List
9
+
10
+ import jwt # PyJWT
11
+ from aws_lambda_powertools import Logger
12
+ from jwt import InvalidTokenError, PyJWKClient
13
+
14
+ from boto3_assist.boto3session import Boto3SessionManager
15
+ from boto3_assist.cognito.jwks_cache import JwksCache
16
+
17
+ logger = Logger()
18
+
19
+ jwks_cache = JwksCache()
20
+
21
+
22
+ class CognitoCustomAuthorizer:
23
+ """Cognito Custom Authorizer"""
24
+
25
+ def __init__(self):
26
+ self.__client_connections: Dict[str, Any] = {}
27
+
28
+ def __get_client_connection(
29
+ self, user_pool_id: str, refresh_client: bool = False
30
+ ) -> Any:
31
+ """Get the client connection to cognito"""
32
+ region = user_pool_id.split("_")[0]
33
+ client = self.__client_connections.get(region)
34
+ if refresh_client:
35
+ client = None
36
+ if not client:
37
+ session = Boto3SessionManager(service_name="cognito-idp", aws_region=region)
38
+ client = session.client
39
+ # boto3.client("cognito-idp", region_name=region)
40
+ self.__client_connections[region] = client
41
+
42
+ return client
43
+
44
+ def generate_policy(
45
+ self, user_pools: str | List[str], event: Dict[str, Any]
46
+ ) -> Dict[str, Any]:
47
+ """Generates the policy for the authorizer"""
48
+
49
+ token = event["authorizationToken"]
50
+ user_pools = self.__to_list(user_pools=user_pools)
51
+ for user_pool_id in user_pools:
52
+ try:
53
+ if not user_pool_id:
54
+ continue
55
+ # up_id = self.__to_id(user_pool_id=user_pool_id)
56
+ # Decode the token, assuming RS256 (used by Cognito)
57
+ # decoded_token = self.decode_jwt(token=token, user_pool_id=up_id)
58
+ issuer = self.build_issuer_url(user_pool_id)
59
+ claims = self.decode_jwt(token, issuer)
60
+ # Token is valid, return an IAM policy
61
+ return self.__generate_policy_doc(
62
+ principal_id=claims["sub"],
63
+ effect="Allow",
64
+ method_arn=event["methodArn"],
65
+ )
66
+
67
+ except InvalidTokenError as e:
68
+ # Token is not valid for this user pool, try the next one
69
+ logger.debug(str(e))
70
+ continue
71
+ except Exception as e: # pylint: disable=w0718
72
+ logger.error(str(e))
73
+
74
+ # if we get here we deny it
75
+ return self.__generate_policy_doc(
76
+ principal_id="user",
77
+ effect="Deny",
78
+ method_arn=event["methodArn"],
79
+ )
80
+
81
+ def __generate_policy_doc(self, *, principal_id, effect, method_arn):
82
+ """Generate the policy doc"""
83
+ auth_response: Dict[str, Any] = {"principalId": principal_id}
84
+
85
+ if effect and method_arn:
86
+ policy_document = {
87
+ "Version": "2012-10-17",
88
+ "Statement": [
89
+ {
90
+ "Action": "execute-api:Invoke",
91
+ "Effect": effect,
92
+ "Resource": method_arn,
93
+ }
94
+ ],
95
+ }
96
+ auth_response["policyDocument"] = policy_document
97
+
98
+ return auth_response
99
+
100
+ def build_issuer_url(self, user_pool_id: str) -> str:
101
+ """Build the issuer URL"""
102
+
103
+ # Extract region from user pool ID format, e.g., "us-east-1_ABC123"
104
+ region = user_pool_id.split("_")[0]
105
+ return f"https://cognito-idp.{region}.amazonaws.com/{user_pool_id}"
106
+
107
+ def __to_list(self, user_pools: str | List[str]) -> List[str]:
108
+ if isinstance(user_pools, str):
109
+ user_pools = str(user_pools).replace(";", ",").replace(" ", "")
110
+ user_pools = str(user_pools).split(",")
111
+ elif isinstance(user_pools, list):
112
+ pass
113
+ else:
114
+ logger.warning(
115
+ f"Missing/ Invalid user pool: {user_pools}, type: {type(user_pools)}"
116
+ )
117
+
118
+ return user_pools
119
+
120
+ def parse_jwt(self, token: str) -> dict:
121
+ """Parse the JWT"""
122
+ if "Bearer" in token:
123
+ token = token.replace("Bearer ", "")
124
+
125
+ decoded_jwt: dict = jwt.decode(token, options={"verify_signature": False})
126
+
127
+ return decoded_jwt
128
+
129
+ def decode_jwt(self, token: str, issuer) -> dict:
130
+ """Decode the JWT"""
131
+ # Get the public keys
132
+ # Get the JWKS client
133
+ jwks_client = self.get_jwks_client(issuer)
134
+ if "Bearer" in token:
135
+ token = token.replace("Bearer ", "")
136
+ # Fetch the signing key using the PyJWKClient
137
+ signing_key = jwks_client.get_signing_key_from_jwt(token)
138
+
139
+ # Decode and verify the token
140
+ claims = jwt.decode(
141
+ token,
142
+ signing_key.key,
143
+ algorithms=["RS256"],
144
+ # audience=user_pool_id,
145
+ issuer=issuer,
146
+ options={"verify_aud": False}, # Disable audience verification
147
+ )
148
+
149
+ # Optional claim checks
150
+ if claims["token_use"] != "id":
151
+ # we are currently only using ID tokens
152
+ raise RuntimeError("Not an id token")
153
+
154
+ return claims
155
+
156
+ def get_jwks_client(self, issuer) -> PyJWKClient:
157
+ """Get the JWT Client"""
158
+ if (
159
+ issuer in jwks_cache.cache
160
+ and (time.time() - jwks_cache.cache.get(issuer, {})["timestamp"]) < 3600
161
+ ):
162
+ # Return cached JWKS client if it’s less than an hour old
163
+ return jwks_cache.cache[issuer]["client"]
164
+ else:
165
+ # Create a new PyJWKClient and cache it
166
+ jwks_url = f"{issuer}/.well-known/jwks.json"
167
+ jwks_client = PyJWKClient(jwks_url)
168
+ jwks_cache.cache[issuer] = {"client": jwks_client, "timestamp": time.time()}
169
+ return jwks_client