amplify-excel-migrator 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of amplify-excel-migrator might be problematic. Click here for more details.
- amplify_client.py +736 -0
- amplify_excel_migrator-1.0.0.dist-info/METADATA +204 -0
- amplify_excel_migrator-1.0.0.dist-info/RECORD +9 -0
- amplify_excel_migrator-1.0.0.dist-info/WHEEL +5 -0
- amplify_excel_migrator-1.0.0.dist-info/entry_points.txt +2 -0
- amplify_excel_migrator-1.0.0.dist-info/licenses/LICENSE +21 -0
- amplify_excel_migrator-1.0.0.dist-info/top_level.txt +3 -0
- migrator.py +301 -0
- model_field_parser.py +134 -0
amplify_client.py
ADDED
|
@@ -0,0 +1,736 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import logging
|
|
3
|
+
import sys
|
|
4
|
+
from getpass import getpass
|
|
5
|
+
from typing import Dict, Any
|
|
6
|
+
|
|
7
|
+
import aiohttp
|
|
8
|
+
import boto3
|
|
9
|
+
import requests
|
|
10
|
+
import jwt
|
|
11
|
+
from botocore.exceptions import NoCredentialsError, ProfileNotFound, NoRegionError, ClientError
|
|
12
|
+
from pycognito import Cognito, MFAChallengeException
|
|
13
|
+
from pycognito.exceptions import ForceChangePasswordException
|
|
14
|
+
|
|
15
|
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class AmplifyClient:
|
|
20
|
+
"""
|
|
21
|
+
Client for Amplify GraphQL using ADMIN_USER_PASSWORD_AUTH flow
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(self,
|
|
25
|
+
api_endpoint: str,
|
|
26
|
+
user_pool_id: str,
|
|
27
|
+
region: str,
|
|
28
|
+
client_id: str):
|
|
29
|
+
"""
|
|
30
|
+
Initialize the client
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
api_endpoint: Amplify GraphQL endpoint
|
|
34
|
+
user_pool_id: Cognito User Pool ID
|
|
35
|
+
region: AWS region
|
|
36
|
+
client_id: Cognito App Client ID
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
self.api_endpoint = api_endpoint
|
|
40
|
+
self.user_pool_id = user_pool_id
|
|
41
|
+
self.region = region
|
|
42
|
+
self.client_id = client_id
|
|
43
|
+
|
|
44
|
+
self.batch_size = 10
|
|
45
|
+
self.cognito_client = None
|
|
46
|
+
self.boto_cognito_admin_client = None
|
|
47
|
+
self.id_token = None
|
|
48
|
+
self.mfa_tokens = None
|
|
49
|
+
self.admin_group_name = 'ADMINS'
|
|
50
|
+
|
|
51
|
+
self.records_cache = {}
|
|
52
|
+
|
|
53
|
+
def init_cognito_client(self, is_aws_admin: bool, username: str = None, aws_profile: str = None):
|
|
54
|
+
try:
|
|
55
|
+
if is_aws_admin:
|
|
56
|
+
if aws_profile:
|
|
57
|
+
session = boto3.Session(profile_name=aws_profile)
|
|
58
|
+
self.boto_cognito_admin_client = session.client('cognito-idp', region_name=self.region)
|
|
59
|
+
else:
|
|
60
|
+
# Use default AWS credentials (from ~/.aws/credentials, env vars, or IAM role)
|
|
61
|
+
self.boto_cognito_admin_client = boto3.client('cognito-idp', region_name=self.region)
|
|
62
|
+
|
|
63
|
+
else:
|
|
64
|
+
self.cognito_client = Cognito(
|
|
65
|
+
user_pool_id=self.user_pool_id,
|
|
66
|
+
client_id=self.client_id,
|
|
67
|
+
user_pool_region=self.region,
|
|
68
|
+
username=username
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
except NoCredentialsError:
|
|
72
|
+
logger.error("AWS credentials not found. Please configure AWS credentials.")
|
|
73
|
+
logger.error("Options: 1) AWS CLI: 'aws configure', 2) Environment variables, 3) Pass credentials directly")
|
|
74
|
+
raise RuntimeError(
|
|
75
|
+
"Failed to initialize client: No AWS credentials found. "
|
|
76
|
+
"Run 'aws configure' or set AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables."
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
except ProfileNotFound:
|
|
80
|
+
logger.error(f"AWS profile '{aws_profile}' not found")
|
|
81
|
+
raise RuntimeError(
|
|
82
|
+
f"Failed to initialize client: AWS profile '{aws_profile}' not found. "
|
|
83
|
+
f"Available profiles can be found in ~/.aws/credentials"
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
except NoRegionError:
|
|
87
|
+
logger.error("No AWS region specified")
|
|
88
|
+
raise RuntimeError(
|
|
89
|
+
f"Failed to initialize client: No AWS region specified. "
|
|
90
|
+
f"Provide region parameter or set AWS_DEFAULT_REGION environment variable."
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
except ValueError as e:
|
|
94
|
+
logger.error(f"Invalid parameter: {e}")
|
|
95
|
+
raise
|
|
96
|
+
|
|
97
|
+
except ClientError as e:
|
|
98
|
+
error_code = e.response.get('Error', {}).get('Code', 'Unknown')
|
|
99
|
+
error_msg = e.response.get('Error', {}).get('Message', str(e))
|
|
100
|
+
logger.error(f"AWS Client Error [{error_code}]: {error_msg}")
|
|
101
|
+
raise RuntimeError(f"Failed to initialize client: AWS error [{error_code}]: {error_msg}")
|
|
102
|
+
|
|
103
|
+
except Exception as e:
|
|
104
|
+
logger.error(f"Error during client initialization: {e}")
|
|
105
|
+
raise RuntimeError(f"Failed to initialize client: {e}")
|
|
106
|
+
|
|
107
|
+
def authenticate(self, username: str, password: str, mfa_code: str = None) -> bool:
|
|
108
|
+
try:
|
|
109
|
+
if not self.cognito_client:
|
|
110
|
+
self.init_cognito_client(is_aws_admin=False, username=username)
|
|
111
|
+
|
|
112
|
+
if mfa_code and self.mfa_tokens:
|
|
113
|
+
if not self._complete_mfa_challenge(mfa_code):
|
|
114
|
+
return False
|
|
115
|
+
else:
|
|
116
|
+
self.cognito_client.authenticate(password=password)
|
|
117
|
+
|
|
118
|
+
self.id_token = self.cognito_client.id_token
|
|
119
|
+
|
|
120
|
+
self._check_user_in_admins_group(self.id_token)
|
|
121
|
+
|
|
122
|
+
logger.info("✅ Authentication successful")
|
|
123
|
+
return True
|
|
124
|
+
|
|
125
|
+
except MFAChallengeException as e:
|
|
126
|
+
logger.warning("MFA required")
|
|
127
|
+
if hasattr(e, 'get_tokens'):
|
|
128
|
+
self.mfa_tokens = e.get_tokens()
|
|
129
|
+
|
|
130
|
+
mfa_code = input("Enter MFA code: ").strip()
|
|
131
|
+
if mfa_code:
|
|
132
|
+
return self.authenticate(username, password, mfa_code)
|
|
133
|
+
else:
|
|
134
|
+
logger.error("MFA code required but not provided")
|
|
135
|
+
return False
|
|
136
|
+
else:
|
|
137
|
+
logger.error("MFA challenge received but no session tokens available")
|
|
138
|
+
return False
|
|
139
|
+
|
|
140
|
+
except ForceChangePasswordException:
|
|
141
|
+
logger.warning("Password change required")
|
|
142
|
+
new_password = input("Your password has expired. Enter new password: ").strip()
|
|
143
|
+
confirm_password = input("Confirm new password: ").strip()
|
|
144
|
+
if new_password != confirm_password:
|
|
145
|
+
logger.error("Passwords do not match")
|
|
146
|
+
return False
|
|
147
|
+
|
|
148
|
+
try:
|
|
149
|
+
self.cognito_client.new_password_challenge(password, new_password)
|
|
150
|
+
return self.authenticate(username, new_password)
|
|
151
|
+
|
|
152
|
+
except Exception as e:
|
|
153
|
+
logger.error(f"Failed to change password: {e}")
|
|
154
|
+
return False
|
|
155
|
+
|
|
156
|
+
except Exception as e:
|
|
157
|
+
logger.error(f"Authentication failed: {e}")
|
|
158
|
+
return False
|
|
159
|
+
|
|
160
|
+
def aws_admin_authenticate(self, username: str, password: str) -> bool:
|
|
161
|
+
"""
|
|
162
|
+
Requires AWS credentials with cognito-idp:ListUserPoolClients permission
|
|
163
|
+
"""
|
|
164
|
+
try:
|
|
165
|
+
if not self.boto_cognito_admin_client:
|
|
166
|
+
self.init_cognito_client(is_aws_admin=True)
|
|
167
|
+
|
|
168
|
+
print(f"Authenticating {username} using ADMIN_USER_PASSWORD_AUTH flow...")
|
|
169
|
+
|
|
170
|
+
response = self.boto_cognito_admin_client.admin_initiate_auth(
|
|
171
|
+
UserPoolId=self.user_pool_id,
|
|
172
|
+
ClientId=self.client_id,
|
|
173
|
+
AuthFlow='ADMIN_USER_PASSWORD_AUTH',
|
|
174
|
+
AuthParameters={
|
|
175
|
+
'USERNAME': username,
|
|
176
|
+
'PASSWORD': password
|
|
177
|
+
}
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
self._check_for_mfa_challenges(response, username)
|
|
181
|
+
|
|
182
|
+
if 'AuthenticationResult' in response:
|
|
183
|
+
self.id_token = response['AuthenticationResult']['IdToken']
|
|
184
|
+
else:
|
|
185
|
+
logger.error("❌ Authentication failed: No AuthenticationResult in response")
|
|
186
|
+
return False
|
|
187
|
+
|
|
188
|
+
self._check_user_in_admins_group(self.id_token)
|
|
189
|
+
|
|
190
|
+
print(f"✅ Authentication successful")
|
|
191
|
+
return True
|
|
192
|
+
|
|
193
|
+
except self.cognito_client.exceptions.NotAuthorizedException as e:
|
|
194
|
+
logger.error(f"❌ Authentication failed: {e}")
|
|
195
|
+
return False
|
|
196
|
+
|
|
197
|
+
except self.cognito_client.exceptions.UserNotFoundException:
|
|
198
|
+
logger.error(f"❌ User not found: {username}")
|
|
199
|
+
return False
|
|
200
|
+
|
|
201
|
+
except Exception as e:
|
|
202
|
+
logger.error(f"❌ Error during authentication: {e}")
|
|
203
|
+
return False
|
|
204
|
+
|
|
205
|
+
def _complete_mfa_challenge(self, mfa_code: str) -> bool:
|
|
206
|
+
try:
|
|
207
|
+
if not self.mfa_tokens:
|
|
208
|
+
logger.error("No MFA session tokens available")
|
|
209
|
+
return False
|
|
210
|
+
|
|
211
|
+
challenge_name = self.mfa_tokens.get('ChallengeName', 'SMS_MFA')
|
|
212
|
+
|
|
213
|
+
if 'SOFTWARE_TOKEN' in challenge_name:
|
|
214
|
+
self.cognito_client.respond_to_software_token_mfa_challenge(
|
|
215
|
+
code=mfa_code,
|
|
216
|
+
mfa_tokens=self.mfa_tokens
|
|
217
|
+
)
|
|
218
|
+
else:
|
|
219
|
+
self.cognito_client.respond_to_sms_mfa_challenge(
|
|
220
|
+
code=mfa_code,
|
|
221
|
+
mfa_tokens=self.mfa_tokens
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
logger.info("✅ MFA challenge successful")
|
|
225
|
+
return True
|
|
226
|
+
|
|
227
|
+
except Exception as e:
|
|
228
|
+
logger.error(f"MFA challenge failed: {e}")
|
|
229
|
+
return False
|
|
230
|
+
|
|
231
|
+
def _get_client_id(self) -> str:
|
|
232
|
+
if self.client_id:
|
|
233
|
+
return self.client_id
|
|
234
|
+
|
|
235
|
+
try:
|
|
236
|
+
if not self.boto_cognito_admin_client:
|
|
237
|
+
self.boto_cognito_admin_client(is_aws_admin=True)
|
|
238
|
+
response = self.boto_cognito_admin_client.list_user_pool_clients(
|
|
239
|
+
UserPoolId=self.user_pool_id,
|
|
240
|
+
MaxResults=1
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
if response['UserPoolClients']:
|
|
244
|
+
client_id = response['UserPoolClients'][0]['ClientId']
|
|
245
|
+
return client_id
|
|
246
|
+
|
|
247
|
+
raise Exception("No User Pool clients found")
|
|
248
|
+
|
|
249
|
+
except self.cognito_client.exceptions.ResourceNotFoundException:
|
|
250
|
+
raise Exception(f"User Pool not found or AWS credentials lack permission")
|
|
251
|
+
except Exception as e:
|
|
252
|
+
raise Exception(f"Failed to get Client ID: {e}")
|
|
253
|
+
|
|
254
|
+
def _check_for_mfa_challenges(self, response, username: str) -> bool:
|
|
255
|
+
if 'ChallengeName' in response:
|
|
256
|
+
challenge = response['ChallengeName']
|
|
257
|
+
|
|
258
|
+
if challenge == 'MFA_SETUP':
|
|
259
|
+
logger.error("MFA setup required")
|
|
260
|
+
return False
|
|
261
|
+
|
|
262
|
+
elif challenge == 'SMS_MFA' or challenge == 'SOFTWARE_TOKEN_MFA':
|
|
263
|
+
mfa_code = input("Enter MFA code: ")
|
|
264
|
+
_ = self.cognito_client.admin_respond_to_auth_challenge(
|
|
265
|
+
UserPoolId=self.user_pool_id,
|
|
266
|
+
ClientId=self.client_id,
|
|
267
|
+
ChallengeName=challenge,
|
|
268
|
+
Session=response['Session'],
|
|
269
|
+
ChallengeResponses={
|
|
270
|
+
'USERNAME': username,
|
|
271
|
+
'SMS_MFA_CODE' if challenge == 'SMS_MFA' else 'SOFTWARE_TOKEN_MFA_CODE': mfa_code
|
|
272
|
+
}
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
elif challenge == 'NEW_PASSWORD_REQUIRED':
|
|
276
|
+
new_password = getpass("Enter new password: ")
|
|
277
|
+
_ = self.cognito_client.admin_respond_to_auth_challenge(
|
|
278
|
+
UserPoolId=self.user_pool_id,
|
|
279
|
+
ClientId=self.client_id,
|
|
280
|
+
ChallengeName=challenge,
|
|
281
|
+
Session=response['Session'],
|
|
282
|
+
ChallengeResponses={
|
|
283
|
+
'USERNAME': username,
|
|
284
|
+
'NEW_PASSWORD': new_password
|
|
285
|
+
}
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
return False
|
|
289
|
+
|
|
290
|
+
def _check_user_in_admins_group(self, id_token: str):
|
|
291
|
+
print(jwt.__version__)
|
|
292
|
+
|
|
293
|
+
claims = jwt.decode(id_token, options={"verify_signature": False})
|
|
294
|
+
groups = claims.get("cognito:groups", [])
|
|
295
|
+
|
|
296
|
+
if self.admin_group_name not in groups:
|
|
297
|
+
raise PermissionError("User is not in ADMINS group")
|
|
298
|
+
|
|
299
|
+
def _request(self, query: str, variables: Dict = None) -> Any | None:
|
|
300
|
+
"""
|
|
301
|
+
Make a GraphQL request using the ID token
|
|
302
|
+
|
|
303
|
+
Args:
|
|
304
|
+
query: GraphQL query or mutation
|
|
305
|
+
variables: Variables for the query
|
|
306
|
+
|
|
307
|
+
Returns:
|
|
308
|
+
Response data
|
|
309
|
+
"""
|
|
310
|
+
if not self.id_token:
|
|
311
|
+
raise Exception("Not authenticated. Call authenticate() first.")
|
|
312
|
+
|
|
313
|
+
headers = {
|
|
314
|
+
'Authorization': self.id_token,
|
|
315
|
+
'Content-Type': 'application/json'
|
|
316
|
+
}
|
|
317
|
+
|
|
318
|
+
payload = {
|
|
319
|
+
'query': query,
|
|
320
|
+
'variables': variables or {}
|
|
321
|
+
}
|
|
322
|
+
|
|
323
|
+
try:
|
|
324
|
+
response = requests.post(
|
|
325
|
+
self.api_endpoint,
|
|
326
|
+
headers=headers,
|
|
327
|
+
json=payload
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
if response.status_code == 200:
|
|
331
|
+
result = response.json()
|
|
332
|
+
|
|
333
|
+
if 'errors' in result:
|
|
334
|
+
logger.error(f"GraphQL errors: {result['errors']}")
|
|
335
|
+
return None
|
|
336
|
+
|
|
337
|
+
return result
|
|
338
|
+
else:
|
|
339
|
+
logger.error(f"HTTP Error {response.status_code}: {response.text}")
|
|
340
|
+
return None
|
|
341
|
+
|
|
342
|
+
except Exception as e:
|
|
343
|
+
if 'NameResolutionError' in str(e):
|
|
344
|
+
logger.error(
|
|
345
|
+
f"Connection error: Unable to resolve hostname. Check your internet connection or the API endpoint URL.")
|
|
346
|
+
sys.exit(1)
|
|
347
|
+
else:
|
|
348
|
+
logger.error(f"Request error: {e}")
|
|
349
|
+
return None
|
|
350
|
+
|
|
351
|
+
async def _request_async(self, session: aiohttp.ClientSession, query: str, variables: Dict = None) -> Any | None:
|
|
352
|
+
"""
|
|
353
|
+
Async version of _request for parallel GraphQL requests
|
|
354
|
+
|
|
355
|
+
Args:
|
|
356
|
+
session: aiohttp ClientSession
|
|
357
|
+
query: GraphQL query or mutation
|
|
358
|
+
variables: Variables for the query
|
|
359
|
+
|
|
360
|
+
Returns:
|
|
361
|
+
Response data
|
|
362
|
+
"""
|
|
363
|
+
if not self.id_token:
|
|
364
|
+
raise Exception("Not authenticated. Call authenticate() first.")
|
|
365
|
+
|
|
366
|
+
headers = {
|
|
367
|
+
'Authorization': self.id_token,
|
|
368
|
+
'Content-Type': 'application/json'
|
|
369
|
+
}
|
|
370
|
+
|
|
371
|
+
payload = {
|
|
372
|
+
'query': query,
|
|
373
|
+
'variables': variables or {}
|
|
374
|
+
}
|
|
375
|
+
|
|
376
|
+
try:
|
|
377
|
+
async with session.post(self.api_endpoint, headers=headers, json=payload) as response:
|
|
378
|
+
if response.status == 200:
|
|
379
|
+
result = await response.json()
|
|
380
|
+
|
|
381
|
+
if 'errors' in result:
|
|
382
|
+
logger.error(f"GraphQL errors: {result['errors']}")
|
|
383
|
+
return None
|
|
384
|
+
|
|
385
|
+
return result
|
|
386
|
+
else:
|
|
387
|
+
text = await response.text()
|
|
388
|
+
logger.error(f"HTTP Error {response.status}: {text}")
|
|
389
|
+
return None
|
|
390
|
+
except Exception as e:
|
|
391
|
+
logger.error(f"Request error: {e}")
|
|
392
|
+
return None
|
|
393
|
+
|
|
394
|
+
async def create_record_async(self, session: aiohttp.ClientSession, data: Dict, model_name: str,
|
|
395
|
+
primary_field: str) -> Dict | None:
|
|
396
|
+
mutation = f"""
|
|
397
|
+
mutation Create{model_name}($input: Create{model_name}Input!) {{
|
|
398
|
+
create{model_name}(input: $input) {{
|
|
399
|
+
id
|
|
400
|
+
{primary_field}
|
|
401
|
+
}}
|
|
402
|
+
}}
|
|
403
|
+
"""
|
|
404
|
+
|
|
405
|
+
result = await self._request_async(session, mutation, {'input': data})
|
|
406
|
+
|
|
407
|
+
if result and 'data' in result:
|
|
408
|
+
created = result['data'].get(f'create{model_name}')
|
|
409
|
+
if created:
|
|
410
|
+
logger.info(f'Created {model_name} with {primary_field}="{data[primary_field]}" (ID: {created["id"]})')
|
|
411
|
+
return created
|
|
412
|
+
|
|
413
|
+
return None
|
|
414
|
+
|
|
415
|
+
async def check_record_exists_async(self, session: aiohttp.ClientSession, model_name: str,
|
|
416
|
+
primary_field: str, value: str, is_secondary_index: bool,
|
|
417
|
+
record: Dict) -> Dict | None:
|
|
418
|
+
if is_secondary_index:
|
|
419
|
+
query_name = f"list{model_name}By{primary_field.capitalize()}"
|
|
420
|
+
query = f"""
|
|
421
|
+
query {query_name}(${primary_field}: String!) {{
|
|
422
|
+
{query_name}({primary_field}: ${primary_field}) {{
|
|
423
|
+
items {{
|
|
424
|
+
id
|
|
425
|
+
}}
|
|
426
|
+
}}
|
|
427
|
+
}}
|
|
428
|
+
"""
|
|
429
|
+
result = await self._request_async(session, query, {primary_field: value})
|
|
430
|
+
if result and 'data' in result:
|
|
431
|
+
items = result['data'].get(query_name, {}).get('items', [])
|
|
432
|
+
if len(items) > 0:
|
|
433
|
+
logger.error(
|
|
434
|
+
f'Record with {primary_field}="{value}" already exists in {model_name}')
|
|
435
|
+
return None
|
|
436
|
+
else:
|
|
437
|
+
query_name = self._get_list_query_name(model_name)
|
|
438
|
+
query = f"""
|
|
439
|
+
query List{model_name}s($filter: Model{model_name}FilterInput) {{
|
|
440
|
+
{query_name}(filter: $filter) {{
|
|
441
|
+
items {{
|
|
442
|
+
id
|
|
443
|
+
}}
|
|
444
|
+
}}
|
|
445
|
+
}}
|
|
446
|
+
"""
|
|
447
|
+
filter_input = {primary_field: {"eq": value}}
|
|
448
|
+
result = await self._request_async(session, query, {"filter": filter_input})
|
|
449
|
+
if result and 'data' in result:
|
|
450
|
+
items = result['data'].get(query_name, {}).get('items', [])
|
|
451
|
+
if len(items) > 0:
|
|
452
|
+
logger.error(
|
|
453
|
+
f'Record with {primary_field}="{value}" already exists in {model_name}')
|
|
454
|
+
return None
|
|
455
|
+
|
|
456
|
+
return record
|
|
457
|
+
|
|
458
|
+
async def upload_batch_async(self, batch: list, model_name: str, primary_field: str,
|
|
459
|
+
is_secondary_index: bool) -> tuple[int, int]:
|
|
460
|
+
async with aiohttp.ClientSession() as session:
|
|
461
|
+
duplicate_checks = [
|
|
462
|
+
self.check_record_exists_async(session, model_name, primary_field,
|
|
463
|
+
record[primary_field], is_secondary_index, record)
|
|
464
|
+
for record in batch
|
|
465
|
+
]
|
|
466
|
+
check_results = await asyncio.gather(*duplicate_checks, return_exceptions=True)
|
|
467
|
+
|
|
468
|
+
filtered_batch = []
|
|
469
|
+
for result in check_results:
|
|
470
|
+
if isinstance(result, Exception):
|
|
471
|
+
logger.error(f"Error checking duplicate: {result}")
|
|
472
|
+
elif result is not None:
|
|
473
|
+
filtered_batch.append(result)
|
|
474
|
+
|
|
475
|
+
if not filtered_batch:
|
|
476
|
+
return 0, len(batch)
|
|
477
|
+
|
|
478
|
+
create_tasks = [
|
|
479
|
+
self.create_record_async(session, record, model_name, primary_field)
|
|
480
|
+
for record in filtered_batch
|
|
481
|
+
]
|
|
482
|
+
results = await asyncio.gather(*create_tasks, return_exceptions=True)
|
|
483
|
+
|
|
484
|
+
success_count = sum(1 for r in results if r and not isinstance(r, Exception))
|
|
485
|
+
error_count = len(batch) - success_count
|
|
486
|
+
|
|
487
|
+
return success_count, error_count
|
|
488
|
+
|
|
489
|
+
def get_model_structure(self, model_type: str) -> Dict:
|
|
490
|
+
query = f"""
|
|
491
|
+
query GetModelType {{
|
|
492
|
+
__type(name: "{model_type}") {{
|
|
493
|
+
name
|
|
494
|
+
kind
|
|
495
|
+
description
|
|
496
|
+
fields {{
|
|
497
|
+
name
|
|
498
|
+
type {{
|
|
499
|
+
name
|
|
500
|
+
kind
|
|
501
|
+
ofType {{
|
|
502
|
+
name
|
|
503
|
+
kind
|
|
504
|
+
ofType {{
|
|
505
|
+
name
|
|
506
|
+
kind
|
|
507
|
+
}}
|
|
508
|
+
}}
|
|
509
|
+
}}
|
|
510
|
+
description
|
|
511
|
+
}}
|
|
512
|
+
}}
|
|
513
|
+
}}
|
|
514
|
+
"""
|
|
515
|
+
|
|
516
|
+
response = self._request(query)
|
|
517
|
+
if response and 'data' in response and '__type' in response['data']:
|
|
518
|
+
return response['data']['__type']
|
|
519
|
+
|
|
520
|
+
return {}
|
|
521
|
+
|
|
522
|
+
def get_primary_field_name(self, model_name: str, parsed_model_structure: Dict[str, Any]) -> (
|
|
523
|
+
tuple[str, bool]):
|
|
524
|
+
secondary_index = self._get_secondary_index(model_name)
|
|
525
|
+
if secondary_index:
|
|
526
|
+
return secondary_index, True
|
|
527
|
+
|
|
528
|
+
for field in parsed_model_structure['fields']:
|
|
529
|
+
if field['is_required'] and field['is_scalar'] and field['name'] != 'id':
|
|
530
|
+
return field['name'], False
|
|
531
|
+
|
|
532
|
+
logger.error('No suitable primary field found (required scalar field other than id)')
|
|
533
|
+
return '', False
|
|
534
|
+
|
|
535
|
+
def _get_secondary_index(self, model_name: str) -> str:
|
|
536
|
+
query_structure = self.get_model_structure("Query")
|
|
537
|
+
if not query_structure:
|
|
538
|
+
logger.error("Query type not found in schema")
|
|
539
|
+
return ''
|
|
540
|
+
|
|
541
|
+
query_fields = query_structure['fields']
|
|
542
|
+
|
|
543
|
+
pattern = f"{model_name}By"
|
|
544
|
+
|
|
545
|
+
for query in query_fields:
|
|
546
|
+
query_name = query['name']
|
|
547
|
+
if pattern in query_name:
|
|
548
|
+
pattern_index = query_name.index(pattern)
|
|
549
|
+
field_name = query_name[pattern_index + len(pattern):]
|
|
550
|
+
return field_name[0].lower() + field_name[1:] if field_name else ''
|
|
551
|
+
|
|
552
|
+
return ''
|
|
553
|
+
|
|
554
|
+
def _get_list_query_name(self, model_name: str) -> str | None:
|
|
555
|
+
"""Get the correct list query name from the schema (handles pluralization)"""
|
|
556
|
+
query_structure = self.get_model_structure("Query")
|
|
557
|
+
if not query_structure:
|
|
558
|
+
logger.error("Query type not found in schema")
|
|
559
|
+
return f"list{model_name}s"
|
|
560
|
+
|
|
561
|
+
query_fields = query_structure['fields']
|
|
562
|
+
candidates = [
|
|
563
|
+
f"list{model_name}s",
|
|
564
|
+
f"list{model_name}es",
|
|
565
|
+
f"list{model_name[:-1]}ies",
|
|
566
|
+
]
|
|
567
|
+
|
|
568
|
+
for query in query_fields:
|
|
569
|
+
query_name = query['name']
|
|
570
|
+
if query_name in candidates and 'By' not in query_name:
|
|
571
|
+
return query_name
|
|
572
|
+
|
|
573
|
+
logger.error(f"No list query found for model {model_name}, tried: {candidates}")
|
|
574
|
+
return None
|
|
575
|
+
|
|
576
|
+
def upload(self, records: list, model_name: str, parsed_model_structure: Dict[str, Any]) -> tuple[int, int]:
|
|
577
|
+
logger.info("Uploading to Amplify backend...")
|
|
578
|
+
|
|
579
|
+
success_count = 0
|
|
580
|
+
error_count = 0
|
|
581
|
+
|
|
582
|
+
primary_field, is_secondary_index = self.get_primary_field_name(model_name, parsed_model_structure)
|
|
583
|
+
if not primary_field:
|
|
584
|
+
logger.error(f"Aborting upload for model {model_name}")
|
|
585
|
+
return 0, len(records)
|
|
586
|
+
|
|
587
|
+
for i in range(0, len(records), self.batch_size):
|
|
588
|
+
batch = records[i:i + self.batch_size]
|
|
589
|
+
logger.info(f"Uploading batch {i // self.batch_size + 1} ({len(batch)} items)...")
|
|
590
|
+
|
|
591
|
+
batch_success, batch_error = asyncio.run(
|
|
592
|
+
self.upload_batch_async(batch, model_name, primary_field, is_secondary_index)
|
|
593
|
+
)
|
|
594
|
+
success_count += batch_success
|
|
595
|
+
error_count += batch_error
|
|
596
|
+
|
|
597
|
+
logger.info(
|
|
598
|
+
f"Processed batch {i // self.batch_size + 1} of model {model_name}: {success_count} success, {error_count} errors")
|
|
599
|
+
|
|
600
|
+
return success_count, error_count
|
|
601
|
+
|
|
602
|
+
def list_records_by_secondary_index(self, model_name: str, secondary_index: str, value: str = None,
|
|
603
|
+
fields: list = None) -> Dict | None:
|
|
604
|
+
if fields is None:
|
|
605
|
+
fields = ['id', secondary_index]
|
|
606
|
+
|
|
607
|
+
fields_str = '\n'.join(fields)
|
|
608
|
+
|
|
609
|
+
if not value:
|
|
610
|
+
query_name = self._get_list_query_name(model_name)
|
|
611
|
+
query = f"""
|
|
612
|
+
query List{model_name}s {{
|
|
613
|
+
{query_name} {{
|
|
614
|
+
items {{
|
|
615
|
+
{fields_str}
|
|
616
|
+
}}
|
|
617
|
+
}}
|
|
618
|
+
}}
|
|
619
|
+
"""
|
|
620
|
+
result = self._request(query)
|
|
621
|
+
else:
|
|
622
|
+
query_name = f"list{model_name}By{secondary_index.capitalize()}"
|
|
623
|
+
query = f"""
|
|
624
|
+
query {query_name}(${secondary_index}: String!) {{
|
|
625
|
+
{query_name}({secondary_index}: ${secondary_index}) {{
|
|
626
|
+
items {{
|
|
627
|
+
{fields_str}
|
|
628
|
+
}}
|
|
629
|
+
}}
|
|
630
|
+
}}
|
|
631
|
+
"""
|
|
632
|
+
result = self._request(query, {secondary_index: value})
|
|
633
|
+
|
|
634
|
+
if result and 'data' in result:
|
|
635
|
+
items = result['data'].get(query_name, {}).get('items', [])
|
|
636
|
+
return items if items else None
|
|
637
|
+
|
|
638
|
+
return None
|
|
639
|
+
|
|
640
|
+
def get_record_by_id(self, model_name: str, record_id: str, fields: list = None) -> Dict | None:
|
|
641
|
+
if fields is None:
|
|
642
|
+
fields = ['id']
|
|
643
|
+
|
|
644
|
+
fields_str = '\n'.join(fields)
|
|
645
|
+
|
|
646
|
+
query_name = f"get{model_name}"
|
|
647
|
+
query = f"""
|
|
648
|
+
query Get{model_name}($id: ID!) {{
|
|
649
|
+
{query_name}(id: $id) {{
|
|
650
|
+
{fields_str}
|
|
651
|
+
}}
|
|
652
|
+
}}
|
|
653
|
+
"""
|
|
654
|
+
|
|
655
|
+
result = self._request(query, {"id": record_id})
|
|
656
|
+
|
|
657
|
+
if result and 'data' in result:
|
|
658
|
+
return result['data'].get(query_name)
|
|
659
|
+
|
|
660
|
+
return None
|
|
661
|
+
|
|
662
|
+
def get_records_by_field(self, model_name: str, field_name: str, value: str = None,
|
|
663
|
+
fields: list = None) -> Dict | None:
|
|
664
|
+
if fields is None:
|
|
665
|
+
fields = ['id', field_name]
|
|
666
|
+
|
|
667
|
+
fields_str = '\n'.join(fields)
|
|
668
|
+
|
|
669
|
+
query_name = self._get_list_query_name(model_name)
|
|
670
|
+
|
|
671
|
+
if not value:
|
|
672
|
+
query = f"""
|
|
673
|
+
query List{model_name}s {{
|
|
674
|
+
{query_name} {{
|
|
675
|
+
items {{
|
|
676
|
+
{fields_str}
|
|
677
|
+
}}
|
|
678
|
+
}}
|
|
679
|
+
}}
|
|
680
|
+
"""
|
|
681
|
+
result = self._request(query)
|
|
682
|
+
else:
|
|
683
|
+
query = f"""
|
|
684
|
+
query List{model_name}s($filter: Model{model_name}FilterInput) {{
|
|
685
|
+
{query_name}(filter: $filter) {{
|
|
686
|
+
items {{
|
|
687
|
+
{fields_str}
|
|
688
|
+
}}
|
|
689
|
+
}}
|
|
690
|
+
}}
|
|
691
|
+
"""
|
|
692
|
+
filter_input = {
|
|
693
|
+
field_name: {
|
|
694
|
+
"eq": value
|
|
695
|
+
}
|
|
696
|
+
}
|
|
697
|
+
result = self._request(query, {"filter": filter_input})
|
|
698
|
+
|
|
699
|
+
if result and 'data' in result:
|
|
700
|
+
items = result['data'].get(query_name, {}).get('items', [])
|
|
701
|
+
return items if items else None
|
|
702
|
+
|
|
703
|
+
return None
|
|
704
|
+
|
|
705
|
+
def get_records(self, model_name: str, parsed_model_structure: Dict[str, Any] = None, primary_field: str = None,
|
|
706
|
+
is_secondary_index: bool = None, fields: list = None) -> list | None:
|
|
707
|
+
if model_name in self.records_cache:
|
|
708
|
+
return self.records_cache[model_name]
|
|
709
|
+
|
|
710
|
+
if not primary_field:
|
|
711
|
+
if not parsed_model_structure:
|
|
712
|
+
logger.error("Parsed model structure required if primary_field not provided")
|
|
713
|
+
return None
|
|
714
|
+
primary_field, is_secondary_index = self.get_primary_field_name(model_name, parsed_model_structure)
|
|
715
|
+
|
|
716
|
+
if not primary_field:
|
|
717
|
+
return None
|
|
718
|
+
if is_secondary_index:
|
|
719
|
+
records = self.list_records_by_secondary_index(model_name, primary_field, fields=fields)
|
|
720
|
+
else:
|
|
721
|
+
records = self.get_records_by_field(model_name, primary_field, fields=fields)
|
|
722
|
+
|
|
723
|
+
if records:
|
|
724
|
+
self.records_cache[model_name] = records
|
|
725
|
+
return records
|
|
726
|
+
|
|
727
|
+
def get_record(self, model_name: str, parsed_model_structure: Dict[str, Any] = None, value: str = None,
|
|
728
|
+
record_id: str = None, primary_field: str = None, is_secondary_index: bool = None,
|
|
729
|
+
fields: list = None) -> Dict | None:
|
|
730
|
+
if record_id:
|
|
731
|
+
return self.get_record_by_id(model_name, record_id)
|
|
732
|
+
|
|
733
|
+
records = self.get_records(model_name, parsed_model_structure, primary_field, is_secondary_index, fields)
|
|
734
|
+
if not records:
|
|
735
|
+
return None
|
|
736
|
+
return next((record for record in records if record.get(primary_field) == value), None)
|