my-aws-helpers 3.1.0.dev4__tar.gz → 6.0.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {my_aws_helpers-3.1.0.dev4 → my_aws_helpers-6.0.4}/PKG-INFO +3 -1
- {my_aws_helpers-3.1.0.dev4 → my_aws_helpers-6.0.4}/my_aws_helpers/api.py +2 -2
- my_aws_helpers-6.0.4/my_aws_helpers/bedrock.py +362 -0
- my_aws_helpers-6.0.4/my_aws_helpers/cognito.py +141 -0
- {my_aws_helpers-3.1.0.dev4 → my_aws_helpers-6.0.4}/my_aws_helpers/dynamo.py +15 -11
- {my_aws_helpers-3.1.0.dev4 → my_aws_helpers-6.0.4}/my_aws_helpers/event.py +15 -13
- {my_aws_helpers-3.1.0.dev4 → my_aws_helpers-6.0.4}/my_aws_helpers/s3.py +157 -5
- {my_aws_helpers-3.1.0.dev4 → my_aws_helpers-6.0.4}/my_aws_helpers.egg-info/PKG-INFO +3 -1
- {my_aws_helpers-3.1.0.dev4 → my_aws_helpers-6.0.4}/my_aws_helpers.egg-info/SOURCES.txt +1 -0
- my_aws_helpers-6.0.4/my_aws_helpers.egg-info/requires.txt +15 -0
- {my_aws_helpers-3.1.0.dev4 → my_aws_helpers-6.0.4}/setup.py +22 -7
- my_aws_helpers-6.0.4/tests/test_cognito.py +20 -0
- {my_aws_helpers-3.1.0.dev4 → my_aws_helpers-6.0.4}/tests/test_event.py +5 -4
- my_aws_helpers-3.1.0.dev4/my_aws_helpers/bedrock.py +0 -217
- my_aws_helpers-3.1.0.dev4/my_aws_helpers/cognito.py +0 -144
- my_aws_helpers-3.1.0.dev4/my_aws_helpers.egg-info/requires.txt +0 -1
- {my_aws_helpers-3.1.0.dev4 → my_aws_helpers-6.0.4}/MANIFEST.in +0 -0
- {my_aws_helpers-3.1.0.dev4 → my_aws_helpers-6.0.4}/README.md +0 -0
- {my_aws_helpers-3.1.0.dev4 → my_aws_helpers-6.0.4}/my_aws_helpers/auth.py +0 -0
- {my_aws_helpers-3.1.0.dev4 → my_aws_helpers-6.0.4}/my_aws_helpers/errors.py +0 -0
- {my_aws_helpers-3.1.0.dev4 → my_aws_helpers-6.0.4}/my_aws_helpers/logging.py +0 -0
- {my_aws_helpers-3.1.0.dev4 → my_aws_helpers-6.0.4}/my_aws_helpers/prompts/__init__.py +0 -0
- {my_aws_helpers-3.1.0.dev4 → my_aws_helpers-6.0.4}/my_aws_helpers/prompts/markdown_system_prompt.txt +0 -0
- {my_aws_helpers-3.1.0.dev4 → my_aws_helpers-6.0.4}/my_aws_helpers/prompts/transactions_headers_prompt.txt +0 -0
- {my_aws_helpers-3.1.0.dev4 → my_aws_helpers-6.0.4}/my_aws_helpers/prompts/transactions_headers_prompt_v2.txt +0 -0
- {my_aws_helpers-3.1.0.dev4 → my_aws_helpers-6.0.4}/my_aws_helpers/prompts/transactions_prompt.txt +0 -0
- {my_aws_helpers-3.1.0.dev4 → my_aws_helpers-6.0.4}/my_aws_helpers/sfn.py +0 -0
- {my_aws_helpers-3.1.0.dev4 → my_aws_helpers-6.0.4}/my_aws_helpers.egg-info/dependency_links.txt +0 -0
- {my_aws_helpers-3.1.0.dev4 → my_aws_helpers-6.0.4}/my_aws_helpers.egg-info/top_level.txt +0 -0
- {my_aws_helpers-3.1.0.dev4 → my_aws_helpers-6.0.4}/my_aws_helpers.egg-info/zip-safe +0 -0
- {my_aws_helpers-3.1.0.dev4 → my_aws_helpers-6.0.4}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: my_aws_helpers
|
|
3
|
-
Version:
|
|
3
|
+
Version: 6.0.4
|
|
4
4
|
Summary: AWS Helpers
|
|
5
5
|
Home-page: https://github.com/JarrodMccarthy/aws_helpers.git
|
|
6
6
|
Author: Jarrod McCarthy
|
|
@@ -11,6 +11,8 @@ Classifier: Programming Language :: Python
|
|
|
11
11
|
Classifier: Programming Language :: Python :: 3.8
|
|
12
12
|
Classifier: Programming Language :: Python :: 3.9
|
|
13
13
|
Classifier: Programming Language :: Python :: 3.10
|
|
14
|
+
Provides-Extra: all
|
|
15
|
+
Provides-Extra: bedrock
|
|
14
16
|
|
|
15
17
|
UNKNOWN
|
|
16
18
|
|
|
@@ -10,7 +10,7 @@ class API:
|
|
|
10
10
|
if isinstance(obj, datetime) or isinstance(obj, date):
|
|
11
11
|
return obj.isoformat()
|
|
12
12
|
return obj
|
|
13
|
-
|
|
13
|
+
|
|
14
14
|
def response_serialiser(response: Any):
|
|
15
15
|
if isinstance(response, list):
|
|
16
16
|
return [API.response_serialiser(obj) for obj in response]
|
|
@@ -18,7 +18,7 @@ class API:
|
|
|
18
18
|
for k, v in response.items():
|
|
19
19
|
response[k] = API.response_serialiser(v)
|
|
20
20
|
return response
|
|
21
|
-
return API._serialise(response)
|
|
21
|
+
return API._serialise(response)
|
|
22
22
|
|
|
23
23
|
def response(code: int, body: Optional[str] = None):
|
|
24
24
|
return {
|
|
@@ -0,0 +1,362 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import boto3
|
|
3
|
+
from botocore.config import Config
|
|
4
|
+
import json
|
|
5
|
+
import time
|
|
6
|
+
import os
|
|
7
|
+
import io
|
|
8
|
+
import re
|
|
9
|
+
from copy import copy
|
|
10
|
+
from typing import Optional, List, Dict
|
|
11
|
+
from enum import Enum
|
|
12
|
+
import pymupdf
|
|
13
|
+
import concurrent.futures
|
|
14
|
+
from dataclasses import dataclass
|
|
15
|
+
from my_aws_helpers.s3 import S3Serialiser, BaseS3Object, BaseS3Queries, S3, S3Location
|
|
16
|
+
|
|
17
|
+
from my_aws_helpers.logging import select_powertools_logger
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
logger = select_powertools_logger("aws-helpers-bedrock")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ImageType(str, Enum):
|
|
24
|
+
gif = "gif"
|
|
25
|
+
jpeg = "jpeg"
|
|
26
|
+
png = "png"
|
|
27
|
+
webp = "webp"
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class PromptType(str, Enum):
|
|
31
|
+
transaction_headers = "transactions_headers_prompt_v2.txt"
|
|
32
|
+
transactions = "transactions_prompt.txt"
|
|
33
|
+
# json = "json_system_prompt.txt"
|
|
34
|
+
markdown = "markdown_system_prompt.txt"
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass
|
|
38
|
+
class TokenUsage:
|
|
39
|
+
input_tokens: int
|
|
40
|
+
output_tokens: int
|
|
41
|
+
total_tokens: int
|
|
42
|
+
|
|
43
|
+
@classmethod
|
|
44
|
+
def from_dict(cls, data: Dict[str, int]) -> TokenUsage:
|
|
45
|
+
return cls(
|
|
46
|
+
input_tokens=data.get("inputTokens", 0),
|
|
47
|
+
output_tokens=data.get("outputTokens", 0),
|
|
48
|
+
total_tokens=data.get("totalTokens", 0),
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@dataclass
|
|
53
|
+
class OCRResult(BaseS3Object):
|
|
54
|
+
content: List[Dict[str, str]]
|
|
55
|
+
token_usage: TokenUsage
|
|
56
|
+
page_number: int
|
|
57
|
+
|
|
58
|
+
@classmethod
|
|
59
|
+
def from_dict(cls, data: Dict) -> OCRResult:
|
|
60
|
+
return cls(
|
|
61
|
+
content=data.get("content", []),
|
|
62
|
+
token_usage=TokenUsage.from_dict(data.get("token_usage", {})),
|
|
63
|
+
page_number=data.get("page_number", 0),
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
@classmethod
|
|
67
|
+
def from_s3_representation(cls, obj: dict) -> OCRResult:
|
|
68
|
+
obj["token_usage"] = (TokenUsage.from_dict(obj.get("token_usage", {})),)
|
|
69
|
+
return cls(**obj)
|
|
70
|
+
|
|
71
|
+
def to_s3_representation(self) -> dict:
|
|
72
|
+
obj = copy(vars(self))
|
|
73
|
+
obj["token_usage"] = S3Serialiser.object_serialiser(
|
|
74
|
+
obj=vars(obj["token_usage"])
|
|
75
|
+
)
|
|
76
|
+
return S3Serialiser.object_serialiser(obj=obj)
|
|
77
|
+
|
|
78
|
+
def get_save_location(self, bucket_name: str) -> S3Location:
|
|
79
|
+
pass
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class OCRResultQueries(BaseS3Queries):
|
|
83
|
+
def __init__(self, s3_client: S3, bucket_name: str):
|
|
84
|
+
super().__init__(s3_client=s3_client, bucket_name=bucket_name)
|
|
85
|
+
|
|
86
|
+
def save_ocr_result_to_s3(
|
|
87
|
+
self, ocr_result: OCRResult, save_location: S3Location
|
|
88
|
+
) -> Optional[S3Location]:
|
|
89
|
+
try:
|
|
90
|
+
obj = ocr_result.to_s3_representation()
|
|
91
|
+
return self.s3_client.save_dict_to_s3(
|
|
92
|
+
content=obj,
|
|
93
|
+
s3_location=save_location,
|
|
94
|
+
)
|
|
95
|
+
except Exception as e:
|
|
96
|
+
logger.exception(f"Failed to save ocr result to s3 due to {e}")
|
|
97
|
+
return None
|
|
98
|
+
|
|
99
|
+
def _concurrent_s3_read(
|
|
100
|
+
self, locations: List[S3Location], max_workers: int = 10
|
|
101
|
+
) -> List[OCRResult]:
|
|
102
|
+
results: List[OCRResult] = list()
|
|
103
|
+
futures = list()
|
|
104
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
105
|
+
for loc in locations:
|
|
106
|
+
future = executor.submit(
|
|
107
|
+
self.s3_client.read_dict_from_s3,
|
|
108
|
+
s3_location=loc,
|
|
109
|
+
)
|
|
110
|
+
futures.append(future)
|
|
111
|
+
for f in futures:
|
|
112
|
+
results.append(f.result())
|
|
113
|
+
results = [r for r in results if r is not None]
|
|
114
|
+
return results
|
|
115
|
+
|
|
116
|
+
def get_ocr_results_by_key_prefix(self, prefix: str) -> List[OCRResult]:
|
|
117
|
+
locations = self.s3_client.list_objects_by_prefix(
|
|
118
|
+
bucket_name=self.bucket_name, prefix=prefix
|
|
119
|
+
)
|
|
120
|
+
objects = self._concurrent_s3_read(locations=locations)
|
|
121
|
+
ocr_results = [OCRResult.from_s3_representation(obj=obj) for obj in objects]
|
|
122
|
+
return ocr_results
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class Bedrock:
|
|
126
|
+
def __init__(
|
|
127
|
+
self,
|
|
128
|
+
model_id: str = "apac.anthropic.claude-3-5-sonnet-20241022-v2:0", # anthropic.claude-sonnet-4-20250514-v1:0
|
|
129
|
+
logger=None,
|
|
130
|
+
sleep_time: float = 1.0,
|
|
131
|
+
):
|
|
132
|
+
|
|
133
|
+
self.session = Bedrock._set_session_params()
|
|
134
|
+
self.logger = logger
|
|
135
|
+
region_name = "ap-southeast-2"
|
|
136
|
+
if self.session is None:
|
|
137
|
+
self.session = boto3.Session(region_name=region_name)
|
|
138
|
+
self.sleep_time = sleep_time
|
|
139
|
+
|
|
140
|
+
custom_config = Config(
|
|
141
|
+
retries={
|
|
142
|
+
"max_attempts": 2, # Total attempts = 1 initial + 1 retry
|
|
143
|
+
"mode": "standard", # or 'adaptive'
|
|
144
|
+
}
|
|
145
|
+
)
|
|
146
|
+
self.client = self.session.client(
|
|
147
|
+
"bedrock-runtime", region_name=region_name, config=custom_config
|
|
148
|
+
)
|
|
149
|
+
self.model_id = model_id
|
|
150
|
+
|
|
151
|
+
@staticmethod
|
|
152
|
+
def _set_session_params():
|
|
153
|
+
try:
|
|
154
|
+
aws_access_key_id = os.environ["AWS_ACCESS_KEY_ID"]
|
|
155
|
+
aws_secret_access_key = os.environ["AWS_SECRET_ACCESS_KEY"]
|
|
156
|
+
aws_session_token = os.environ["AWS_SESSION_TOKEN"]
|
|
157
|
+
region_name = os.environ["AWS_DEFAULT_REGION"]
|
|
158
|
+
return boto3.Session(
|
|
159
|
+
aws_access_key_id=aws_access_key_id,
|
|
160
|
+
aws_secret_access_key=aws_secret_access_key,
|
|
161
|
+
aws_session_token=aws_session_token,
|
|
162
|
+
region_name=region_name,
|
|
163
|
+
)
|
|
164
|
+
except Exception as e:
|
|
165
|
+
print(e)
|
|
166
|
+
return None
|
|
167
|
+
|
|
168
|
+
@staticmethod
|
|
169
|
+
def extract_json_from_markdown(text: str):
|
|
170
|
+
"""
|
|
171
|
+
Extracts the JSON object from a string that may be wrapped in ```json ... ``` code block
|
|
172
|
+
"""
|
|
173
|
+
# Match a {...} block anywhere in the text
|
|
174
|
+
match = re.search(r"\{.*\}", text, re.DOTALL)
|
|
175
|
+
if match:
|
|
176
|
+
json_str = match.group(0)
|
|
177
|
+
return json.loads(json_str)
|
|
178
|
+
else:
|
|
179
|
+
raise ValueError("No JSON object found in the text")
|
|
180
|
+
|
|
181
|
+
def _get_prompt(self, prompt_type: str) -> Optional[str]:
|
|
182
|
+
if prompt_type not in list(PromptType):
|
|
183
|
+
raise Exception(f"Error: Invalid prompt type")
|
|
184
|
+
|
|
185
|
+
path = os.path.join(os.path.dirname(__file__), "prompts", prompt_type)
|
|
186
|
+
try:
|
|
187
|
+
with open(path, "r") as f:
|
|
188
|
+
prompt = f.read()
|
|
189
|
+
return prompt
|
|
190
|
+
except Exception as e:
|
|
191
|
+
self.logger.exception(f"Failed to get {prompt_type} prompt due to {e}")
|
|
192
|
+
return None
|
|
193
|
+
|
|
194
|
+
def _ocr(
|
|
195
|
+
self, prompt: str, image_bytes: bytes, page_number: Optional[int] = 0
|
|
196
|
+
) -> Optional[OCRResult]:
|
|
197
|
+
system_prompt = [{"text": prompt}]
|
|
198
|
+
message = [
|
|
199
|
+
{
|
|
200
|
+
"role": "user",
|
|
201
|
+
"content": [
|
|
202
|
+
{
|
|
203
|
+
"image": {
|
|
204
|
+
"format": "png",
|
|
205
|
+
"source": {
|
|
206
|
+
"bytes": image_bytes,
|
|
207
|
+
},
|
|
208
|
+
}
|
|
209
|
+
}
|
|
210
|
+
],
|
|
211
|
+
}
|
|
212
|
+
]
|
|
213
|
+
retries = 3
|
|
214
|
+
for i in range(retries):
|
|
215
|
+
self.logger.info(f"Attempt number {i} for {self.model_id} converse")
|
|
216
|
+
try:
|
|
217
|
+
response = self.client.converse(
|
|
218
|
+
modelId=self.model_id, messages=message, system=system_prompt
|
|
219
|
+
)
|
|
220
|
+
if response["ResponseMetadata"]["HTTPStatusCode"] == 200:
|
|
221
|
+
break
|
|
222
|
+
except Exception as e:
|
|
223
|
+
self.logger.exception(f"Error during conversation due to {e}")
|
|
224
|
+
if i >= len(retries) - 1:
|
|
225
|
+
raise Exception(e)
|
|
226
|
+
time.sleep(self.sleep_time)
|
|
227
|
+
continue
|
|
228
|
+
|
|
229
|
+
result = {}
|
|
230
|
+
result["content"] = json.loads(
|
|
231
|
+
response["output"]["message"]["content"][0]["text"]
|
|
232
|
+
)
|
|
233
|
+
result["token_usage"] = response["usage"]
|
|
234
|
+
result["page_number"] = page_number
|
|
235
|
+
return OCRResult.from_dict(data=result)
|
|
236
|
+
|
|
237
|
+
def _parallel_ocr(
|
|
238
|
+
self,
|
|
239
|
+
image_bytes_list: List[bytes],
|
|
240
|
+
prompt: str,
|
|
241
|
+
max_workers: int = 10,
|
|
242
|
+
):
|
|
243
|
+
execution_futures = []
|
|
244
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
245
|
+
for i, img in enumerate(image_bytes_list):
|
|
246
|
+
self.logger.info(f"Starting OCR for page: {i}")
|
|
247
|
+
time.sleep(self.sleep_time) # Stagger start time
|
|
248
|
+
future = executor.submit(self._ocr, prompt=prompt, image_bytes=img)
|
|
249
|
+
execution_futures.append(future)
|
|
250
|
+
|
|
251
|
+
# Wait for all tasks and collect results in order of submission
|
|
252
|
+
results = [
|
|
253
|
+
future.result()
|
|
254
|
+
for future in execution_futures
|
|
255
|
+
if future.result() is not None
|
|
256
|
+
]
|
|
257
|
+
return results
|
|
258
|
+
|
|
259
|
+
def get_ocr_result(
|
|
260
|
+
self,
|
|
261
|
+
pdf_bytes: io.BytesIO,
|
|
262
|
+
prompt_type: str,
|
|
263
|
+
zoom: int = 7,
|
|
264
|
+
) -> List[OCRResult]:
|
|
265
|
+
try:
|
|
266
|
+
self.logger.info("Getting OCR Results")
|
|
267
|
+
document = pymupdf.open(stream=pdf_bytes, filetype="pdf")
|
|
268
|
+
pages: List[pymupdf.Page] = [p for p in document]
|
|
269
|
+
|
|
270
|
+
image_bytes_list: List[bytes] = list()
|
|
271
|
+
for i, p in enumerate(pages):
|
|
272
|
+
try:
|
|
273
|
+
image_bytes: bytes = p.get_pixmap(
|
|
274
|
+
matrix=pymupdf.Matrix(zoom, zoom)
|
|
275
|
+
).tobytes("png")
|
|
276
|
+
image_bytes_list.append(image_bytes)
|
|
277
|
+
except Exception as e:
|
|
278
|
+
self.logger.error(f"Could not get pix map for page {i}")
|
|
279
|
+
continue
|
|
280
|
+
skip_page_zero = False
|
|
281
|
+
header_ocr_result = None
|
|
282
|
+
if len(image_bytes_list) > 1:
|
|
283
|
+
headers_prompt = self._get_prompt(
|
|
284
|
+
prompt_type=PromptType.transaction_headers.value
|
|
285
|
+
)
|
|
286
|
+
for i in range(2):
|
|
287
|
+
# try to get headers from the first or second page
|
|
288
|
+
header_ocr_result = self._ocr(
|
|
289
|
+
prompt=headers_prompt, image_bytes=image_bytes_list[i]
|
|
290
|
+
)
|
|
291
|
+
if header_ocr_result is None:
|
|
292
|
+
self.logger.info(
|
|
293
|
+
f"No ocr result returned when getting headers {PromptType.transaction_headers.value}"
|
|
294
|
+
)
|
|
295
|
+
headers = header_ocr_result.content.get("headers")
|
|
296
|
+
if (len(headers) < 1) or (headers is None):
|
|
297
|
+
skip_page_zero = True
|
|
298
|
+
continue
|
|
299
|
+
else:
|
|
300
|
+
break
|
|
301
|
+
|
|
302
|
+
transactions_prompt = self._get_prompt(prompt_type=prompt_type)
|
|
303
|
+
if header_ocr_result:
|
|
304
|
+
transactions_prompt = transactions_prompt.replace(
|
|
305
|
+
"#### TABLE HEADERS ####", json.dumps(header_ocr_result.content)
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
self.logger.info("Got Prompt")
|
|
309
|
+
results = list()
|
|
310
|
+
|
|
311
|
+
if skip_page_zero:
|
|
312
|
+
image_bytes_list = image_bytes_list[
|
|
313
|
+
1:
|
|
314
|
+
] # page zero often has account summary info
|
|
315
|
+
results = self._parallel_ocr(
|
|
316
|
+
image_bytes_list=image_bytes_list, prompt=transactions_prompt
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
# for i, image_bytes in enumerate(image_bytes_list):
|
|
320
|
+
# self.logger.info(f"Starting OCR for page: {i}")
|
|
321
|
+
# results.append(self._ocr(image_bytes=image_bytes, prompt=transactions_prompt))
|
|
322
|
+
return results
|
|
323
|
+
except Exception as e:
|
|
324
|
+
self.logger.exception(e)
|
|
325
|
+
return []
|
|
326
|
+
|
|
327
|
+
def _get_image_block(self, image: bytes, image_content_type: ImageType) -> dict:
|
|
328
|
+
return {
|
|
329
|
+
"image": {
|
|
330
|
+
"format": image_content_type,
|
|
331
|
+
"source": {
|
|
332
|
+
"bytes": image,
|
|
333
|
+
},
|
|
334
|
+
}
|
|
335
|
+
}
|
|
336
|
+
|
|
337
|
+
def image_analysis(
|
|
338
|
+
self, images: List[bytes], prompt: str, image_content_type: ImageType
|
|
339
|
+
) -> OCRResult:
|
|
340
|
+
|
|
341
|
+
system_prompt = [{"text": prompt}]
|
|
342
|
+
message = [
|
|
343
|
+
{
|
|
344
|
+
"role": "user",
|
|
345
|
+
"content": [
|
|
346
|
+
self._get_image_block(
|
|
347
|
+
image=image, image_content_type=image_content_type
|
|
348
|
+
)
|
|
349
|
+
for image in images
|
|
350
|
+
],
|
|
351
|
+
}
|
|
352
|
+
]
|
|
353
|
+
response = self.client.converse(
|
|
354
|
+
modelId=self.model_id, messages=message, system=system_prompt
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
result = {}
|
|
358
|
+
result["content"] = Bedrock.extract_json_from_markdown(
|
|
359
|
+
text=response["output"]["message"]["content"][0]["text"]
|
|
360
|
+
)
|
|
361
|
+
result["token_usage"] = response["usage"]
|
|
362
|
+
return OCRResult.from_dict(data=result)
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import json
|
|
3
|
+
from typing import Optional
|
|
4
|
+
import boto3
|
|
5
|
+
import urllib.request
|
|
6
|
+
from jose import jwk, jwt
|
|
7
|
+
from jose.utils import base64url_decode
|
|
8
|
+
from my_aws_helpers.logging import select_powertools_logger
|
|
9
|
+
|
|
10
|
+
logger = select_powertools_logger("aws-helpers-cognito")
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Cognito:
|
|
14
|
+
client: boto3.client
|
|
15
|
+
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
client: Optional[boto3.client] = None,
|
|
19
|
+
user_pool_id: Optional[str] = None,
|
|
20
|
+
):
|
|
21
|
+
self.cognito_user_pool_id = (
|
|
22
|
+
user_pool_id if user_pool_id else os.environ["COGNITO_USER_POOL_ID"]
|
|
23
|
+
)
|
|
24
|
+
self.region = os.environ.get("AWS_DEFAULT_REGION", "ap-southeast-2")
|
|
25
|
+
self.client = client if client else self._get_client(region=self.region)
|
|
26
|
+
|
|
27
|
+
def _get_client(self, region: str) -> boto3.client:
|
|
28
|
+
return boto3.client("cognito-idp", region_name=region)
|
|
29
|
+
|
|
30
|
+
def _verify_signature(self):
|
|
31
|
+
pass
|
|
32
|
+
|
|
33
|
+
def _verify_audience(self):
|
|
34
|
+
pass
|
|
35
|
+
|
|
36
|
+
def _verify_token_use(self):
|
|
37
|
+
pass
|
|
38
|
+
|
|
39
|
+
def validate_token(self, token: str) -> bool:
|
|
40
|
+
try:
|
|
41
|
+
if "Bearer" in token:
|
|
42
|
+
token = token.split(" ")[-1]
|
|
43
|
+
headers = jwt.get_unverified_headers(token)
|
|
44
|
+
key_id = headers["kid"]
|
|
45
|
+
keys = self._get_keys()
|
|
46
|
+
key = next(k for k in keys if k["kid"] == key_id)
|
|
47
|
+
public_key = jwk.construct(key)
|
|
48
|
+
message, encoded_signature = token.rsplit(".", 1)
|
|
49
|
+
decoded_signature = base64url_decode(encoded_signature.encode("utf-8"))
|
|
50
|
+
if not public_key.verify(message.encode("utf-8"), decoded_signature):
|
|
51
|
+
raise Exception("Signature verification failed")
|
|
52
|
+
return True
|
|
53
|
+
except Exception as e:
|
|
54
|
+
logger.exception(f"Failed to validate token due to {e}")
|
|
55
|
+
return False
|
|
56
|
+
|
|
57
|
+
def sign_up(self, username: str, password: str, app_client_id: str):
|
|
58
|
+
try:
|
|
59
|
+
response = self.client.sign_up(
|
|
60
|
+
ClientId=app_client_id,
|
|
61
|
+
Username=username,
|
|
62
|
+
Password=password,
|
|
63
|
+
UserAttributes=[
|
|
64
|
+
{"Name": "email", "Value": username},
|
|
65
|
+
],
|
|
66
|
+
)
|
|
67
|
+
return response
|
|
68
|
+
except Exception as e:
|
|
69
|
+
logger.exception(f"Failed to sign up due to {e}")
|
|
70
|
+
return None
|
|
71
|
+
|
|
72
|
+
def confirm_sign_up(
|
|
73
|
+
self,
|
|
74
|
+
username: str,
|
|
75
|
+
client_id: str,
|
|
76
|
+
confirmation_code: str,
|
|
77
|
+
) -> dict:
|
|
78
|
+
try:
|
|
79
|
+
response = self.client.confirm_sign_up(
|
|
80
|
+
ClientId=client_id, Username=username, ConfirmationCode=confirmation_code,
|
|
81
|
+
)
|
|
82
|
+
return response
|
|
83
|
+
except Exception as e:
|
|
84
|
+
logger.exception(
|
|
85
|
+
f"Failed to confirm sign up username {username} due to {e}"
|
|
86
|
+
)
|
|
87
|
+
return None
|
|
88
|
+
|
|
89
|
+
def admin_confirm_sign_up(
|
|
90
|
+
self,
|
|
91
|
+
username: str,
|
|
92
|
+
user_pool_id: str,
|
|
93
|
+
) -> dict:
|
|
94
|
+
try:
|
|
95
|
+
response = self.client.admin_confirm_sign_up(
|
|
96
|
+
UserPoolId=user_pool_id, Username=username
|
|
97
|
+
)
|
|
98
|
+
return response
|
|
99
|
+
except Exception as e:
|
|
100
|
+
logger.exception(
|
|
101
|
+
f"Failed to confirm admin sign up username {username} due to {e}"
|
|
102
|
+
)
|
|
103
|
+
return None
|
|
104
|
+
|
|
105
|
+
def refresh_token(self, refresh_token: str, app_client_id: str):
|
|
106
|
+
response = self.client.initiate_auth(
|
|
107
|
+
ClientId=app_client_id,
|
|
108
|
+
AuthFlow="REFRESH_TOKEN_AUTH",
|
|
109
|
+
AuthParameters={"REFRESH_TOKEN": refresh_token},
|
|
110
|
+
)
|
|
111
|
+
return response
|
|
112
|
+
|
|
113
|
+
def login(self, username: str, password: str, app_client_id: str):
|
|
114
|
+
response = self.client.initiate_auth(
|
|
115
|
+
ClientId=app_client_id,
|
|
116
|
+
AuthFlow="USER_PASSWORD_AUTH",
|
|
117
|
+
AuthParameters={"USERNAME": username, "PASSWORD": password},
|
|
118
|
+
)
|
|
119
|
+
return response
|
|
120
|
+
|
|
121
|
+
def _get_issuer(self) -> str:
|
|
122
|
+
return f"https://cognito-idp.{self.region}.amazonaws.com/{self.cognito_user_pool_id}"
|
|
123
|
+
|
|
124
|
+
def _get_keys(self):
|
|
125
|
+
issuer = self._get_issuer()
|
|
126
|
+
keys_url = f"{issuer}/.well-known/jwks.json"
|
|
127
|
+
with urllib.request.urlopen(keys_url) as f:
|
|
128
|
+
response = f.read()
|
|
129
|
+
keys = json.loads(response.decode("utf-8"))["keys"]
|
|
130
|
+
return keys
|
|
131
|
+
|
|
132
|
+
@staticmethod
|
|
133
|
+
def get_policy(allow: bool, method_arn: str = "*") -> dict:
|
|
134
|
+
allow = "Allow" if allow else "Deny"
|
|
135
|
+
return {
|
|
136
|
+
"principalId": "authenticated-user",
|
|
137
|
+
"policyDocument": {
|
|
138
|
+
"Version": "2012-10-17",
|
|
139
|
+
"Statement": [{"Action": "*", "Effect": allow, "Resource": "*"}],
|
|
140
|
+
},
|
|
141
|
+
}
|
|
@@ -2,12 +2,13 @@ from typing import List, Any, Optional
|
|
|
2
2
|
from datetime import datetime, date
|
|
3
3
|
import boto3
|
|
4
4
|
from abc import ABC, abstractmethod
|
|
5
|
-
from decimal import Decimal, Context
|
|
5
|
+
from decimal import Decimal, Context
|
|
6
6
|
|
|
7
7
|
from my_aws_helpers.logging import select_powertools_logger
|
|
8
8
|
|
|
9
9
|
logger = select_powertools_logger("aws-helpers-dynamo")
|
|
10
10
|
|
|
11
|
+
|
|
11
12
|
class MetaData:
|
|
12
13
|
"""
|
|
13
14
|
This class is a convenience class,
|
|
@@ -76,7 +77,6 @@ class BaseTableObject:
|
|
|
76
77
|
|
|
77
78
|
def __init__(self) -> None:
|
|
78
79
|
pass
|
|
79
|
-
|
|
80
80
|
|
|
81
81
|
|
|
82
82
|
class DynamoSerialiser:
|
|
@@ -86,10 +86,10 @@ class DynamoSerialiser:
|
|
|
86
86
|
if isinstance(obj, datetime) or isinstance(obj, date):
|
|
87
87
|
return obj.isoformat()
|
|
88
88
|
if isinstance(obj, float):
|
|
89
|
-
ctx = Context(prec
|
|
89
|
+
ctx = Context(prec=38)
|
|
90
90
|
return ctx.create_decimal_from_float(obj)
|
|
91
91
|
return obj
|
|
92
|
-
|
|
92
|
+
|
|
93
93
|
@staticmethod
|
|
94
94
|
def object_serialiser(obj: Any):
|
|
95
95
|
if isinstance(obj, list):
|
|
@@ -98,6 +98,7 @@ class DynamoSerialiser:
|
|
|
98
98
|
return {k: DynamoSerialiser.object_serialiser(v) for k, v in obj.items()}
|
|
99
99
|
return DynamoSerialiser._serialise(obj=obj)
|
|
100
100
|
|
|
101
|
+
|
|
101
102
|
class Dynamo:
|
|
102
103
|
table: boto3.resource
|
|
103
104
|
|
|
@@ -112,7 +113,7 @@ class Dynamo:
|
|
|
112
113
|
return self.table.get_item(Item=item)
|
|
113
114
|
|
|
114
115
|
def delete_item(self, item: dict):
|
|
115
|
-
return self.table.delete_item(
|
|
116
|
+
return self.table.delete_item(Key=item)
|
|
116
117
|
|
|
117
118
|
def batch_put(self, items: List[dict]) -> None:
|
|
118
119
|
with self.table.batch_writer() as batch:
|
|
@@ -130,14 +131,16 @@ class Dynamo:
|
|
|
130
131
|
response = self.table.scan()
|
|
131
132
|
items: List = response["Items"]
|
|
132
133
|
while response.get("LastEvaluatedKey") is not None:
|
|
133
|
-
response = self.table.scan(ExclusiveStartKey
|
|
134
|
+
response = self.table.scan(ExclusiveStartKey=response["LastEvaluatedKey"])
|
|
134
135
|
if response.get("Items") is not None:
|
|
135
136
|
items.extend(response["Items"])
|
|
136
137
|
if response.get("LastEvaluatedKey") is None:
|
|
137
138
|
break
|
|
138
139
|
return items
|
|
139
140
|
|
|
140
|
-
def delete_table_items(
|
|
141
|
+
def delete_table_items(
|
|
142
|
+
self, partition_key_name: str = "pk", sort_key_name: str = "sk"
|
|
143
|
+
) -> bool:
|
|
141
144
|
try:
|
|
142
145
|
items = self._deep_scan()
|
|
143
146
|
delete_repr_items = [
|
|
@@ -147,10 +150,10 @@ class Dynamo:
|
|
|
147
150
|
}
|
|
148
151
|
for item in items
|
|
149
152
|
]
|
|
150
|
-
self.batch_delete(items
|
|
153
|
+
self.batch_delete(items=delete_repr_items)
|
|
151
154
|
return True
|
|
152
155
|
except Exception as e:
|
|
153
|
-
logger.exception(f
|
|
156
|
+
logger.exception(f"Failed to delete table items due to {e}")
|
|
154
157
|
return False
|
|
155
158
|
|
|
156
159
|
def to_dynamo_representation(obj: dict):
|
|
@@ -177,6 +180,7 @@ def _datatype_map(value: Any):
|
|
|
177
180
|
return new_obj
|
|
178
181
|
return value
|
|
179
182
|
|
|
183
|
+
|
|
180
184
|
class BaseQueries(ABC):
|
|
181
185
|
table_name: str
|
|
182
186
|
|
|
@@ -186,7 +190,7 @@ class BaseQueries(ABC):
|
|
|
186
190
|
|
|
187
191
|
def _get_client(self):
|
|
188
192
|
return Dynamo(table_name=self.table_name)
|
|
189
|
-
|
|
193
|
+
|
|
190
194
|
def _iterative_query(self, query_kwargs: dict) -> List[dict]:
|
|
191
195
|
results = list()
|
|
192
196
|
last_evaluated_key = "not none"
|
|
@@ -198,4 +202,4 @@ class BaseQueries(ABC):
|
|
|
198
202
|
results += result["Items"]
|
|
199
203
|
last_evaluated_key = result.get("LastEvaluatedKey")
|
|
200
204
|
exclusive_start_key = last_evaluated_key
|
|
201
|
-
return results
|
|
205
|
+
return results
|