amplify-excel-migrator 1.1.5__py3-none-any.whl → 1.2.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. amplify_excel_migrator/__init__.py +17 -0
  2. amplify_excel_migrator/auth/__init__.py +6 -0
  3. amplify_excel_migrator/auth/cognito_auth.py +306 -0
  4. amplify_excel_migrator/auth/provider.py +42 -0
  5. amplify_excel_migrator/cli/__init__.py +5 -0
  6. amplify_excel_migrator/cli/commands.py +165 -0
  7. amplify_excel_migrator/client.py +47 -0
  8. amplify_excel_migrator/core/__init__.py +5 -0
  9. amplify_excel_migrator/core/config.py +98 -0
  10. amplify_excel_migrator/data/__init__.py +7 -0
  11. amplify_excel_migrator/data/excel_reader.py +23 -0
  12. amplify_excel_migrator/data/transformer.py +119 -0
  13. amplify_excel_migrator/data/validator.py +48 -0
  14. amplify_excel_migrator/graphql/__init__.py +8 -0
  15. amplify_excel_migrator/graphql/client.py +137 -0
  16. amplify_excel_migrator/graphql/executor.py +405 -0
  17. amplify_excel_migrator/graphql/mutation_builder.py +80 -0
  18. amplify_excel_migrator/graphql/query_builder.py +194 -0
  19. amplify_excel_migrator/migration/__init__.py +8 -0
  20. amplify_excel_migrator/migration/batch_uploader.py +23 -0
  21. amplify_excel_migrator/migration/failure_tracker.py +92 -0
  22. amplify_excel_migrator/migration/orchestrator.py +143 -0
  23. amplify_excel_migrator/migration/progress_reporter.py +57 -0
  24. amplify_excel_migrator/schema/__init__.py +6 -0
  25. model_field_parser.py → amplify_excel_migrator/schema/field_parser.py +100 -22
  26. amplify_excel_migrator/schema/introspector.py +95 -0
  27. {amplify_excel_migrator-1.1.5.dist-info → amplify_excel_migrator-1.2.15.dist-info}/METADATA +121 -26
  28. amplify_excel_migrator-1.2.15.dist-info/RECORD +40 -0
  29. amplify_excel_migrator-1.2.15.dist-info/entry_points.txt +2 -0
  30. amplify_excel_migrator-1.2.15.dist-info/top_level.txt +2 -0
  31. tests/__init__.py +1 -0
  32. tests/test_cli_commands.py +292 -0
  33. tests/test_client.py +187 -0
  34. tests/test_cognito_auth.py +363 -0
  35. tests/test_config_manager.py +347 -0
  36. tests/test_field_parser.py +615 -0
  37. tests/test_mutation_builder.py +391 -0
  38. tests/test_query_builder.py +384 -0
  39. amplify_client.py +0 -941
  40. amplify_excel_migrator-1.1.5.dist-info/RECORD +0 -9
  41. amplify_excel_migrator-1.1.5.dist-info/entry_points.txt +0 -2
  42. amplify_excel_migrator-1.1.5.dist-info/top_level.txt +0 -3
  43. migrator.py +0 -437
  44. {amplify_excel_migrator-1.1.5.dist-info → amplify_excel_migrator-1.2.15.dist-info}/WHEEL +0 -0
  45. {amplify_excel_migrator-1.1.5.dist-info → amplify_excel_migrator-1.2.15.dist-info}/licenses/LICENSE +0 -0
