my-aws-helpers 3.1.0.dev4__tar.gz → 6.0.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. {my_aws_helpers-3.1.0.dev4 → my_aws_helpers-6.0.4}/PKG-INFO +3 -1
  2. {my_aws_helpers-3.1.0.dev4 → my_aws_helpers-6.0.4}/my_aws_helpers/api.py +2 -2
  3. my_aws_helpers-6.0.4/my_aws_helpers/bedrock.py +362 -0
  4. my_aws_helpers-6.0.4/my_aws_helpers/cognito.py +141 -0
  5. {my_aws_helpers-3.1.0.dev4 → my_aws_helpers-6.0.4}/my_aws_helpers/dynamo.py +15 -11
  6. {my_aws_helpers-3.1.0.dev4 → my_aws_helpers-6.0.4}/my_aws_helpers/event.py +15 -13
  7. {my_aws_helpers-3.1.0.dev4 → my_aws_helpers-6.0.4}/my_aws_helpers/s3.py +157 -5
  8. {my_aws_helpers-3.1.0.dev4 → my_aws_helpers-6.0.4}/my_aws_helpers.egg-info/PKG-INFO +3 -1
  9. {my_aws_helpers-3.1.0.dev4 → my_aws_helpers-6.0.4}/my_aws_helpers.egg-info/SOURCES.txt +1 -0
  10. my_aws_helpers-6.0.4/my_aws_helpers.egg-info/requires.txt +15 -0
  11. {my_aws_helpers-3.1.0.dev4 → my_aws_helpers-6.0.4}/setup.py +22 -7
  12. my_aws_helpers-6.0.4/tests/test_cognito.py +20 -0
  13. {my_aws_helpers-3.1.0.dev4 → my_aws_helpers-6.0.4}/tests/test_event.py +5 -4
  14. my_aws_helpers-3.1.0.dev4/my_aws_helpers/bedrock.py +0 -217
  15. my_aws_helpers-3.1.0.dev4/my_aws_helpers/cognito.py +0 -144
  16. my_aws_helpers-3.1.0.dev4/my_aws_helpers.egg-info/requires.txt +0 -1
  17. {my_aws_helpers-3.1.0.dev4 → my_aws_helpers-6.0.4}/MANIFEST.in +0 -0
  18. {my_aws_helpers-3.1.0.dev4 → my_aws_helpers-6.0.4}/README.md +0 -0
  19. {my_aws_helpers-3.1.0.dev4 → my_aws_helpers-6.0.4}/my_aws_helpers/auth.py +0 -0
  20. {my_aws_helpers-3.1.0.dev4 → my_aws_helpers-6.0.4}/my_aws_helpers/errors.py +0 -0
  21. {my_aws_helpers-3.1.0.dev4 → my_aws_helpers-6.0.4}/my_aws_helpers/logging.py +0 -0
  22. {my_aws_helpers-3.1.0.dev4 → my_aws_helpers-6.0.4}/my_aws_helpers/prompts/__init__.py +0 -0
  23. {my_aws_helpers-3.1.0.dev4 → my_aws_helpers-6.0.4}/my_aws_helpers/prompts/markdown_system_prompt.txt +0 -0
  24. {my_aws_helpers-3.1.0.dev4 → my_aws_helpers-6.0.4}/my_aws_helpers/prompts/transactions_headers_prompt.txt +0 -0
  25. {my_aws_helpers-3.1.0.dev4 → my_aws_helpers-6.0.4}/my_aws_helpers/prompts/transactions_headers_prompt_v2.txt +0 -0
  26. {my_aws_helpers-3.1.0.dev4 → my_aws_helpers-6.0.4}/my_aws_helpers/prompts/transactions_prompt.txt +0 -0
  27. {my_aws_helpers-3.1.0.dev4 → my_aws_helpers-6.0.4}/my_aws_helpers/sfn.py +0 -0
  28. {my_aws_helpers-3.1.0.dev4 → my_aws_helpers-6.0.4}/my_aws_helpers.egg-info/dependency_links.txt +0 -0
  29. {my_aws_helpers-3.1.0.dev4 → my_aws_helpers-6.0.4}/my_aws_helpers.egg-info/top_level.txt +0 -0
  30. {my_aws_helpers-3.1.0.dev4 → my_aws_helpers-6.0.4}/my_aws_helpers.egg-info/zip-safe +0 -0
  31. {my_aws_helpers-3.1.0.dev4 → my_aws_helpers-6.0.4}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: my_aws_helpers
