clear-skies-aws 1.9.18__py3-none-any.whl → 1.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,540 @@
1
+ import base64
2
+ import json
3
+ import re
4
+ import unittest
5
+ from decimal import Decimal
6
+ from unittest.mock import MagicMock, call, patch
7
+
8
+ from boto3.session import Session as Boto3Session
9
+ from botocore.exceptions import ClientError
10
+ from clearskies import Model
11
+ from clearskies.autodoc.schema import String as AutoDocString
12
+
13
+ from clearskies_aws.backends.dynamo_db_parti_ql_backend import (
14
+ DynamoDBPartiQLBackend,
15
+ DynamoDBPartiQLCursor,
16
+ )
17
+
18
+
19
+ @patch("clearskies_aws.backends.dynamo_db_parti_ql_backend.logger")
20
+ class TestDynamoDBPartiQLBackend(unittest.TestCase):
21
+
22
+ def setUp(self):
23
+ """Set up the test environment before each test method."""
24
+ self.mock_boto3_session = MagicMock(spec=Boto3Session)
25
+ self.mock_dynamodb_client = MagicMock()
26
+ self.mock_boto3_session.client.return_value = self.mock_dynamodb_client
27
+
28
+ self.cursor_under_test = DynamoDBPartiQLCursor(self.mock_boto3_session)
29
+
30
+ self.backend = DynamoDBPartiQLBackend(self.cursor_under_test)
31
+ self.mock_model = MagicMock(spec=Model)
32
+ self.mock_model.table_name = "my_test_table"
33
+ self.mock_model.id_column_name = "id"
34
+
35
+ self.mock_model.schema = MagicMock()
36
+ self.mock_model.schema.return_value.indexes = {}
37
+
38
+ self.backend._get_table_description = MagicMock()
39
+ self.backend._get_table_description.return_value = {
40
+ "KeySchema": [{"AttributeName": "id", "KeyType": "HASH"}],
41
+ "GlobalSecondaryIndexes": [],
42
+ }
43
+
44
+ def _get_base_config(self, table_name="test_table", **overrides):
45
+ """Helper to create a base configuration dictionary with defaults."""
46
+ config = {
47
+ "table_name": table_name,
48
+ "wheres": [],
49
+ "sorts": [],
50
+ "limit": None,
51
+ "pagination": {},
52
+ "model_columns": [],
53
+ "select_all": False,
54
+ "selects": [],
55
+ "group_by_column": None,
56
+ "joins": [],
57
+ }
58
+ config.update(overrides)
59
+ return config
60
+
61
+ def test_as_sql_simple_select_all(self, mock_logger_arg):
62
+ """Test SQL generation for a simple 'SELECT *' statement."""
63
+ config = self._get_base_config(table_name="users", select_all=True)
64
+ config["_chosen_index_name"] = None
65
+ statement, params, limit, next_token = self.backend.as_sql(config)
66
+ self.assertEqual('SELECT * FROM "users"', statement)
67
+ self.assertEqual([], params)
68
+ self.assertIsNone(limit)
69
+ self.assertEqual(next_token, config.get("pagination", {}).get("next_token"))
70
+
71
+ def test_as_sql_select_specific_columns(self, mock_logger_arg):
72
+ """Test SQL generation for selecting specific columns."""
73
+ config = self._get_base_config(table_name="products", selects=["name", "price"])
74
+ config["_chosen_index_name"] = None
75
+ statement, params, limit, next_token = self.backend.as_sql(config)
76
+ self.assertEqual('SELECT "name", "price" FROM "products"', statement)
77
+ self.assertEqual([], params)
78
+
79
+ def test_as_sql_select_all_and_specific_columns_uses_specific(
80
+ self, mock_logger_arg
81
+ ):
82
+ """Test SQL generation uses specific columns if both select_all and selects are given."""
83
+ config = self._get_base_config(
84
+ table_name="inventory", select_all=True, selects=["item_id", "stock_count"]
85
+ )
86
+ config["_chosen_index_name"] = None
87
+ statement, params, limit, next_token = self.backend.as_sql(config)
88
+ expected_sql = 'SELECT "item_id", "stock_count" FROM "inventory"'
89
+ self.assertEqual(expected_sql, statement)
90
+ mock_logger_arg.warning.assert_any_call(
91
+ "Both 'select_all=True' and specific 'selects' were provided. Using specific 'selects'."
92
+ )
93
+
94
+ def test_as_sql_default_select_if_no_select_all_or_selects(self, mock_logger_arg):
95
+ """Test SQL generation defaults to 'SELECT *' if no specific columns are given."""
96
+ config = self._get_base_config(table_name="orders")
97
+ config["_chosen_index_name"] = None
98
+ statement, params, limit, next_token = self.backend.as_sql(config)
99
+ self.assertEqual('SELECT * FROM "orders"', statement)
100
+ self.assertEqual([], params)
101
+
102
+ def test_as_sql_with_wheres(self, mock_logger_arg):
103
+ """Test SQL generation with WHERE clauses."""
104
+ config = self._get_base_config(
105
+ table_name="customers",
106
+ select_all=True,
107
+ wheres=[
108
+ {"column": "city", "operator": "=", "values": ["New York"]},
109
+ {"column": "age", "operator": ">", "values": [30]},
110
+ ],
111
+ )
112
+ config["_chosen_index_name"] = None
113
+ statement, params, limit, next_token = self.backend.as_sql(config)
114
+ expected_statement = 'SELECT * FROM "customers" WHERE "city" = ? AND "age" > ?'
115
+ expected_parameters = [{"S": "New York"}, {"N": "30"}]
116
+ self.assertEqual(expected_statement, statement)
117
+ self.assertEqual(expected_parameters, params)
118
+
119
+ def test_as_sql_with_sorts(self, mock_logger_arg):
120
+ """Test SQL generation with ORDER BY clauses (no table prefix for columns)."""
121
+ config = self._get_base_config(
122
+ table_name="items",
123
+ select_all=True,
124
+ sorts=[
125
+ {"column": "name", "direction": "ASC"},
126
+ {"column": "created_at", "direction": "DESC"},
127
+ ],
128
+ )
129
+ config["_chosen_index_name"] = None
130
+ statement, params, limit, next_token = self.backend.as_sql(config)
131
+ expected_statement = (
132
+ 'SELECT * FROM "items" ORDER BY "name" ASC, "created_at" DESC'
133
+ )
134
+ self.assertEqual(expected_statement, statement)
135
+
136
+ def test_as_sql_with_index_name(self, mock_logger_arg):
137
+ """Test SQL generation uses index name in FROM clause if provided."""
138
+ config = self._get_base_config(table_name="my_table", select_all=True)
139
+ config["_chosen_index_name"] = "my_gsi"
140
+
141
+ statement, params, limit, next_token = self.backend.as_sql(config)
142
+ self.assertEqual('SELECT * FROM "my_table"."my_gsi"', statement)
143
+
144
+ def test_as_sql_ignores_group_by_and_joins(self, mock_logger_arg):
145
+ """Test that GROUP BY and JOIN configurations are ignored for SQL but logged."""
146
+ config = self._get_base_config(
147
+ table_name="log_data", group_by_column="level", joins=["some_join_info"]
148
+ )
149
+ config["_chosen_index_name"] = None
150
+
151
+ statement, _, _, _ = self.backend.as_sql(config)
152
+ self.assertNotIn("GROUP BY", statement.upper())
153
+ self.assertNotIn("JOIN", statement.upper())
154
+ mock_logger_arg.warning.assert_any_call(
155
+ "Configuration included 'group_by_column=level', "
156
+ "but GROUP BY is not supported by this DynamoDB PartiQL backend and will be ignored for SQL generation."
157
+ )
158
+ mock_logger_arg.warning.assert_any_call(
159
+ "Configuration included 'joins=['some_join_info']', "
160
+ "but JOINs are not supported by this DynamoDB PartiQL backend and will be ignored for SQL generation."
161
+ )
162
+
163
+ def test_check_query_configuration_sort_with_base_table_hash_key_equality(
164
+ self, mock_logger_arg
165
+ ):
166
+ """Test _check_query_configuration allows sort if base table hash key equality exists."""
167
+ self.backend._get_table_description.return_value = {
168
+ "KeySchema": [{"AttributeName": "id", "KeyType": "HASH"}],
169
+ "GlobalSecondaryIndexes": [],
170
+ }
171
+ config = self._get_base_config(
172
+ table_name="my_test_table",
173
+ sorts=[{"column": "name", "direction": "ASC"}],
174
+ wheres=[{"column": "id", "operator": "=", "values": ["some_id"]}],
175
+ )
176
+ processed_config = self.backend._check_query_configuration(
177
+ config, self.mock_model
178
+ )
179
+ self.assertIsNone(processed_config.get("_chosen_index_name"))
180
+ self.assertEqual(processed_config.get("_partition_key_for_target"), "id")
181
+
182
+ def test_check_query_configuration_sort_raises_error_if_no_hash_key_equality(
183
+ self, mock_logger_arg
184
+ ):
185
+ """Test _check_query_configuration raises ValueError for sort without hash key equality."""
186
+ self.backend._get_table_description.return_value = {
187
+ "KeySchema": [{"AttributeName": "id", "KeyType": "HASH"}],
188
+ "GlobalSecondaryIndexes": [],
189
+ }
190
+ config = self._get_base_config(
191
+ table_name="my_test_table",
192
+ sorts=[{"column": "name", "direction": "ASC"}],
193
+ wheres=[{"column": "status", "operator": "=", "values": ["active"]}],
194
+ )
195
+ expected_error_message = "DynamoDB PartiQL queries with ORDER BY on 'my_test_table' require an equality condition on its partition key ('id') in the WHERE clause."
196
+ with self.assertRaisesRegex(ValueError, re.escape(expected_error_message)):
197
+ self.backend._check_query_configuration(config, self.mock_model)
198
+
199
+ def test_check_query_configuration_sort_uses_gsi_if_partition_key_matches(
200
+ self, mock_logger_arg
201
+ ):
202
+ """Test _check_query_configuration selects GSI if its partition key matches WHERE and can sort."""
203
+ self.backend._get_table_description.return_value = {
204
+ "KeySchema": [{"AttributeName": "id", "KeyType": "HASH"}],
205
+ "GlobalSecondaryIndexes": [
206
+ {
207
+ "IndexName": "domain-status-index",
208
+ "KeySchema": [
209
+ {"AttributeName": "domain", "KeyType": "HASH"},
210
+ {"AttributeName": "status", "KeyType": "RANGE"},
211
+ ],
212
+ "Projection": {"ProjectionType": "ALL"},
213
+ }
214
+ ],
215
+ }
216
+ config = self._get_base_config(
217
+ table_name="my_test_table",
218
+ sorts=[{"column": "status", "direction": "DESC"}],
219
+ wheres=[{"column": "domain", "operator": "=", "values": ["example.com"]}],
220
+ )
221
+ processed_config = self.backend._check_query_configuration(
222
+ config, self.mock_model
223
+ )
224
+ self.assertEqual(
225
+ processed_config.get("_chosen_index_name"), "domain-status-index"
226
+ )
227
+ self.assertEqual(processed_config.get("_partition_key_for_target"), "domain")
228
+
229
+ def test_count_uses_native_query_with_pk_condition(self, mock_logger_arg):
230
+ """Test count() uses native DDB query when PK equality is present."""
231
+ self.backend._get_table_description.return_value = {
232
+ "KeySchema": [{"AttributeName": "id", "KeyType": "HASH"}]
233
+ }
234
+ config = self._get_base_config(
235
+ table_name="users",
236
+ wheres=[{"column": "id", "operator": "=", "values": ["user123"]}],
237
+ )
238
+ self.mock_dynamodb_client.query.return_value = {"Count": 10, "Items": []}
239
+
240
+ count = self.backend.count(config, self.mock_model)
241
+ self.assertEqual(count, 10)
242
+ self.mock_dynamodb_client.query.assert_called_once()
243
+ self.mock_dynamodb_client.scan.assert_not_called()
244
+ called_args = self.mock_dynamodb_client.query.call_args[1]
245
+ self.assertEqual(called_args.get("TableName"), "users")
246
+ self.assertEqual(called_args.get("Select"), "COUNT")
247
+ self.assertIn("KeyConditionExpression", called_args)
248
+
249
+ def test_count_uses_native_scan_without_pk_condition(self, mock_logger_arg):
250
+ """Test count() uses native DDB scan when PK equality is NOT present."""
251
+ self.backend._get_table_description.return_value = {
252
+ "KeySchema": [{"AttributeName": "id", "KeyType": "HASH"}]
253
+ }
254
+ config = self._get_base_config(
255
+ table_name="users",
256
+ wheres=[{"column": "status", "operator": "=", "values": ["active"]}],
257
+ )
258
+ self.mock_dynamodb_client.scan.return_value = {"Count": 5, "Items": []}
259
+
260
+ count = self.backend.count(config, self.mock_model)
261
+ self.assertEqual(count, 5)
262
+ self.mock_dynamodb_client.scan.assert_called_once()
263
+ self.mock_dynamodb_client.query.assert_not_called()
264
+ called_args = self.mock_dynamodb_client.scan.call_args[1]
265
+ self.assertEqual(called_args.get("TableName"), "users")
266
+ self.assertEqual(called_args.get("Select"), "COUNT")
267
+ self.assertIn("FilterExpression", called_args)
268
+
269
+ def test_count_paginates_native_results(self, mock_logger_arg):
270
+ """Test count() paginates and sums results from native DDB operations."""
271
+ self.backend._get_table_description.return_value = {
272
+ "KeySchema": [{"AttributeName": "id", "KeyType": "HASH"}]
273
+ }
274
+ config = self._get_base_config(table_name="large_table")
275
+
276
+ self.mock_dynamodb_client.scan.side_effect = [
277
+ {"Count": 100, "LastEvaluatedKey": {"id": {"S": "page1_end"}}},
278
+ {"Count": 50, "LastEvaluatedKey": {"id": {"S": "page2_end"}}},
279
+ {"Count": 25},
280
+ ]
281
+ count = self.backend.count(config, self.mock_model)
282
+ self.assertEqual(count, 175)
283
+ self.assertEqual(self.mock_dynamodb_client.scan.call_count, 3)
284
+
285
+ def test_records_simple_fetch(self, mock_logger_arg):
286
+ """Test records() fetching a single page of results without limit or pagination."""
287
+ config = self._get_base_config(table_name="users", select_all=True)
288
+ expected_statement = 'SELECT * FROM "users"'
289
+ ddb_items = [
290
+ {"id": {"S": "user1"}, "name": {"S": "Alice"}, "age": {"N": "30"}},
291
+ {"id": {"S": "user2"}, "name": {"S": "Bob"}, "age": {"N": "24"}},
292
+ ]
293
+ self.mock_dynamodb_client.execute_statement.return_value = {"Items": ddb_items}
294
+
295
+ results = list(self.backend.records(config, self.mock_model))
296
+
297
+ expected_call_kwargs = {"Statement": expected_statement, "Parameters": []}
298
+ self.mock_dynamodb_client.execute_statement.assert_called_once_with(
299
+ **expected_call_kwargs
300
+ )
301
+ self.assertEqual(len(results), 2)
302
+ # Assert based on what _map_from_boto3 currently does
303
+ self.assertEqual(
304
+ results[0], {"id": "user1", "name": "Alice", "age": Decimal("30")}
305
+ )
306
+ self.assertEqual(
307
+ results[1], {"id": "user2", "name": "Bob", "age": Decimal("24")}
308
+ )
309
+ self.assertIsNone(config["pagination"].get("next_page_token_for_response"))
310
+
311
+ def test_records_with_limit(self, mock_logger_arg):
312
+ """Test records() respects the server-side limit passed to DynamoDB."""
313
+ config = self._get_base_config(table_name="products", limit=1, select_all=True)
314
+ expected_statement = 'SELECT * FROM "products"'
315
+ ddb_items = [{"id": {"S": "prod1"}, "price": {"N": "10.99"}}]
316
+ ddb_next_token = "fakeDDBNextToken"
317
+
318
+ self.mock_dynamodb_client.execute_statement.return_value = {
319
+ "Items": ddb_items,
320
+ "NextToken": ddb_next_token,
321
+ }
322
+ next_page_data = {}
323
+ results = list(self.backend.records(config, self.mock_model, next_page_data))
324
+
325
+ expected_call_kwargs = {
326
+ "Statement": expected_statement,
327
+ "Parameters": [],
328
+ "Limit": 1,
329
+ }
330
+ self.mock_dynamodb_client.execute_statement.assert_called_once_with(
331
+ **expected_call_kwargs
332
+ )
333
+ self.assertEqual(len(results), 1)
334
+ self.assertEqual(results[0], {"id": "prod1", "price": Decimal("10.99")})
335
+ expected_client_token = self.backend.serialize_next_token_for_response(
336
+ ddb_next_token
337
+ )
338
+ self.assertEqual(
339
+ next_page_data["next_token"],
340
+ expected_client_token,
341
+ )
342
+
343
+ def test_records_pagination_flow(self, mock_logger_arg):
344
+ """Test records() handling of client-provided token and returning DDB's next token."""
345
+ initial_ddb_token = "start_token_from_ddb_previously"
346
+ client_sends_this_token = self.backend.serialize_next_token_for_response(
347
+ initial_ddb_token
348
+ )
349
+ config1 = self._get_base_config(
350
+ table_name="events",
351
+ select_all=True,
352
+ pagination={"next_token": client_sends_this_token},
353
+ )
354
+ expected_statement = 'SELECT * FROM "events"'
355
+ ddb_items_page1 = [{"event_id": {"S": "evt1"}}]
356
+ ddb_next_token_page1 = "ddb_token_for_page2"
357
+
358
+ self.mock_dynamodb_client.execute_statement.return_value = {
359
+ "Items": ddb_items_page1,
360
+ "NextToken": ddb_next_token_page1,
361
+ }
362
+ next_page_data = {}
363
+ results_page1 = list(
364
+ self.backend.records(config1, self.mock_model, next_page_data)
365
+ )
366
+
367
+ expected_call_kwargs1 = {
368
+ "Statement": expected_statement,
369
+ "Parameters": [],
370
+ "NextToken": initial_ddb_token,
371
+ }
372
+ self.mock_dynamodb_client.execute_statement.assert_called_once_with(
373
+ **expected_call_kwargs1
374
+ )
375
+ self.assertEqual(len(results_page1), 1)
376
+ self.assertEqual(results_page1[0], {"event_id": "evt1"})
377
+ client_token_for_next_call = next_page_data["next_token"]
378
+ self.assertIsNotNone(client_token_for_next_call)
379
+ self.assertEqual(
380
+ self.backend.restore_next_token_from_config(client_token_for_next_call),
381
+ ddb_next_token_page1,
382
+ )
383
+
384
+ self.mock_dynamodb_client.execute_statement.reset_mock()
385
+ config2 = self._get_base_config(
386
+ table_name="events",
387
+ select_all=True,
388
+ pagination={"next_token": client_token_for_next_call},
389
+ )
390
+ ddb_items_page2 = [{"event_id": {"S": "evt2"}}]
391
+ self.mock_dynamodb_client.execute_statement.return_value = {
392
+ "Items": ddb_items_page2
393
+ }
394
+ next_page_data = {}
395
+ results_page2 = list(
396
+ self.backend.records(config2, self.mock_model, next_page_data)
397
+ )
398
+ expected_call_kwargs2 = {
399
+ "Statement": expected_statement,
400
+ "Parameters": [],
401
+ "NextToken": ddb_next_token_page1,
402
+ }
403
+ self.mock_dynamodb_client.execute_statement.assert_called_once_with(
404
+ **expected_call_kwargs2
405
+ )
406
+ self.assertEqual(len(results_page2), 1)
407
+ self.assertEqual(results_page2[0], {"event_id": "evt2"})
408
+ self.assertEqual(next_page_data, {})
409
+
410
+ def test_records_no_items_returned_with_next_token(self, mock_logger_arg):
411
+ """Test records() when DDB returns no items but provides a NextToken."""
412
+ config = self._get_base_config(table_name="filtered_items", select_all=True)
413
+ expected_statement = 'SELECT * FROM "filtered_items"'
414
+ ddb_next_token = "ddb_has_more_but_current_page_empty_after_filter"
415
+
416
+ self.mock_dynamodb_client.execute_statement.return_value = {
417
+ "Items": [],
418
+ "NextToken": ddb_next_token,
419
+ }
420
+ next_page_data = {}
421
+ results = list(self.backend.records(config, self.mock_model, next_page_data))
422
+
423
+ expected_call_kwargs = {"Statement": expected_statement, "Parameters": []}
424
+ self.mock_dynamodb_client.execute_statement.assert_called_once_with(
425
+ **expected_call_kwargs
426
+ )
427
+ self.assertEqual(len(results), 0)
428
+ expected_client_token = self.backend.serialize_next_token_for_response(
429
+ ddb_next_token
430
+ )
431
+ self.assertEqual(
432
+ next_page_data["next_token"],
433
+ expected_client_token,
434
+ )
435
+
436
+ def test_records_limit_cuts_off_ddb_page(self, mock_logger_arg):
437
+ """Test when server-side limit means fewer items are returned than a full DDB page."""
438
+ config = self._get_base_config(
439
+ table_name="many_items", limit=1, select_all=True
440
+ )
441
+ expected_statement = 'SELECT * FROM "many_items"'
442
+ ddb_items_returned_by_limit = [{"id": {"S": "item1"}}]
443
+ ddb_next_token_after_limit = "ddb_still_has_more_after_limit"
444
+
445
+ self.mock_dynamodb_client.execute_statement.return_value = {
446
+ "Items": ddb_items_returned_by_limit,
447
+ "NextToken": ddb_next_token_after_limit,
448
+ }
449
+ next_page_data = {}
450
+ results = list(self.backend.records(config, self.mock_model, next_page_data))
451
+
452
+ expected_call_kwargs = {
453
+ "Statement": expected_statement,
454
+ "Parameters": [],
455
+ "Limit": 1,
456
+ }
457
+ self.mock_dynamodb_client.execute_statement.assert_called_once_with(
458
+ **expected_call_kwargs
459
+ )
460
+ self.assertEqual(len(results), 1)
461
+ self.assertEqual(results[0], {"id": "item1"})
462
+ expected_client_token = self.backend.serialize_next_token_for_response(
463
+ ddb_next_token_after_limit
464
+ )
465
+ self.assertEqual(
466
+ next_page_data["next_token"],
467
+ expected_client_token,
468
+ )
469
+
470
+ def test_create_record(self, mock_logger_arg):
471
+ """Test create() inserts a record and returns the input data."""
472
+ data_to_create = {"id": "new_user_123", "name": "Jane Doe", "age": 28}
473
+ expected_ddb_parameters = [
474
+ {"id": {"S": "new_user_123"}, "name": {"S": "Jane Doe"}, "age": {"N": "28"}}
475
+ ]
476
+
477
+ self.mock_dynamodb_client.execute_statement.return_value = {}
478
+
479
+ created_data = self.backend.create(data_to_create, self.mock_model)
480
+
481
+ self.assertEqual(created_data, data_to_create)
482
+ self.mock_dynamodb_client.execute_statement.assert_called_once_with(
483
+ Statement='INSERT INTO "my_test_table" VALUE ?',
484
+ Parameters=expected_ddb_parameters,
485
+ )
486
+
487
+ def test_update_record(self, mock_logger_arg):
488
+ """Test update() modifies a record and returns the updated data."""
489
+ record_id = "user_to_update"
490
+ update_data = {"age": 35, "status": "active"}
491
+
492
+ expected_set_params = [{"N": "35"}, {"S": "active"}]
493
+ expected_id_param = {"S": "user_to_update"}
494
+ expected_ddb_parameters = expected_set_params + [expected_id_param]
495
+
496
+ updated_item_from_db = {
497
+ "id": {"S": record_id},
498
+ "name": {"S": "Original Name"},
499
+ "age": {"N": "35"},
500
+ "status": {"S": "active"},
501
+ }
502
+ self.mock_dynamodb_client.execute_statement.return_value = {
503
+ "Items": [updated_item_from_db]
504
+ }
505
+
506
+ updated_data_response = self.backend.update(
507
+ record_id, update_data, self.mock_model
508
+ )
509
+
510
+ expected_statement = 'UPDATE "my_test_table" SET "age" = ?, "status" = ? WHERE "id" = ? RETURNING ALL NEW *'
511
+ self.mock_dynamodb_client.execute_statement.assert_called_once_with(
512
+ Statement=expected_statement, Parameters=expected_ddb_parameters
513
+ )
514
+ self.assertEqual(
515
+ updated_data_response,
516
+ {
517
+ "id": "user_to_update",
518
+ "name": "Original Name",
519
+ "age": Decimal("35"),
520
+ "status": "active",
521
+ },
522
+ )
523
+
524
+ def test_delete_record(self, mock_logger_arg):
525
+ """Test delete() removes a record."""
526
+ record_id = "user_to_delete"
527
+ expected_ddb_parameters = [{"S": "user_to_delete"}]
528
+ self.mock_dynamodb_client.execute_statement.return_value = {}
529
+
530
+ result = self.backend.delete(record_id, self.mock_model)
531
+
532
+ self.assertTrue(result)
533
+ self.mock_dynamodb_client.execute_statement.assert_called_once_with(
534
+ Statement='DELETE FROM "my_test_table" WHERE "id" = ?',
535
+ Parameters=expected_ddb_parameters,
536
+ )
537
+
538
+
539
+ if __name__ == "__main__":
540
+ unittest.main(argv=["first-arg-is-ignored"], exit=False)
@@ -1,8 +1,11 @@
1
1
  import json