amplify_client.py DELETED
@@ -1,941 +0,0 @@
1
- import asyncio
2
- import logging
3
- import sys
4
- from getpass import getpass
5
- from typing import Dict, Any
6
-
7
- import aiohttp
8
- import boto3
9
- import requests
10
- import jwt
11
- import inflect
12
- from botocore.exceptions import NoCredentialsError, ProfileNotFound, NoRegionError, ClientError
13
- from pycognito import Cognito, MFAChallengeException
14
- from pycognito.exceptions import ForceChangePasswordException
15
-
16
- logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
17
- logger = logging.getLogger(__name__)
18
-
19
-
20
- class AuthenticationError(Exception):
21
- """Raised when authentication is required but not completed"""
22
-
23
- pass
24
-
25
-
26
- class GraphQLError(Exception):
27
- """Raised when GraphQL query returns errors"""
28
-
29
- pass
30
-
31
-
32
- class AmplifyClient:
33
- """
34
- Client for Amplify GraphQL using ADMIN_USER_PASSWORD_AUTH flow
35
- """
36
-
37
- def __init__(self, api_endpoint: str, user_pool_id: str, region: str, client_id: str):
38
- """
39
- Initialize the client
40
-
41
- Args:
42
- api_endpoint: Amplify GraphQL endpoint
43
- user_pool_id: Cognito User Pool ID
44
- region: AWS region
45
- client_id: Cognito App Client ID
46
- """
47
-
48
- self.api_endpoint = api_endpoint
49
- self.user_pool_id = user_pool_id
50
- self.region = region
51
- self.client_id = client_id
52
-
53
- self.batch_size = 20
54
- self.cognito_client = None
55
- self.boto_cognito_admin_client = None
56
- self.id_token = None
57
- self.mfa_tokens = None
58
- self.admin_group_name = "ADMINS"
59
-
60
- self.records_cache = {}
61
-
62
- def init_cognito_client(self, is_aws_admin: bool, username: str = None, aws_profile: str = None):
63
- try:
64
- if is_aws_admin:
65
- if aws_profile:
66
- session = boto3.Session(profile_name=aws_profile)
67
- self.boto_cognito_admin_client = session.client("cognito-idp", region_name=self.region)
68
- else:
69
- # Use default AWS credentials (from ~/.aws/credentials, env vars, or IAM role)
70
- self.boto_cognito_admin_client = boto3.client("cognito-idp", region_name=self.region)
71
-
72
- else:
73
- self.cognito_client = Cognito(
74
- user_pool_id=self.user_pool_id,
75
- client_id=self.client_id,
76
- user_pool_region=self.region,
77
- username=username,
78
- )
79
-
80
- except NoCredentialsError:
81
- logger.error("AWS credentials not found. Please configure AWS credentials.")
82
- logger.error("Options: 1) AWS CLI: 'aws configure', 2) Environment variables, 3) Pass credentials directly")
83
- raise RuntimeError(
84
- "Failed to initialize client: No AWS credentials found. "
85
- "Run 'aws configure' or set AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables."
86
- )
87
-
88
- except ProfileNotFound:
89
- logger.error(f"AWS profile '{aws_profile}' not found")
90
- raise RuntimeError(
91
- f"Failed to initialize client: AWS profile '{aws_profile}' not found. "
92
- f"Available profiles can be found in ~/.aws/credentials"
93
- )
94
-
95
- except NoRegionError:
96
- logger.error("No AWS region specified")
97
- raise RuntimeError(
98
- f"Failed to initialize client: No AWS region specified. "
99
- f"Provide region parameter or set AWS_DEFAULT_REGION environment variable."
100
- )
101
-
102
- except ValueError as e:
103
- logger.error(f"Invalid parameter: {e}")
104
- raise
105
-
106
- except ClientError as e:
107
- error_code = e.response.get("Error", {}).get("Code", "Unknown")
108
- error_msg = e.response.get("Error", {}).get("Message", str(e))
109
- logger.error(f"AWS Client Error [{error_code}]: {error_msg}")
110
- raise RuntimeError(f"Failed to initialize client: AWS error [{error_code}]: {error_msg}")
111
-
112
- except Exception as e:
113
- logger.error(f"Error during client initialization: {e}")
114
- raise RuntimeError(f"Failed to initialize client: {e}")
115
-
116
- def authenticate(self, username: str, password: str, mfa_code: str = None) -> bool:
117
- try:
118
- if not self.cognito_client:
119
- self.init_cognito_client(is_aws_admin=False, username=username)
120
-
121
- if mfa_code and self.mfa_tokens:
122
- if not self._complete_mfa_challenge(mfa_code):
123
- return False
124
- else:
125
- self.cognito_client.authenticate(password=password)
126
-
127
- self.id_token = self.cognito_client.id_token
128
-
129
- self._check_user_in_admins_group(self.id_token)
130
-
131
- logger.info("✅ Authentication successful")
132
- return True
133
-
134
- except MFAChallengeException as e:
135
- logger.warning("MFA required")
136
- if hasattr(e, "get_tokens"):
137
- self.mfa_tokens = e.get_tokens()
138
-
139
- mfa_code = input("Enter MFA code: ").strip()
140
- if mfa_code:
141
- return self.authenticate(username, password, mfa_code)
142
- else:
143
- logger.error("MFA code required but not provided")
144
- return False
145
- else:
146
- logger.error("MFA challenge received but no session tokens available")
147
- return False
148
-
149
- except ForceChangePasswordException:
150
- logger.warning("Password change required")
151
- new_password = input("Your password has expired. Enter new password: ").strip()
152
- confirm_password = input("Confirm new password: ").strip()
153
- if new_password != confirm_password:
154
- logger.error("Passwords do not match")
155
- return False
156
-
157
- try:
158
- self.cognito_client.new_password_challenge(password, new_password)
159
- return self.authenticate(username, new_password)
160
-
161
- except Exception as e:
162
- logger.error(f"Failed to change password: {e}")
163
- return False
164
-
165
- except Exception as e:
166
- logger.error(f"Authentication failed: {e}")
167
- return False
168
-
169
- def aws_admin_authenticate(self, username: str, password: str) -> bool:
170
- """
171
- Requires AWS credentials with cognito-idp:ListUserPoolClients permission
172
- """
173
- try:
174
- if not self.boto_cognito_admin_client:
175
- self.init_cognito_client(is_aws_admin=True)
176
-
177
- print(f"Authenticating {username} using ADMIN_USER_PASSWORD_AUTH flow...")
178
-
179
- response = self.boto_cognito_admin_client.admin_initiate_auth(
180
- UserPoolId=self.user_pool_id,
181
- ClientId=self.client_id,
182
- AuthFlow="ADMIN_USER_PASSWORD_AUTH",
183
- AuthParameters={"USERNAME": username, "PASSWORD": password},
184
- )
185
-
186
- self._check_for_mfa_challenges(response, username)
187
-
188
- if "AuthenticationResult" in response:
189
- self.id_token = response["AuthenticationResult"]["IdToken"]
190
- else:
191
- logger.error("❌ Authentication failed: No AuthenticationResult in response")
192
- return False
193
-
194
- self._check_user_in_admins_group(self.id_token)
195
-
196
- print(f"✅ Authentication successful")
197
- return True
198
-
199
- except self.cognito_client.exceptions.NotAuthorizedException as e:
200
- logger.error(f"❌ Authentication failed: {e}")
201
- return False
202
-
203
- except self.cognito_client.exceptions.UserNotFoundException:
204
- logger.error(f"❌ User not found: {username}")
205
- return False
206
-
207
- except Exception as e:
208
- logger.error(f"❌ Error during authentication: {e}")
209
- return False
210
-
211
- def _complete_mfa_challenge(self, mfa_code: str) -> bool:
212
- try:
213
- if not self.mfa_tokens:
214
- logger.error("No MFA session tokens available")
215
- return False
216
-
217
- challenge_name = self.mfa_tokens.get("ChallengeName", "SMS_MFA")
218
-
219
- if "SOFTWARE_TOKEN" in challenge_name:
220
- self.cognito_client.respond_to_software_token_mfa_challenge(code=mfa_code, mfa_tokens=self.mfa_tokens)
221
- else:
222
- self.cognito_client.respond_to_sms_mfa_challenge(code=mfa_code, mfa_tokens=self.mfa_tokens)
223
-
224
- logger.info("✅ MFA challenge successful")
225
- return True
226
-
227
- except Exception as e:
228
- logger.error(f"MFA challenge failed: {e}")
229
- return False
230
-
231
- def _get_client_id(self) -> str:
232
- if self.client_id:
233
- return self.client_id
234
-
235
- try:
236
- if not self.boto_cognito_admin_client:
237
- self.boto_cognito_admin_client(is_aws_admin=True)
238
- response = self.boto_cognito_admin_client.list_user_pool_clients(UserPoolId=self.user_pool_id, MaxResults=1)
239
-
240
- if response["UserPoolClients"]:
241
- client_id = response["UserPoolClients"][0]["ClientId"]
242
- return client_id
243
-
244
- raise Exception("No User Pool clients found")
245
-
246
- except self.cognito_client.exceptions.ResourceNotFoundException:
247
- raise Exception(f"User Pool not found or AWS credentials lack permission")
248
- except Exception as e:
249
- raise Exception(f"Failed to get Client ID: {e}")
250
-
251
- def _check_for_mfa_challenges(self, response, username: str) -> bool:
252
- if "ChallengeName" in response:
253
- challenge = response["ChallengeName"]
254
-
255
- if challenge == "MFA_SETUP":
256
- logger.error("MFA setup required")
257
- return False
258
-
259
- elif challenge == "SMS_MFA" or challenge == "SOFTWARE_TOKEN_MFA":
260
- mfa_code = input("Enter MFA code: ")
261
- _ = self.cognito_client.admin_respond_to_auth_challenge(
262
- UserPoolId=self.user_pool_id,
263
- ClientId=self.client_id,
264
- ChallengeName=challenge,
265
- Session=response["Session"],
266
- ChallengeResponses={
267
- "USERNAME": username,
268
- "SMS_MFA_CODE" if challenge == "SMS_MFA" else "SOFTWARE_TOKEN_MFA_CODE": mfa_code,
269
- },
270
- )
271
-
272
- elif challenge == "NEW_PASSWORD_REQUIRED":
273
- new_password = getpass("Enter new password: ")
274
- _ = self.cognito_client.admin_respond_to_auth_challenge(
275
- UserPoolId=self.user_pool_id,
276
- ClientId=self.client_id,
277
- ChallengeName=challenge,
278
- Session=response["Session"],
279
- ChallengeResponses={"USERNAME": username, "NEW_PASSWORD": new_password},
280
- )
281
-
282
- return False
283
-
284
- def _check_user_in_admins_group(self, id_token: str):
285
- print(jwt.__version__)
286
-
287
- claims = jwt.decode(id_token, options={"verify_signature": False})
288
- groups = claims.get("cognito:groups", [])
289
-
290
- if self.admin_group_name not in groups:
291
- raise PermissionError("User is not in ADMINS group")
292
-
293
- def _request(self, query: str, variables: Dict = None, context: str = None) -> Any | None:
294
- """
295
- Make a GraphQL request using the ID token
296
-
297
- Args:
298
- query: GraphQL query or mutation
299
- variables: Variables for the query
300
- context: Optional context string to include in error messages (e.g., row identifier)
301
-
302
- Returns:
303
- Response data
304
- """
305
- if not self.id_token:
306
- raise AuthenticationError("Not authenticated. Call authenticate() first.")
307
-
308
- headers = {"Authorization": self.id_token, "Content-Type": "application/json"}
309
-
310
- payload = {"query": query, "variables": variables or {}}
311
-
312
- context_msg = f" [{context}]" if context else ""
313
-
314
- try:
315
- response = requests.post(self.api_endpoint, headers=headers, json=payload)
316
-
317
- if response.status_code == 200:
318
- result = response.json()
319
-
320
- if "errors" in result:
321
- raise GraphQLError(f"GraphQL errors{context_msg}: {result['errors']}")
322
-
323
- return result
324
- else:
325
- logger.error(f"HTTP Error {response.status_code}{context_msg}: {response.text}")
326
- return None
327
-
328
- except requests.exceptions.ConnectionError as e:
329
- logger.error(
330
- f"Connection error{context_msg}: Unable to connect to API endpoint. Check your internet connection or the API endpoint URL."
331
- )
332
- sys.exit(1)
333
-
334
- except requests.exceptions.Timeout as e:
335
- logger.error(f"Request timeout{context_msg}: {e}")
336
- return None
337
-
338
- except requests.exceptions.HTTPError as e:
339
- logger.error(f"HTTP error{context_msg}: {e}")
340
- return None
341
-
342
- except GraphQLError as e:
343
- logger.error(str(e))
344
- return None
345
-
346
- except requests.exceptions.RequestException as e:
347
- logger.error(f"Request error{context_msg}: {e}")
348
- return None
349
-
350
- async def _request_async(
351
- self, session: aiohttp.ClientSession, query: str, variables: Dict = None, context: str = None
352
- ) -> Any | None:
353
- """
354
- Async version of _request for parallel GraphQL requests
355
-
356
- Args:
357
- session: aiohttp ClientSession
358
- query: GraphQL query or mutation
359
- variables: Variables for the query
360
- context: Optional context string to include in error messages (e.g., row identifier)
361
-
362
- Returns:
363
- Response data
364
- """
365
- if not self.id_token:
366
- raise AuthenticationError("Not authenticated. Call authenticate() first.")
367
-
368
- headers = {"Authorization": self.id_token, "Content-Type": "application/json"}
369
-
370
- payload = {"query": query, "variables": variables or {}}
371
-
372
- context_msg = f" [{context}]" if context else ""
373
-
374
- try:
375
- async with session.post(self.api_endpoint, headers=headers, json=payload) as response:
376
- if response.status == 200:
377
- result = await response.json()
378
-
379
- if "errors" in result:
380
- raise GraphQLError(f"GraphQL errors{context_msg}: {result['errors']}")
381
-
382
- return result
383
- else:
384
- text = await response.text()
385
- error_msg = f"HTTP Error {response.status}{context_msg}: {text}"
386
- logger.error(error_msg)
387
- raise aiohttp.ClientError(error_msg)
388
-
389
- except aiohttp.ServerTimeoutError as e:
390
- error_msg = f"Request timeout{context_msg}: {e}"
391
- logger.error(error_msg)
392
- raise aiohttp.ServerTimeoutError(error_msg)
393
-
394
- except aiohttp.ClientConnectionError as e:
395
- error_msg = f"Connection error{context_msg}: Unable to connect to API endpoint. {e}"
396
- logger.error(error_msg)
397
- raise aiohttp.ClientConnectionError(error_msg)
398
-
399
- except aiohttp.ClientResponseError as e:
400
- error_msg = f"HTTP response error{context_msg}: {e}"
401
- logger.error(error_msg)
402
- raise aiohttp.ClientResponseError(
403
- request_info=e.request_info, history=e.history, status=e.status, message=error_msg
404
- )
405
-
406
- except GraphQLError as e:
407
- logger.error(str(e))
408
- raise
409
-
410
- except aiohttp.ClientError as e:
411
- error_msg = f"Client error{context_msg}: {e}"
412
- logger.error(error_msg)
413
- raise aiohttp.ClientError(error_msg)
414
-
415
- async def create_record_async(
416
- self, session: aiohttp.ClientSession, data: Dict, model_name: str, primary_field: str
417
- ) -> Dict | None:
418
- mutation = f"""
419
- mutation Create{model_name}($input: Create{model_name}Input!) {{
420
- create{model_name}(input: $input) {{
421
- id
422
- {primary_field}
423
- }}
424
- }}
425
- """
426
-
427
- context = f"{model_name}: {primary_field}={data.get(primary_field)}"
428
- result = await self._request_async(session, mutation, {"input": data}, context)
429
-
430
- if result and "data" in result:
431
- created = result["data"].get(f"create{model_name}")
432
- if created:
433
- logger.info(f'Created {model_name} with {primary_field}="{data[primary_field]}" (ID: {created["id"]})')
434
- return created
435
- else:
436
- logger.error(f'Failed to create {model_name} with {primary_field}="{data[primary_field]}"')
437
-
438
- return None
439
-
440
- async def check_record_exists_async(
441
- self,
442
- session: aiohttp.ClientSession,
443
- model_name: str,
444
- primary_field: str,
445
- value: str,
446
- is_secondary_index: bool,
447
- record: Dict,
448
- field_type: str = "String",
449
- ) -> Dict | None:
450
- context = f"{model_name}: {primary_field}={value}"
451
-
452
- if is_secondary_index:
453
- query_name = f"list{model_name}By{primary_field[0].upper() + primary_field[1:]}"
454
- query = f"""
455
- query {query_name}(${primary_field}: {field_type}!) {{
456
- {query_name}({primary_field}: ${primary_field}) {{
457
- items {{
458
- id
459
- }}
460
- }}
461
- }}
462
- """
463
- result = await self._request_async(session, query, {primary_field: value}, context)
464
- if result and "data" in result:
465
- items = result["data"].get(query_name, {}).get("items", [])
466
- if len(items) > 0:
467
- logger.warning(f'Record with {primary_field}="{value}" already exists in {model_name}')
468
- return None
469
- else:
470
- query_name = self._get_list_query_name(model_name)
471
- query = f"""
472
- query List{model_name}s($filter: Model{model_name}FilterInput) {{
473
- {query_name}(filter: $filter) {{
474
- items {{
475
- id
476
- }}
477
- }}
478
- }}
479
- """
480
- filter_input = {primary_field: {"eq": value}}
481
- result = await self._request_async(session, query, {"filter": filter_input}, context)
482
- if result and "data" in result:
483
- items = result["data"].get(query_name, {}).get("items", [])
484
- if len(items) > 0:
485
- logger.error(f'Record with {primary_field}="{value}" already exists in {model_name}')
486
- return None
487
-
488
- return record
489
-
490
- async def upload_batch_async(
491
- self, batch: list, model_name: str, primary_field: str, is_secondary_index: bool, field_type: str = "String"
492
- ) -> tuple[int, int, list[Dict]]:
493
- async with aiohttp.ClientSession() as session:
494
- duplicate_checks = [
495
- self.check_record_exists_async(
496
- session, model_name, primary_field, record[primary_field], is_secondary_index, record, field_type
497
- )
498
- for record in batch
499
- ]
500
- check_results = await asyncio.gather(*duplicate_checks, return_exceptions=True)
501
-
502
- filtered_batch = []
503
- failed_records = []
504
-
505
- for i, result in enumerate(check_results):
506
- if isinstance(result, Exception):
507
- error_msg = str(result)
508
- failed_records.append(
509
- {
510
- "primary_field": primary_field,
511
- "primary_field_value": batch[i].get(primary_field, "Unknown"),
512
- "error": f"Duplicate check error: {error_msg}",
513
- }
514
- )
515
- logger.error(f"Error checking duplicate: {result}")
516
- elif result is not None:
517
- filtered_batch.append(result)
518
-
519
- if not filtered_batch:
520
- return 0, len(batch), failed_records
521
-
522
- create_tasks = [
523
- self.create_record_async(session, record, model_name, primary_field) for record in filtered_batch
524
- ]
525
- results = await asyncio.gather(*create_tasks, return_exceptions=True)
526
-
527
- for i, result in enumerate(results):
528
- if isinstance(result, Exception):
529
- error_msg = str(result)
530
- failed_records.append(
531
- {
532
- "primary_field": primary_field,
533
- "primary_field_value": filtered_batch[i].get(primary_field, "Unknown"),
534
- "error": error_msg,
535
- }
536
- )
537
- elif not result:
538
- failed_records.append(
539
- {
540
- "primary_field": primary_field,
541
- "primary_field_value": filtered_batch[i].get(primary_field, "Unknown"),
542
- "error": "Creation failed - no response",
543
- }
544
- )
545
-
546
- success_count = sum(1 for r in results if r and not isinstance(r, Exception))
547
- error_count = len(batch) - success_count
548
-
549
- return success_count, error_count, failed_records
550
-
551
- def get_model_structure(self, model_type: str) -> Dict:
552
- query = f"""
553
- query GetModelType {{
554
- __type(name: "{model_type}") {{
555
- name
556
- kind
557
- description
558
- fields {{
559
- name
560
- type {{
561
- name
562
- kind
563
- ofType {{
564
- name
565
- kind
566
- ofType {{
567
- name
568
- kind
569
- }}
570
- }}
571
- }}
572
- description
573
- }}
574
- }}
575
- }}
576
- """
577
-
578
- response = self._request(query)
579
- if response and "data" in response and "__type" in response["data"]:
580
- return response["data"]["__type"]
581
-
582
- return {}
583
-
584
- def get_primary_field_name(self, model_name: str, parsed_model_structure: Dict[str, Any]) -> tuple[str, bool, str]:
585
- """
586
- Returns: (field_name, is_secondary_index, field_type)
587
- """
588
- secondary_index = self._get_secondary_index(model_name)
589
- if secondary_index:
590
- # Find the field type in parsed_model_structure
591
- field_type = "String"
592
- for field in parsed_model_structure["fields"]:
593
- if field["name"] == secondary_index:
594
- field_type = field["type"]
595
- break
596
- return secondary_index, True, field_type
597
-
598
- for field in parsed_model_structure["fields"]:
599
- if field["is_required"] and field["is_scalar"] and field["name"] != "id":
600
- return field["name"], False, field["type"]
601
-
602
- logger.error("No suitable primary field found (required scalar field other than id)")
603
- return "", False, "String"
604
-
605
- def _get_secondary_index(self, model_name: str) -> str:
606
- query_structure = self.get_model_structure("Query")
607
- if not query_structure:
608
- logger.error("Query type not found in schema")
609
- return ""
610
-
611
- query_fields = query_structure["fields"]
612
-
613
- pattern = f"{model_name}By"
614
-
615
- for query in query_fields:
616
- query_name = query["name"]
617
- if pattern in query_name:
618
- pattern_index = query_name.index(pattern)
619
- field_name = query_name[pattern_index + len(pattern) :]
620
- return field_name[0].lower() + field_name[1:] if field_name else ""
621
-
622
- return ""
623
-
624
- def _get_list_query_name(self, model_name: str) -> str | None:
625
- query_structure = self.get_model_structure("Query")
626
- if not query_structure:
627
- logger.error("Query type not found in schema")
628
- return f"list{model_name}s"
629
-
630
- query_fields = query_structure["fields"]
631
- p = inflect.engine()
632
-
633
- candidates = [f"list{model_name}"]
634
- capitals = [i for i, c in enumerate(model_name) if c.isupper()]
635
-
636
- if len(capitals) > 1:
637
- last_word_start = capitals[-1]
638
- prefix = model_name[:last_word_start]
639
- last_word = model_name[last_word_start:]
640
-
641
- last_word_plural = str(p.plural(last_word.lower())) # type: ignore[arg-type]
642
- last_word_plural_cap = last_word_plural[0].upper() + last_word_plural[1:] if last_word_plural else ""
643
-
644
- pascal_plural = f"{prefix}{last_word_plural_cap}"
645
- candidates.append(f"list{pascal_plural}")
646
-
647
- full_plural = str(p.plural(model_name.lower())) # type: ignore[arg-type]
648
- full_plural_cap = full_plural[0].upper() + full_plural[1:] if full_plural else ""
649
- candidates.append(f"list{full_plural_cap}")
650
-
651
- for query in query_fields:
652
- query_name = query["name"]
653
- if query_name in candidates and "By" not in query_name:
654
- return query_name
655
-
656
- logger.error(f"No list query found for model {model_name}, tried: {candidates}")
657
- return None
658
-
659
- def upload(
660
- self, records: list, model_name: str, parsed_model_structure: Dict[str, Any]
661
- ) -> tuple[int, int, list[Dict]]:
662
- logger.info("Uploading to Amplify backend...")
663
-
664
- success_count = 0
665
- error_count = 0
666
- all_failed_records = []
667
- num_of_batches = (len(records) + self.batch_size - 1) // self.batch_size
668
-
669
- primary_field, is_secondary_index, field_type = self.get_primary_field_name(model_name, parsed_model_structure)
670
- if not primary_field:
671
- logger.error(f"Aborting upload for model {model_name}")
672
- return 0, len(records), []
673
-
674
- for i in range(0, len(records), self.batch_size):
675
- batch = records[i : i + self.batch_size]
676
- logger.info(f"Uploading batch {i // self.batch_size + 1} / {num_of_batches} ({len(batch)} items)...")
677
-
678
- batch_success, batch_error, batch_failed_records = asyncio.run(
679
- self.upload_batch_async(batch, model_name, primary_field, is_secondary_index, field_type)
680
- )
681
- success_count += batch_success
682
- error_count += batch_error
683
- all_failed_records.extend(batch_failed_records)
684
-
685
- logger.info(
686
- f"Processed batch {i // self.batch_size + 1} of model {model_name}: {success_count} success, {error_count} errors"
687
- )
688
-
689
- return success_count, error_count, all_failed_records
690
-
691
- def list_records_by_secondary_index(
692
- self, model_name: str, secondary_index: str, value: str = None, fields: list = None, field_type: str = "String"
693
- ) -> Dict | None:
694
- if fields is None:
695
- fields = ["id", secondary_index]
696
-
697
- fields_str = "\n".join(fields)
698
- all_items = []
699
- next_token = None
700
-
701
- if not value:
702
- query_name = self._get_list_query_name(model_name)
703
-
704
- while True:
705
- query = f"""
706
- query List{model_name}s($limit: Int, $nextToken: String) {{
707
- {query_name}(limit: $limit, nextToken: $nextToken) {{
708
- items {{
709
- {fields_str}
710
- }}
711
- nextToken
712
- }}
713
- }}
714
- """
715
- variables = {"limit": 1000, "nextToken": next_token}
716
- result = self._request(query, variables)
717
-
718
- if result and "data" in result:
719
- data = result["data"].get(query_name, {})
720
- items = data.get("items", [])
721
- all_items.extend(items)
722
- next_token = data.get("nextToken")
723
-
724
- if not next_token:
725
- break
726
- else:
727
- break
728
- else:
729
- query_name = f"list{model_name}By{secondary_index[0].upper() + secondary_index[1:]}"
730
-
731
- while True:
732
- query = f"""
733
- query {query_name}(${secondary_index}: {field_type}!, $limit: Int, $nextToken: String) {{
734
- {query_name}({secondary_index}: ${secondary_index}, limit: $limit, nextToken: $nextToken) {{
735
- items {{
736
- {fields_str}
737
- }}
738
- nextToken
739
- }}
740
- }}
741
- """
742
- variables = {secondary_index: value, "limit": 1000, "nextToken": next_token}
743
- result = self._request(query, variables)
744
-
745
- if result and "data" in result:
746
- data = result["data"].get(query_name, {})
747
- items = data.get("items", [])
748
- all_items.extend(items)
749
- next_token = data.get("nextToken")
750
-
751
- if not next_token:
752
- break
753
- else:
754
- break
755
-
756
- return all_items if all_items else None
757
-
758
- def list_records_by_field(
759
- self, model_name: str, field_name: str, value: str = None, fields: list = None
760
- ) -> Dict | None:
761
- if fields is None:
762
- fields = ["id", field_name]
763
-
764
- fields_str = "\n".join(fields)
765
- all_items = []
766
- next_token = None
767
-
768
- query_name = self._get_list_query_name(model_name)
769
-
770
- if not value:
771
- while True:
772
- query = f"""
773
- query List{model_name}s($limit: Int, $nextToken: String) {{
774
- {query_name}(limit: $limit, nextToken: $nextToken) {{
775
- items {{
776
- {fields_str}
777
- }}
778
- nextToken
779
- }}
780
- }}
781
- """
782
- variables = {"limit": 1000, "nextToken": next_token}
783
- result = self._request(query, variables)
784
-
785
- if result and "data" in result:
786
- data = result["data"].get(query_name, {})
787
- items = data.get("items", [])
788
- all_items.extend(items)
789
- next_token = data.get("nextToken")
790
-
791
- if not next_token:
792
- break
793
- else:
794
- break
795
- else:
796
- while True:
797
- query = f"""
798
- query List{model_name}s($filter: Model{model_name}FilterInput, $limit: Int, $nextToken: String) {{
799
- {query_name}(filter: $filter, limit: $limit, nextToken: $nextToken) {{
800
- items {{
801
- {fields_str}
802
- }}
803
- nextToken
804
- }}
805
- }}
806
- """
807
- filter_input = {field_name: {"eq": value}}
808
- variables = {"filter": filter_input, "limit": 1000, "nextToken": next_token}
809
- result = self._request(query, variables)
810
-
811
- if result and "data" in result:
812
- data = result["data"].get(query_name, {})
813
- items = data.get("items", [])
814
- all_items.extend(items)
815
- next_token = data.get("nextToken")
816
-
817
- if not next_token:
818
- break
819
- else:
820
- break
821
-
822
- return all_items if all_items else None
823
-
824
- def get_record_by_id(self, model_name: str, record_id: str, fields: list = None) -> Dict | None:
825
- if fields is None:
826
- fields = ["id"]
827
-
828
- fields_str = "\n".join(fields)
829
-
830
- query_name = f"get{model_name}"
831
- query = f"""
832
- query Get{model_name}($id: ID!) {{
833
- {query_name}(id: $id) {{
834
- {fields_str}
835
- }}
836
- }}
837
- """
838
-
839
- result = self._request(query, {"id": record_id})
840
-
841
- if result and "data" in result:
842
- return result["data"].get(query_name)
843
-
844
- return None
845
-
846
- def get_records(
847
- self,
848
- model_name: str,
849
- primary_field: str = None,
850
- is_secondary_index: bool = None,
851
- fields: list = None,
852
- ) -> list | None:
853
- if model_name in self.records_cache:
854
- return self.records_cache[model_name]
855
-
856
- if not primary_field:
857
- return None
858
- if is_secondary_index:
859
- records = self.list_records_by_secondary_index(model_name, primary_field, fields=fields)
860
- else:
861
- records = self.list_records_by_field(model_name, primary_field, fields=fields)
862
-
863
- if records:
864
- self.records_cache[model_name] = records
865
- logger.debug(f"💾 Cached {len(records)} records for {model_name}")
866
- return records
867
-
868
- def get_record(
869
- self,
870
- model_name: str,
871
- parsed_model_structure: Dict[str, Any] = None,
872
- value: str = None,
873
- record_id: str = None,
874
- primary_field: str = None,
875
- is_secondary_index: bool = None,
876
- fields: list = None,
877
- ) -> Dict | None:
878
- if record_id:
879
- return self.get_record_by_id(model_name, record_id)
880
-
881
- if not primary_field:
882
- if not parsed_model_structure:
883
- logger.error("Parsed model structure required if primary_field not provided")
884
- return None
885
- primary_field, is_secondary_index, _ = self.get_primary_field_name(model_name, parsed_model_structure)
886
- records = self.get_records(model_name, primary_field, is_secondary_index, fields)
887
- if not records:
888
- return None
889
- return next((record for record in records if record.get(primary_field) == value), None)
890
-
891
- def build_foreign_key_lookups(self, df, parsed_model_structure: Dict[str, Any]) -> Dict[str, Dict[str, str]]:
892
- """
893
- Build a cache of foreign key lookups for all ID fields in the DataFrame.
894
-
895
- This pre-fetches all related records to avoid N+1 query problems during row processing.
896
-
897
- Args:
898
- df: pandas DataFrame containing the data to be processed
899
- parsed_model_structure: Parsed model structure containing field information
900
-
901
- Returns:
902
- Dictionary mapping model names to lookup dictionaries and primary fields
903
- """
904
-
905
- fk_lookup_cache = {}
906
-
907
- for field in parsed_model_structure["fields"]:
908
- if not field["is_id"]:
909
- continue
910
-
911
- field_name = field["name"][:-2]
912
-
913
- if field_name not in df.columns:
914
- continue
915
-
916
- if "related_model" in field:
917
- related_model = field["related_model"]
918
- else:
919
- related_model = field_name[0].upper() + field_name[1:]
920
-
921
- if related_model in fk_lookup_cache:
922
- continue
923
-
924
- try:
925
- primary_field, is_secondary_index, _ = self.get_primary_field_name(
926
- related_model, parsed_model_structure
927
- )
928
- records = self.get_records(related_model, primary_field, is_secondary_index)
929
-
930
- if records:
931
- lookup = {
932
- str(record.get(primary_field)): record.get("id")
933
- for record in records
934
- if record.get(primary_field)
935
- }
936
- fk_lookup_cache[related_model] = {"lookup": lookup, "primary_field": primary_field}
937
- logger.debug(f" 📦 Cached {len(lookup)} {related_model} records")
938
- except Exception as e:
939
- logger.warning(f" ⚠️ Could not pre-fetch {related_model}: {e}")
940
-
941
- return fk_lookup_cache