3
- Version: 3.1.0.dev4
3
+ Version: 6.0.4
4
4
  Summary: AWS Helpers
5
5
  Home-page: https://github.com/JarrodMccarthy/aws_helpers.git
6
6
  Author: Jarrod McCarthy
@@ -11,6 +11,8 @@ Classifier: Programming Language :: Python
11
11
  Classifier: Programming Language :: Python :: 3.8
12
12
  Classifier: Programming Language :: Python :: 3.9
13
13
  Classifier: Programming Language :: Python :: 3.10
14
+ Provides-Extra: all
15
+ Provides-Extra: bedrock
14
16
 
15
17
  UNKNOWN
16
18
 
@@ -10,7 +10,7 @@ class API:
10
10
  if isinstance(obj, datetime) or isinstance(obj, date):
11
11
  return obj.isoformat()
12
12
  return obj
13
-
13
+
14
14
  def response_serialiser(response: Any):
15
15
  if isinstance(response, list):
16
16
  return [API.response_serialiser(obj) for obj in response]
@@ -18,7 +18,7 @@ class API:
18
18
  for k, v in response.items():
19
19
  response[k] = API.response_serialiser(v)
20
20
  return response
21
- return API._serialise(response)
21
+ return API._serialise(response)
22
22
 
23
23
  def response(code: int, body: Optional[str] = None):