2
2
  import unittest
3
- from unittest.mock import MagicMock, call
4
3
  from types import SimpleNamespace
4
+ from unittest.mock import MagicMock, call
5
+
5
6
  from .lambda_sqs_standard_partial_batch import lambda_sqs_standard_partial_batch
7
+
8
+
6
9
  class LambdaSqsStandardPartialBatchTest(unittest.TestCase):
7
10
  def setUp(self):
8
11
  self.calls = []
@@ -29,14 +32,10 @@ class LambdaSqsStandardPartialBatchTest(unittest.TestCase):
29
32
  },
30
33
  {},
31
34
  )
32
- self.assertEquals(
35
+ self.assertEqual(
33
36
  [
34
- {
35
- 'hey': 'sup'
36
- },
37
- {
38
- 'cool': 'yo'
39
- },
37
+ {"hey": "sup"},
38
+ {"cool": "yo"},
40
39
  ],
41
40
  self.calls,
42
41
  )
@@ -55,17 +54,13 @@ class LambdaSqsStandardPartialBatchTest(unittest.TestCase):
55
54
  },
56
55
  ]
57
56
  }, {})
58
- self.assertEquals(
57
+ self.assertEqual(
59
58
  [
60
- {
61
- 'hey': 'sup'
62
- },
59
+ {"hey": "sup"},
63
60
  ],
64
61
  self.calls,
65
62
  )
66
- self.assertEquals(
67
- {'batchItemFailures': [{
68
- 'itemIdentifier': '2-3-4-5'
69
- }]},
63
+ self.assertEqual(
64
+ {"batchItemFailures": [{"itemIdentifier": "2-3-4-5"}]},
70
65
  results,
71
66
  )
@@ -1,18 +1,60 @@
1
+ from types import ModuleType
2
+
3
+ import boto3 as boto3_module
4
+ from boto3.session import Session as Boto3Session
5
+ from clearskies import Environment
1
6
  from clearskies.di import StandardDependencies as DefaultStandardDependencies
2
- from ..backends import DynamoDBBackend, SqsBackend
7
+
8
+ from ..backends import (
9
+ DynamoDBBackend,
10
+ DynamoDBPartiQLBackend,
11
+ DynamoDBPartiQLCursor,
12
+ SqsBackend,
13
+ )
3
14
  from ..secrets import ParameterStore
15
+
16
+
4
17
  class StandardDependencies(DefaultStandardDependencies):
5
- def provide_dynamo_db_backend(self, boto3, environment):
18
+
19
+ def provide_dynamo_db_parti_ql_cursor(
20
+ self, boto3_session: Boto3Session
21
+ ) -> DynamoDBPartiQLCursor:
22
+ return DynamoDBPartiQLCursor(boto3_session)
23
+
24
+ def provide_dynamo_db_backend(
25
+ self, boto3: ModuleType, environment: Environment
26
+ ) -> DynamoDBBackend:
6
27
  return DynamoDBBackend(boto3, environment)
7
28
 
8
- def provide_sqs_backend(self, boto3, environment):
29
+ def provide_dynamo_db_parti_ql_backend(
30
+ self, dynamo_db_parti_ql_cursor: DynamoDBPartiQLCursor, environment: Environment
31
+ ) -> DynamoDBPartiQLBackend:
32
+ return DynamoDBPartiQLBackend(dynamo_db_parti_ql_cursor)
33
+
34
+ def provide_sqs_backend(
35
+ self, boto3: ModuleType, environment: Environment
36
+ ) -> SqsBackend:
9
37
  return SqsBackend(boto3, environment)
10
38
 
11
- def provide_boto3(self):
39
+ def provide_boto3(self) -> ModuleType:
12
40
  import boto3
13
41
  return boto3
14
42
 
15
- def provide_secrets(self, boto3, environment):
43
+ def provide_secrets(
44
+ self, boto3: ModuleType, environment: Environment
45
+ ) -> ParameterStore:
16
46
  # This is just here so that we can auto-inject the secrets into the environment without having
17
47
  # to force the developer to define a secrets manager
18
48
  return ParameterStore(boto3, environment)
49
+
50
+ def provide_boto3_session(
51
+ self, boto3: ModuleType, environment: Environment
52
+ ) -> boto3_module.session.Session:
53
+
54
+ if not environment.get("AWS_REGION", True):
55
+ raise ValueError(
56
+ "To use AWS Session you must use set AWS_REGION in the .env file or an environment variable"
57
+ )
58
+
59
+ session = boto3.session.Session(region_name=environment.get("AWS_REGION", True))
60
+ return session