24
24
  return {
@@ -0,0 +1,362 @@
1
+ from __future__ import annotations
2
+ import boto3
3
+ from botocore.config import Config
4
+ import json
5
+ import time
6
+ import os
7
+ import io
8
+ import re
9
+ from copy import copy
10
+ from typing import Optional, List, Dict
11
+ from enum import Enum
12
+ import pymupdf
13
+ import concurrent.futures
14
+ from dataclasses import dataclass
15
+ from my_aws_helpers.s3 import S3Serialiser, BaseS3Object, BaseS3Queries, S3, S3Location
16
+
17
+ from my_aws_helpers.logging import select_powertools_logger
18
+
19
+
20
+ logger = select_powertools_logger("aws-helpers-bedrock")
21
+
22
+
23
+ class ImageType(str, Enum):
24
+ gif = "gif"
25
+ jpeg = "jpeg"
26
+ png = "png"
27
+ webp = "webp"
28
+
29
+
30
+ class PromptType(str, Enum):
31
+ transaction_headers = "transactions_headers_prompt_v2.txt"
32
+ transactions = "transactions_prompt.txt"
33
+ # json = "json_system_prompt.txt"
34
+ markdown = "markdown_system_prompt.txt"
35
+
36
+
37
+ @dataclass
38
+ class TokenUsage:
39
+ input_tokens: int
40
+ output_tokens: int
41
+ total_tokens: int
42
+
43
+ @classmethod
44
+ def from_dict(cls, data: Dict[str, int]) -> TokenUsage:
45
+ return cls(
46
+ input_tokens=data.get("inputTokens", 0),
47
+ output_tokens=data.get("outputTokens", 0),
48
+ total_tokens=data.get("totalTokens", 0),
49
+ )
50
+
51
+
52
+ @dataclass
53
+ class OCRResult(BaseS3Object):
54
+ content: List[Dict[str, str]]
55
+ token_usage: TokenUsage
56
+ page_number: int
57
+
58
+ @classmethod
59
+ def from_dict(cls, data: Dict) -> OCRResult:
60
+ return cls(
61
+ content=data.get("content", []),
62
+ token_usage=TokenUsage.from_dict(data.get("token_usage", {})),
63
+ page_number=data.get("page_number", 0),
64
+ )
65
+
66
+ @classmethod
67
+ def from_s3_representation(cls, obj: dict) -> OCRResult:
68
+ obj["token_usage"] = (TokenUsage.from_dict(obj.get("token_usage", {})),)
69
+ return cls(**obj)
70
+
71
+ def to_s3_representation(self) -> dict:
72
+ obj = copy(vars(self))
73
+ obj["token_usage"] = S3Serialiser.object_serialiser(
74
+ obj=vars(obj["token_usage"])
75
+ )
76
+ return S3Serialiser.object_serialiser(obj=obj)
77
+
78
+ def get_save_location(self, bucket_name: str) -> S3Location:
79
+ pass
80
+
81
+
82
+ class OCRResultQueries(BaseS3Queries):
83
+ def __init__(self, s3_client: S3, bucket_name: str):
84
+ super().__init__(s3_client=s3_client, bucket_name=bucket_name)
85
+
86
+ def save_ocr_result_to_s3(
87
+ self, ocr_result: OCRResult, save_location: S3Location
88
+ ) -> Optional[S3Location]:
89
+ try:
90
+ obj = ocr_result.to_s3_representation()
91
+ return self.s3_client.save_dict_to_s3(
92
+ content=obj,
93
+ s3_location=save_location,
94
+ )
95
+ except Exception as e:
96
+ logger.exception(f"Failed to save ocr result to s3 due to {e}")
97
+ return None
98
+
99
+ def _concurrent_s3_read(
100
+ self, locations: List[S3Location], max_workers: int = 10
101
+ ) -> List[OCRResult]:
102
+ results: List[OCRResult] = list()
103
+ futures = list()
104
+ with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
105
+ for loc in locations:
106
+ future = executor.submit(
107
+ self.s3_client.read_dict_from_s3,
108
+ s3_location=loc,
109
+ )
110
+ futures.append(future)
111
+ for f in futures:
112
+ results.append(f.result())
113
+ results = [r for r in results if r is not None]
114
+ return results
115
+
116
+ def get_ocr_results_by_key_prefix(self, prefix: str) -> List[OCRResult]:
117
+ locations = self.s3_client.list_objects_by_prefix(
118
+ bucket_name=self.bucket_name, prefix=prefix
119
+ )
120
+ objects = self._concurrent_s3_read(locations=locations)
121
+ ocr_results = [OCRResult.from_s3_representation(obj=obj) for obj in objects]
122
+ return ocr_results
123
+
124
+
125
+ class Bedrock:
126
+ def __init__(
127
+ self,
128
+ model_id: str = "apac.anthropic.claude-3-5-sonnet-20241022-v2:0", # anthropic.claude-sonnet-4-20250514-v1:0
129
+ logger=None,
130
+ sleep_time: float = 1.0,
131
+ ):
132
+
133
+ self.session = Bedrock._set_session_params()
134
+ self.logger = logger
135
+ region_name = "ap-southeast-2"
136
+ if self.session is None:
137
+ self.session = boto3.Session(region_name=region_name)
138
+ self.sleep_time = sleep_time
139
+
140
+ custom_config = Config(
141
+ retries={
142
+ "max_attempts": 2, # Total attempts = 1 initial + 1 retry
143
+ "mode": "standard", # or 'adaptive'
144
+ }
145
+ )
146
+ self.client = self.session.client(
147
+ "bedrock-runtime", region_name=region_name, config=custom_config
148
+ )
149
+ self.model_id = model_id
150
+
151
+ @staticmethod
152
+ def _set_session_params():
153
+ try:
154
+ aws_access_key_id = os.environ["AWS_ACCESS_KEY_ID"]
155
+ aws_secret_access_key = os.environ["AWS_SECRET_ACCESS_KEY"]
156
+ aws_session_token = os.environ["AWS_SESSION_TOKEN"]
157
+ region_name = os.environ["AWS_DEFAULT_REGION"]
158
+ return boto3.Session(
159
+ aws_access_key_id=aws_access_key_id,
160
+ aws_secret_access_key=aws_secret_access_key,
161
+ aws_session_token=aws_session_token,
162
+ region_name=region_name,
163
+ )
164
+ except Exception as e:
165
+ print(e)
166
+ return None
167
+
168
+ @staticmethod
169
+ def extract_json_from_markdown(text: str):
170
+ """
171
+ Extracts the JSON object from a string that may be wrapped in ```json ... ``` code block
172
+ """
173
+ # Match a {...} block anywhere in the text
174
+ match = re.search(r"\{.*\}", text, re.DOTALL)
175
+ if match:
176
+ json_str = match.group(0)
177
+ return json.loads(json_str)
178
+ else:
179
+ raise ValueError("No JSON object found in the text")
180
+
181
+ def _get_prompt(self, prompt_type: str) -> Optional[str]:
182
+ if prompt_type not in list(PromptType):
183
+ raise Exception(f"Error: Invalid prompt type")
184
+
185
+ path = os.path.join(os.path.dirname(__file__), "prompts", prompt_type)
186
+ try:
187
+ with open(path, "r") as f:
188
+ prompt = f.read()
189
+ return prompt
190
+ except Exception as e:
191
+ self.logger.exception(f"Failed to get {prompt_type} prompt due to {e}")
192
+ return None
193
+
194
+ def _ocr(
195
+ self, prompt: str, image_bytes: bytes, page_number: Optional[int] = 0
196
+ ) -> Optional[OCRResult]:
197
+ system_prompt = [{"text": prompt}]
198
+ message = [
199
+ {
200
+ "role": "user",
201
+ "content": [
202
+ {
203
+ "image": {
204
+ "format": "png",
205
+ "source": {
206
+ "bytes": image_bytes,
207
+ },
208
+ }
209
+ }
210
+ ],
211
+ }
212
+ ]
213
+ retries = 3
214
+ for i in range(retries):
215
+ self.logger.info(f"Attempt number {i} for {self.model_id} converse")
216
+ try:
217
+ response = self.client.converse(
218
+ modelId=self.model_id, messages=message, system=system_prompt
219
+ )
220
+ if response["ResponseMetadata"]["HTTPStatusCode"] == 200:
221
+ break
222
+ except Exception as e:
223
+ self.logger.exception(f"Error during conversation due to {e}")
224
+ if i >= len(retries) - 1:
225
+ raise Exception(e)
226
+ time.sleep(self.sleep_time)
227
+ continue
228
+
229
+ result = {}
230
+ result["content"] = json.loads(
231
+ response["output"]["message"]["content"][0]["text"]
232
+ )
233
+ result["token_usage"] = response["usage"]
234
+ result["page_number"] = page_number
235
+ return OCRResult.from_dict(data=result)
236
+
237
+ def _parallel_ocr(
238
+ self,
239
+ image_bytes_list: List[bytes],
240
+ prompt: str,
241
+ max_workers: int = 10,
242
+ ):
243
+ execution_futures = []
244
+ with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
245
+ for i, img in enumerate(image_bytes_list):
246
+ self.logger.info(f"Starting OCR for page: {i}")
247
+ time.sleep(self.sleep_time) # Stagger start time
248
+ future = executor.submit(self._ocr, prompt=prompt, image_bytes=img)
249
+ execution_futures.append(future)
250
+
251
+ # Wait for all tasks and collect results in order of submission
252
+ results = [
253
+ future.result()
254
+ for future in execution_futures
255
+ if future.result() is not None
256
+ ]
257
+ return results
258
+
259
+ def get_ocr_result(
260
+ self,
261
+ pdf_bytes: io.BytesIO,
262
+ prompt_type: str,
263
+ zoom: int = 7,
264
+ ) -> List[OCRResult]:
265
+ try:
266
+ self.logger.info("Getting OCR Results")
267
+ document = pymupdf.open(stream=pdf_bytes, filetype="pdf")
268
+ pages: List[pymupdf.Page] = [p for p in document]
269
+
270
+ image_bytes_list: List[bytes] = list()
271
+ for i, p in enumerate(pages):
272
+ try:
273
+ image_bytes: bytes = p.get_pixmap(
274
+ matrix=pymupdf.Matrix(zoom, zoom)
275
+ ).tobytes("png")
276
+ image_bytes_list.append(image_bytes)
277
+ except Exception as e:
278
+ self.logger.error(f"Could not get pix map for page {i}")
279
+ continue
280
+ skip_page_zero = False
281
+ header_ocr_result = None
282
+ if len(image_bytes_list) > 1:
283
+ headers_prompt = self._get_prompt(
284
+ prompt_type=PromptType.transaction_headers.value
285
+ )
286
+ for i in range(2):
287
+ # try to get headers from the first or second page
288
+ header_ocr_result = self._ocr(
289
+ prompt=headers_prompt, image_bytes=image_bytes_list[i]
290
+ )
291
+ if header_ocr_result is None:
292
+ self.logger.info(
293
+ f"No ocr result returned when getting headers {PromptType.transaction_headers.value}"
294
+ )
295
+ headers = header_ocr_result.content.get("headers")
296
+ if (len(headers) < 1) or (headers is None):
297
+ skip_page_zero = True
298
+ continue
299
+ else:
300
+ break
301
+
302
+ transactions_prompt = self._get_prompt(prompt_type=prompt_type)
303
+ if header_ocr_result:
304
+ transactions_prompt = transactions_prompt.replace(
305
+ "#### TABLE HEADERS ####", json.dumps(header_ocr_result.content)
306
+ )
307
+
308
+ self.logger.info("Got Prompt")
309
+ results = list()
310
+
311
+ if skip_page_zero:
312
+ image_bytes_list = image_bytes_list[
313
+ 1:
314
+ ] # page zero often has account summary info
315
+ results = self._parallel_ocr(
316
+ image_bytes_list=image_bytes_list, prompt=transactions_prompt
317
+ )
318
+
319
+ # for i, image_bytes in enumerate(image_bytes_list):
320
+ # self.logger.info(f"Starting OCR for page: {i}")
321
+ # results.append(self._ocr(image_bytes=image_bytes, prompt=transactions_prompt))
322
+ return results
323
+ except Exception as e:
324
+ self.logger.exception(e)
325
+ return []
326
+
327
+ def _get_image_block(self, image: bytes, image_content_type: ImageType) -> dict:
328
+ return {
329
+ "image": {
330
+ "format": image_content_type,
331
+ "source": {
332
+ "bytes": image,
333
+ },
334
+ }
335
+ }
336
+
337
+ def image_analysis(
338
+ self, images: List[bytes], prompt: str, image_content_type: ImageType
339
+ ) -> OCRResult:
340
+
341
+ system_prompt = [{"text": prompt}]
342
+ message = [
343
+ {
344
+ "role": "user",
345
+ "content": [
346
+ self._get_image_block(
347
+ image=image, image_content_type=image_content_type
348
+ )
349
+ for image in images
350
+ ],
351
+ }
352
+ ]
353
+ response = self.client.converse(
354
+ modelId=self.model_id, messages=message, system=system_prompt
355
+ )
356
+
357
+ result = {}
358
+ result["content"] = Bedrock.extract_json_from_markdown(
359
+ text=response["output"]["message"]["content"][0]["text"]
360
+ )
361
+ result["token_usage"] = response["usage"]
362
+ return OCRResult.from_dict(data=result)
@@ -0,0 +1,141 @@
1
+ import os
2
+ import json
3
+ from typing import Optional
4
+ import boto3
5
+ import urllib.request
6
+ from jose import jwk, jwt
7
+ from jose.utils import base64url_decode
8
+ from my_aws_helpers.logging import select_powertools_logger
9
+
10
+ logger = select_powertools_logger("aws-helpers-cognito")
11
+
12
+
13
+ class Cognito:
14
+ client: boto3.client
15
+
16
+ def __init__(
17
+ self,
18
+ client: Optional[boto3.client] = None,
19
+ user_pool_id: Optional[str] = None,
20
+ ):
21
+ self.cognito_user_pool_id = (
22
+ user_pool_id if user_pool_id else os.environ["COGNITO_USER_POOL_ID"]
23
+ )
24
+ self.region = os.environ.get("AWS_DEFAULT_REGION", "ap-southeast-2")
25
+ self.client = client if client else self._get_client(region=self.region)
26
+
27
+ def _get_client(self, region: str) -> boto3.client:
28
+ return boto3.client("cognito-idp", region_name=region)
29
+
30
+ def _verify_signature(self):
31
+ pass
32
+
33
+ def _verify_audience(self):
34
+ pass
35
+
36
+ def _verify_token_use(self):
37
+ pass
38
+
39
+ def validate_token(self, token: str) -> bool:
40
+ try:
41
+ if "Bearer" in token:
42
+ token = token.split(" ")[-1]
43
+ headers = jwt.get_unverified_headers(token)
44
+ key_id = headers["kid"]
45
+ keys = self._get_keys()
46
+ key = next(k for k in keys if k["kid"] == key_id)
47
+ public_key = jwk.construct(key)
48
+ message, encoded_signature = token.rsplit(".", 1)
49
+ decoded_signature = base64url_decode(encoded_signature.encode("utf-8"))
50
+ if not public_key.verify(message.encode("utf-8"), decoded_signature):
51
+ raise Exception("Signature verification failed")
52
+ return True
53
+ except Exception as e:
54
+ logger.exception(f"Failed to validate token due to {e}")
55
+ return False
56
+
57
+ def sign_up(self, username: str, password: str, app_client_id: str):
58
+ try:
59
+ response = self.client.sign_up(
60
+ ClientId=app_client_id,
61
+ Username=username,
62
+ Password=password,
63
+ UserAttributes=[
64
+ {"Name": "email", "Value": username},
65
+ ],
66
+ )
67
+ return response
68
+ except Exception as e:
69
+ logger.exception(f"Failed to sign up due to {e}")
70
+ return None
71
+
72
+ def confirm_sign_up(
73
+ self,
74
+ username: str,
75
+ client_id: str,
76
+ confirmation_code: str,
77
+ ) -> dict:
78
+ try:
79
+ response = self.client.confirm_sign_up(
80
+ ClientId=client_id, Username=username, ConfirmationCode=confirmation_code,
81
+ )
82
+ return response
83
+ except Exception as e:
84
+ logger.exception(
85
+ f"Failed to confirm sign up username {username} due to {e}"
86
+ )
87
+ return None
88
+
89
+ def admin_confirm_sign_up(
90
+ self,
91
+ username: str,
92
+ user_pool_id: str,
93
+ ) -> dict:
94
+ try:
95
+ response = self.client.admin_confirm_sign_up(
96
+ UserPoolId=user_pool_id, Username=username
97
+ )
98
+ return response
99
+ except Exception as e:
100
+ logger.exception(
101
+ f"Failed to confirm admin sign up username {username} due to {e}"
102
+ )
103
+ return None
104
+
105
+ def refresh_token(self, refresh_token: str, app_client_id: str):
106
+ response = self.client.initiate_auth(
107
+ ClientId=app_client_id,
108
+ AuthFlow="REFRESH_TOKEN_AUTH",
109
+ AuthParameters={"REFRESH_TOKEN": refresh_token},
110
+ )
111
+ return response
112
+
113
+ def login(self, username: str, password: str, app_client_id: str):
114
+ response = self.client.initiate_auth(
115
+ ClientId=app_client_id,
116
+ AuthFlow="USER_PASSWORD_AUTH",
117
+ AuthParameters={"USERNAME": username, "PASSWORD": password},
118
+ )
119
+ return response
120
+
121
+ def _get_issuer(self) -> str:
122
+ return f"https://cognito-idp.{self.region}.amazonaws.com/{self.cognito_user_pool_id}"
123
+
124
+ def _get_keys(self):
125
+ issuer = self._get_issuer()
126
+ keys_url = f"{issuer}/.well-known/jwks.json"
127
+ with urllib.request.urlopen(keys_url) as f:
128
+ response = f.read()
129
+ keys = json.loads(response.decode("utf-8"))["keys"]
130
+ return keys
131
+
132
+ @staticmethod
133
+ def get_policy(allow: bool, method_arn: str = "*") -> dict:
134
+ allow = "Allow" if allow else "Deny"
135
+ return {
136
+ "principalId": "authenticated-user",
137
+ "policyDocument": {
138
+ "Version": "2012-10-17",
139
+ "Statement": [{"Action": "*", "Effect": allow, "Resource": "*"}],
140
+ },
141
+ }
@@ -2,12 +2,13 @@ from typing import List, Any, Optional
2
2
  from datetime import datetime, date
3
3
  import boto3
4
4
  from abc import ABC, abstractmethod
5
- from decimal import Decimal, Context
5
+ from decimal import Decimal, Context
6
6
 
7
7
  from my_aws_helpers.logging import select_powertools_logger
8
8
 
9
9
  logger = select_powertools_logger("aws-helpers-dynamo")
10
10
 
11
+
11
12
  class MetaData:
12
13
  """
13
14
  This class is a convenience class,
@@ -76,7 +77,6 @@ class BaseTableObject:
76
77
 
77
78
  def __init__(self) -> None:
78
79
  pass
79
-
80
80
 
81
81
 
82
82
  class DynamoSerialiser:
@@ -86,10 +86,10 @@ class DynamoSerialiser:
86
86
  if isinstance(obj, datetime) or isinstance(obj, date):
87
87
  return obj.isoformat()
88
88
  if isinstance(obj, float):
89
- ctx = Context(prec = 38)
89
+ ctx = Context(prec=38)
90
90
  return ctx.create_decimal_from_float(obj)
91
91
  return obj
92
-
92
+
93
93
  @staticmethod
94
94
  def object_serialiser(obj: Any):
95
95
  if isinstance(obj, list):
@@ -98,6 +98,7 @@ class DynamoSerialiser:
98
98
  return {k: DynamoSerialiser.object_serialiser(v) for k, v in obj.items()}
99
99
  return DynamoSerialiser._serialise(obj=obj)
100
100
 
101
+
101
102
  class Dynamo:
102
103
  table: boto3.resource
103
104
 
@@ -112,7 +113,7 @@ class Dynamo:
112
113
  return self.table.get_item(Item=item)
113
114
 
114
115
  def delete_item(self, item: dict):
115
- return self.table.delete_item(Item=item)
116
+ return self.table.delete_item(Key=item)
116
117
 
117
118
  def batch_put(self, items: List[dict]) -> None:
118
119
  with self.table.batch_writer() as batch:
@@ -130,14 +131,16 @@ class Dynamo:
130
131
  response = self.table.scan()
131
132
  items: List = response["Items"]
132
133
  while response.get("LastEvaluatedKey") is not None:
133
- response = self.table.scan(ExclusiveStartKey = response["LastEvaluatedKey"])
134
+ response = self.table.scan(ExclusiveStartKey=response["LastEvaluatedKey"])
134
135
  if response.get("Items") is not None:
135
136
  items.extend(response["Items"])
136
137
  if response.get("LastEvaluatedKey") is None:
137
138
  break
138
139
  return items
139
140
 
140
- def delete_table_items(self, partition_key_name: str = "pk", sort_key_name: str = "sk") -> bool:
141
+ def delete_table_items(
142
+ self, partition_key_name: str = "pk", sort_key_name: str = "sk"
143
+ ) -> bool:
141
144
  try:
142
145
  items = self._deep_scan()
143
146
  delete_repr_items = [
@@ -147,10 +150,10 @@ class Dynamo:
147
150
  }
148
151
  for item in items
149
152
  ]
150
- self.batch_delete(items = delete_repr_items)
153
+ self.batch_delete(items=delete_repr_items)
151
154
  return True
152
155
  except Exception as e:
153
- logger.exception(f'Failed to delete table items due to {e}')
156
+ logger.exception(f"Failed to delete table items due to {e}")
154
157
  return False
155
158
 
156
159
  def to_dynamo_representation(obj: dict):
@@ -177,6 +180,7 @@ def _datatype_map(value: Any):
177
180
  return new_obj
178
181
  return value
179
182
 
183
+
180
184
  class BaseQueries(ABC):
181
185
  table_name: str
182
186
 
@@ -186,7 +190,7 @@ class BaseQueries(ABC):
186
190
 
187
191
  def _get_client(self):
188
192
  return Dynamo(table_name=self.table_name)
189
-
193
+
190
194
  def _iterative_query(self, query_kwargs: dict) -> List[dict]:
191
195
  results = list()
192
196
  last_evaluated_key = "not none"
@@ -198,4 +202,4 @@ class BaseQueries(ABC):
198
202
  results += result["Items"]
199
203
  last_evaluated_key = result.get("LastEvaluatedKey")
200
204
  exclusive_start_key = last_evaluated_key
201
- return results
205
+ return results