trainml 0.5.16__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. examples/local_storage.py +0 -2
  2. tests/integration/test_checkpoints_integration.py +4 -3
  3. tests/integration/test_datasets_integration.py +5 -3
  4. tests/integration/test_jobs_integration.py +33 -27
  5. tests/integration/test_models_integration.py +7 -3
  6. tests/integration/test_volumes_integration.py +2 -2
  7. tests/unit/cli/test_cli_checkpoint_unit.py +312 -1
  8. tests/unit/cloudbender/test_nodes_unit.py +112 -0
  9. tests/unit/cloudbender/test_providers_unit.py +96 -0
  10. tests/unit/cloudbender/test_regions_unit.py +106 -0
  11. tests/unit/cloudbender/test_services_unit.py +141 -0
  12. tests/unit/conftest.py +23 -10
  13. tests/unit/projects/test_project_data_connectors_unit.py +39 -0
  14. tests/unit/projects/test_project_datastores_unit.py +37 -0
  15. tests/unit/projects/test_project_members_unit.py +46 -0
  16. tests/unit/projects/test_project_services_unit.py +65 -0
  17. tests/unit/projects/test_projects_unit.py +17 -1
  18. tests/unit/test_auth_unit.py +17 -2
  19. tests/unit/test_checkpoints_unit.py +256 -71
  20. tests/unit/test_datasets_unit.py +218 -68
  21. tests/unit/test_exceptions.py +133 -0
  22. tests/unit/test_gpu_types_unit.py +11 -1
  23. tests/unit/test_jobs_unit.py +1014 -95
  24. tests/unit/test_main_unit.py +20 -0
  25. tests/unit/test_models_unit.py +218 -70
  26. tests/unit/test_trainml_unit.py +627 -3
  27. tests/unit/test_volumes_unit.py +211 -70
  28. tests/unit/utils/__init__.py +1 -0
  29. tests/unit/utils/test_transfer_unit.py +4260 -0
  30. trainml/__init__.py +1 -1
  31. trainml/checkpoints.py +56 -57
  32. trainml/cli/__init__.py +6 -3
  33. trainml/cli/checkpoint.py +18 -57
  34. trainml/cli/dataset.py +17 -57
  35. trainml/cli/job/__init__.py +11 -53
  36. trainml/cli/job/create.py +51 -24
  37. trainml/cli/model.py +14 -56
  38. trainml/cli/volume.py +18 -57
  39. trainml/datasets.py +50 -55
  40. trainml/jobs.py +239 -68
  41. trainml/models.py +51 -55
  42. trainml/projects/projects.py +2 -2
  43. trainml/trainml.py +50 -16
  44. trainml/utils/__init__.py +1 -0
  45. trainml/utils/auth.py +641 -0
  46. trainml/utils/transfer.py +587 -0
  47. trainml/volumes.py +48 -53
  48. {trainml-0.5.16.dist-info → trainml-1.0.0.dist-info}/METADATA +3 -3
  49. {trainml-0.5.16.dist-info → trainml-1.0.0.dist-info}/RECORD +53 -47
  50. {trainml-0.5.16.dist-info → trainml-1.0.0.dist-info}/LICENSE +0 -0
  51. {trainml-0.5.16.dist-info → trainml-1.0.0.dist-info}/WHEEL +0 -0
  52. {trainml-0.5.16.dist-info → trainml-1.0.0.dist-info}/entry_points.txt +0 -0
  53. {trainml-0.5.16.dist-info → trainml-1.0.0.dist-info}/top_level.txt +0 -0
trainml/utils/auth.py ADDED
@@ -0,0 +1,641 @@
1
+ ##
2
+ ## Source: https://github.com/capless/warrant.git
3
+ ## https://github.com/capless/warrant/blob/master/warrant/aws_srp.py
4
+ ##
5
+ # Apache License
6
+ # Version 2.0, January 2004
7
+ # http://www.apache.org/licenses/
8
+
9
+ # TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
10
+
11
+ # 1. Definitions.
12
+
13
+ # "License" shall mean the terms and conditions for use, reproduction,
14
+ # and distribution as defined by Sections 1 through 9 of this document.
15
+
16
+ # "Licensor" shall mean the copyright owner or entity authorized by
17
+ # the copyright owner that is granting the License.
18
+
19
+ # "Legal Entity" shall mean the union of the acting entity and all
20
+ # other entities that control, are controlled by, or are under common
21
+ # control with that entity. For the purposes of this definition,
22
+ # "control" means (i) the power, direct or indirect, to cause the
23
+ # direction or management of such entity, whether by contract or
24
+ # otherwise, or (ii) ownership of fifty percent (50%) or more of the
25
+ # outstanding shares, or (iii) beneficial ownership of such entity.
26
+
27
+ # "You" (or "Your") shall mean an individual or Legal Entity
28
+ # exercising permissions granted by this License.
29
+
30
+ # "Source" form shall mean the preferred form for making modifications,
31
+ # including but not limited to software source code, documentation
32
+ # source, and configuration files.
33
+
34
+ # "Object" form shall mean any form resulting from mechanical
35
+ # transformation or translation of a Source form, including but
36
+ # not limited to compiled object code, generated documentation,
37
+ # and conversions to other media types.
38
+
39
+ # "Work" shall mean the work of authorship, whether in Source or
40
+ # Object form, made available under the License, as indicated by a
41
+ # copyright notice that is included in or attached to the work
42
+ # (an example is provided in the Appendix below).
43
+
44
+ # "Derivative Works" shall mean any work, whether in Source or Object
45
+ # form, that is based on (or derived from) the Work and for which the
46
+ # editorial revisions, annotations, elaborations, or other modifications
47
+ # represent, as a whole, an original work of authorship. For the purposes
48
+ # of this License, Derivative Works shall not include works that remain
49
+ # separable from, or merely link (or bind by name) to the interfaces of
50
+ # the Work and Derivative Works thereof.
51
+
52
+ # "Contribution" shall mean any work of authorship, including
53
+ # the original version of the Work and any modifications or additions
54
+ # to that Work or Derivative Works thereof, that is intentionally
55
+ # submitted to Licensor for inclusion in the Work by the copyright owner
56
+ # or by an individual or Legal Entity authorized to submit on behalf of
57
+ # the copyright owner. For the purposes of this definition, "submitted"
58
+ # means any form of electronic, verbal, or written communication sent
59
+ # to the Licensor or its representatives, including but not limited to
60
+ # communication on electronic mailing lists, source code control systems,
61
+ # and issue tracking systems that are managed by, or on behalf of, the
62
+ # Licensor for the purpose of discussing and improving the Work, but
63
+ # excluding communication that is conspicuously marked or otherwise
64
+ # designated in writing by the copyright owner as "Not a Contribution."
65
+
66
+ # "Contributor" shall mean Licensor and any individual or Legal Entity
67
+ # on behalf of whom a Contribution has been received by Licensor and
68
+ # subsequently incorporated within the Work.
69
+
70
+ # 2. Grant of Copyright License. Subject to the terms and conditions of
71
+ # this License, each Contributor hereby grants to You a perpetual,
72
+ # worldwide, non-exclusive, no-charge, royalty-free, irrevocable
73
+ # copyright license to reproduce, prepare Derivative Works of,
74
+ # publicly display, publicly perform, sublicense, and distribute the
75
+ # Work and such Derivative Works in Source or Object form.
76
+
77
+ # 3. Grant of Patent License. Subject to the terms and conditions of
78
+ # this License, each Contributor hereby grants to You a perpetual,
79
+ # worldwide, non-exclusive, no-charge, royalty-free, irrevocable
80
+ # (except as stated in this section) patent license to make, have made,
81
+ # use, offer to sell, sell, import, and otherwise transfer the Work,
82
+ # where such license applies only to those patent claims licensable
83
+ # by such Contributor that are necessarily infringed by their
84
+ # Contribution(s) alone or by combination of their Contribution(s)
85
+ # with the Work to which such Contribution(s) was submitted. If You
86
+ # institute patent litigation against any entity (including a
87
+ # cross-claim or counterclaim in a lawsuit) alleging that the Work
88
+ # or a Contribution incorporated within the Work constitutes direct
89
+ # or contributory patent infringement, then any patent licenses
90
+ # granted to You under this License for that Work shall terminate
91
+ # as of the date such litigation is filed.
92
+
93
+ # 4. Redistribution. You may reproduce and distribute copies of the
94
+ # Work or Derivative Works thereof in any medium, with or without
95
+ # modifications, and in Source or Object form, provided that You
96
+ # meet the following conditions:
97
+
98
+ # (a) You must give any other recipients of the Work or
99
+ # Derivative Works a copy of this License; and
100
+
101
+ # (b) You must cause any modified files to carry prominent notices
102
+ # stating that You changed the files; and
103
+
104
+ # (c) You must retain, in the Source form of any Derivative Works
105
+ # that You distribute, all copyright, patent, trademark, and
106
+ # attribution notices from the Source form of the Work,
107
+ # excluding those notices that do not pertain to any part of
108
+ # the Derivative Works; and
109
+
110
+ # (d) If the Work includes a "NOTICE" text file as part of its
111
+ # distribution, then any Derivative Works that You distribute must
112
+ # include a readable copy of the attribution notices contained
113
+ # within such NOTICE file, excluding those notices that do not
114
+ # pertain to any part of the Derivative Works, in at least one
115
+ # of the following places: within a NOTICE text file distributed
116
+ # as part of the Derivative Works; within the Source form or
117
+ # documentation, if provided along with the Derivative Works; or,
118
+ # within a display generated by the Derivative Works, if and
119
+ # wherever such third-party notices normally appear. The contents
120
+ # of the NOTICE file are for informational purposes only and
121
+ # do not modify the License. You may add Your own attribution
122
+ # notices within Derivative Works that You distribute, alongside
123
+ # or as an addendum to the NOTICE text from the Work, provided
124
+ # that such additional attribution notices cannot be construed
125
+ # as modifying the License.
126
+
127
+ # You may add Your own copyright statement to Your modifications and
128
+ # may provide additional or different license terms and conditions
129
+ # for use, reproduction, or distribution of Your modifications, or
130
+ # for any such Derivative Works as a whole, provided Your use,
131
+ # reproduction, and distribution of the Work otherwise complies with
132
+ # the conditions stated in this License.
133
+
134
+ # 5. Submission of Contributions. Unless You explicitly state otherwise,
135
+ # any Contribution intentionally submitted for inclusion in the Work
136
+ # by You to the Licensor shall be under the terms and conditions of
137
+ # this License, without any additional terms or conditions.
138
+ # Notwithstanding the above, nothing herein shall supersede or modify
139
+ # the terms of any separate license agreement you may have executed
140
+ # with Licensor regarding such Contributions.
141
+
142
+ # 6. Trademarks. This License does not grant permission to use the trade
143
+ # names, trademarks, service marks, or product names of the Licensor,
144
+ # except as required for reasonable and customary use in describing the
145
+ # origin of the Work and reproducing the content of the NOTICE file.
146
+
147
+ # 7. Disclaimer of Warranty. Unless required by applicable law or
148
+ # agreed to in writing, Licensor provides the Work (and each
149
+ # Contributor provides its Contributions) on an "AS IS" BASIS,
150
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
151
+ # implied, including, without limitation, any warranties or conditions
152
+ # of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
153
+ # PARTICULAR PURPOSE. You are solely responsible for determining the
154
+ # appropriateness of using or redistributing the Work and assume any
155
+ # risks associated with Your exercise of permissions under this License.
156
+
157
+ # 8. Limitation of Liability. In no event and under no legal theory,
158
+ # whether in tort (including negligence), contract, or otherwise,
159
+ # unless required by applicable law (such as deliberate and grossly
160
+ # negligent acts) or agreed to in writing, shall any Contributor be
161
+ # liable to You for damages, including any direct, indirect, special,
162
+ # incidental, or consequential damages of any character arising as a
163
+ # result of this License or out of the use or inability to use the
164
+ # Work (including but not limited to damages for loss of goodwill,
165
+ # work stoppage, computer failure or malfunction, or any and all
166
+ # other commercial damages or losses), even if such Contributor
167
+ # has been advised of the possibility of such damages.
168
+
169
+ # 9. Accepting Warranty or Additional Liability. While redistributing
170
+ # the Work or Derivative Works thereof, You may choose to offer,
171
+ # and charge a fee for, acceptance of support, warranty, indemnity,
172
+ # or other liability obligations and/or rights consistent with this
173
+ # License. However, in accepting such obligations, You may act only
174
+ # on Your own behalf and on Your sole responsibility, not on behalf
175
+ # of any other Contributor, and only if You agree to indemnify,
176
+ # defend, and hold each Contributor harmless for any liability
177
+ # incurred by, or claims asserted against, such Contributor by reason
178
+ # of your accepting any such warranty or additional liability.
179
+
180
+ # END OF TERMS AND CONDITIONS
181
+ ##
182
+ ##
183
+ import base64
184
+ import binascii
185
+ import datetime
186
+ import hashlib
187
+ import hmac
188
+ import re
189
+ import json
190
+ import requests
191
+ import logging
192
+ import time
193
+ from datetime import datetime
194
+
195
+ import boto3
196
+ import os
197
+ import six
198
+ from jose import jwt
199
+
200
+ from trainml.exceptions import TrainMLException
201
+
202
+ # https://github.com/aws/amazon-cognito-identity-js/blob/master/src/AuthenticationHelper.js#L22
203
+ n_hex = (
204
+ "FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD1"
205
+ + "29024E088A67CC74020BBEA63B139B22514A08798E3404DD"
206
+ + "EF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245"
207
+ + "E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7ED"
208
+ + "EE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3D"
209
+ + "C2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F"
210
+ + "83655D23DCA3AD961C62F356208552BB9ED529077096966D"
211
+ + "670C354E4ABC9804F1746C08CA18217C32905E462E36CE3B"
212
+ + "E39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9"
213
+ + "DE2BCBF6955817183995497CEA956AE515D2261898FA0510"
214
+ + "15728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64"
215
+ + "ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7"
216
+ + "ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6B"
217
+ + "F12FFA06D98A0864D87602733EC86A64521F2B18177B200C"
218
+ + "BBE117577A615D6C770988C0BAD946E208E24FA074E5AB31"
219
+ + "43DB5BFCE0FD108E4B82D120A93AD2CAFFFFFFFFFFFFFFFF"
220
+ )
221
+ # https://github.com/aws/amazon-cognito-identity-js/blob/master/src/AuthenticationHelper.js#L49
222
+ g_hex = "2"
223
+ info_bits = bytearray("Caldera Derived Key", "utf-8")
224
+
225
+
226
+ def hash_sha256(buf):
227
+ """AuthenticationHelper.hash"""
228
+ a = hashlib.sha256(buf).hexdigest()
229
+ return (64 - len(a)) * "0" + a
230
+
231
+
232
+ def hex_hash(hex_string):
233
+ return hash_sha256(bytearray.fromhex(hex_string))
234
+
235
+
236
+ def hex_to_long(hex_string):
237
+ return int(hex_string, 16)
238
+
239
+
240
+ def long_to_hex(long_num):
241
+ return "%x" % long_num
242
+
243
+
244
+ def get_random(nbytes):
245
+ random_hex = binascii.hexlify(os.urandom(nbytes))
246
+ return hex_to_long(random_hex)
247
+
248
+
249
+ def pad_hex(long_int):
250
+ """
251
+ Converts a Long integer (or hex string) to hex format padded with zeroes for hashing
252
+ :param {Long integer|String} long_int Number or string to pad.
253
+ :return {String} Padded hex string.
254
+ """
255
+ if not isinstance(long_int, six.string_types):
256
+ hash_str = long_to_hex(long_int)
257
+ else:
258
+ hash_str = long_int
259
+ if len(hash_str) % 2 == 1:
260
+ hash_str = "0%s" % hash_str
261
+ elif hash_str[0] in "89ABCDEFabcdef":
262
+ hash_str = "00%s" % hash_str
263
+ return hash_str
264
+
265
+
266
+ def compute_hkdf(ikm, salt):
267
+ """
268
+ Standard hkdf algorithm
269
+ :param {Buffer} ikm Input key material.
270
+ :param {Buffer} salt Salt value.
271
+ :return {Buffer} Strong key material.
272
+ @private
273
+ """
274
+ prk = hmac.new(salt, ikm, hashlib.sha256).digest()
275
+ info_bits_update = info_bits + bytearray(chr(1), "utf-8")
276
+ hmac_hash = hmac.new(prk, info_bits_update, hashlib.sha256).digest()
277
+ return hmac_hash[:16]
278
+
279
+
280
+ def calculate_u(big_a, big_b):
281
+ """
282
+ Calculate the client's value U which is the hash of A and B
283
+ :param {Long integer} big_a Large A value.
284
+ :param {Long integer} big_b Server B value.
285
+ :return {Long integer} Computed U value.
286
+ """
287
+ u_hex_hash = hex_hash(pad_hex(big_a) + pad_hex(big_b))
288
+ return hex_to_long(u_hex_hash)
289
+
290
+
291
+ class AWSSRP(object):
292
+ NEW_PASSWORD_REQUIRED_CHALLENGE = "NEW_PASSWORD_REQUIRED"
293
+ PASSWORD_VERIFIER_CHALLENGE = "PASSWORD_VERIFIER"
294
+
295
+ def __init__(
296
+ self,
297
+ username,
298
+ password,
299
+ pool_id,
300
+ client_id,
301
+ pool_region=None,
302
+ client=None,
303
+ client_secret=None,
304
+ ):
305
+ if pool_region is not None and client is not None:
306
+ raise ValueError(
307
+ "pool_region and client should not both be specified "
308
+ "(region should be passed to the boto3 client instead)"
309
+ )
310
+
311
+ self.username = username
312
+ self.password = password
313
+ self.pool_id = pool_id
314
+ self.client_id = client_id
315
+ self.client_secret = client_secret
316
+ self.client = (
317
+ client
318
+ if client
319
+ else boto3.client("cognito-idp", region_name=pool_region)
320
+ )
321
+ self.big_n = hex_to_long(n_hex)
322
+ self.g = hex_to_long(g_hex)
323
+ self.k = hex_to_long(hex_hash("00" + n_hex + "0" + g_hex))
324
+ self.small_a_value = self.generate_random_small_a()
325
+ self.large_a_value = self.calculate_a()
326
+
327
+ def generate_random_small_a(self):
328
+ """
329
+ helper function to generate a random big integer
330
+ :return {Long integer} a random value.
331
+ """
332
+ random_long_int = get_random(128)
333
+ return random_long_int % self.big_n
334
+
335
+ def calculate_a(self):
336
+ """
337
+ Calculate the client's public value A = g^a%N
338
+ with the generated random number a
339
+ :param {Long integer} a Randomly generated small A.
340
+ :return {Long integer} Computed large A.
341
+ """
342
+ big_a = pow(self.g, self.small_a_value, self.big_n)
343
+ # safety check
344
+ if (big_a % self.big_n) == 0:
345
+ raise ValueError("Safety check for A failed")
346
+ return big_a
347
+
348
+ def get_password_authentication_key(
349
+ self, username, password, server_b_value, salt
350
+ ):
351
+ """
352
+ Calculates the final hkdf based on computed S value, and computed U value and the key
353
+ :param {String} username Username.
354
+ :param {String} password Password.
355
+ :param {Long integer} server_b_value Server B value.
356
+ :param {Long integer} salt Generated salt.
357
+ :return {Buffer} Computed HKDF value.
358
+ """
359
+ u_value = calculate_u(self.large_a_value, server_b_value)
360
+ if u_value == 0:
361
+ raise ValueError("U cannot be zero.")
362
+ username_password = "%s%s:%s" % (
363
+ self.pool_id.split("_")[1],
364
+ username,
365
+ password,
366
+ )
367
+ username_password_hash = hash_sha256(username_password.encode("utf-8"))
368
+
369
+ x_value = hex_to_long(hex_hash(pad_hex(salt) + username_password_hash))
370
+ g_mod_pow_xn = pow(self.g, x_value, self.big_n)
371
+ int_value2 = server_b_value - self.k * g_mod_pow_xn
372
+ s_value = pow(
373
+ int_value2, self.small_a_value + u_value * x_value, self.big_n
374
+ )
375
+ hkdf = compute_hkdf(
376
+ bytearray.fromhex(pad_hex(s_value)),
377
+ bytearray.fromhex(pad_hex(long_to_hex(u_value))),
378
+ )
379
+ return hkdf
380
+
381
+ def get_auth_params(self):
382
+ auth_params = {
383
+ "USERNAME": self.username,
384
+ "SRP_A": long_to_hex(self.large_a_value),
385
+ }
386
+ if self.client_secret is not None:
387
+ auth_params.update(
388
+ {
389
+ "SECRET_HASH": self.get_secret_hash(
390
+ self.username, self.client_id, self.client_secret
391
+ )
392
+ }
393
+ )
394
+ return auth_params
395
+
396
+ @staticmethod
397
+ def get_secret_hash(username, client_id, client_secret):
398
+ message = bytearray(username + client_id, "utf-8")
399
+ hmac_obj = hmac.new(
400
+ bytearray(client_secret, "utf-8"), message, hashlib.sha256
401
+ )
402
+ return base64.standard_b64encode(hmac_obj.digest()).decode("utf-8")
403
+
404
+ def process_challenge(self, challenge_parameters):
405
+ user_id_for_srp = challenge_parameters["USER_ID_FOR_SRP"]
406
+ salt_hex = challenge_parameters["SALT"]
407
+ srp_b_hex = challenge_parameters["SRP_B"]
408
+ secret_block_b64 = challenge_parameters["SECRET_BLOCK"]
409
+ # re strips leading zero from a day number (required by AWS Cognito)
410
+ timestamp = re.sub(
411
+ r" 0(\d) ",
412
+ r" \1 ",
413
+ datetime.utcnow().strftime("%a %b %d %H:%M:%S UTC %Y"),
414
+ )
415
+ hkdf = self.get_password_authentication_key(
416
+ user_id_for_srp, self.password, hex_to_long(srp_b_hex), salt_hex
417
+ )
418
+ secret_block_bytes = base64.standard_b64decode(secret_block_b64)
419
+ msg = (
420
+ bytearray(self.pool_id.split("_")[1], "utf-8")
421
+ + bytearray(user_id_for_srp, "utf-8")
422
+ + bytearray(secret_block_bytes)
423
+ + bytearray(timestamp, "utf-8")
424
+ )
425
+ hmac_obj = hmac.new(hkdf, msg, digestmod=hashlib.sha256)
426
+ signature_string = base64.standard_b64encode(hmac_obj.digest())
427
+ response = {
428
+ "TIMESTAMP": timestamp,
429
+ "USERNAME": user_id_for_srp,
430
+ "PASSWORD_CLAIM_SECRET_BLOCK": secret_block_b64,
431
+ "PASSWORD_CLAIM_SIGNATURE": signature_string.decode("utf-8"),
432
+ }
433
+ if self.client_secret is not None:
434
+ response.update(
435
+ {
436
+ "SECRET_HASH": self.get_secret_hash(
437
+ self.username, self.client_id, self.client_secret
438
+ )
439
+ }
440
+ )
441
+ return response
442
+
443
+ def authenticate_user(self, client=None):
444
+ boto_client = self.client or client
445
+ auth_params = self.get_auth_params()
446
+ response = boto_client.initiate_auth(
447
+ AuthFlow="USER_SRP_AUTH",
448
+ AuthParameters=auth_params,
449
+ ClientId=self.client_id,
450
+ )
451
+ if response["ChallengeName"] == self.PASSWORD_VERIFIER_CHALLENGE:
452
+ challenge_response = self.process_challenge(
453
+ response["ChallengeParameters"]
454
+ )
455
+ tokens = boto_client.respond_to_auth_challenge(
456
+ ClientId=self.client_id,
457
+ ChallengeName=self.PASSWORD_VERIFIER_CHALLENGE,
458
+ ChallengeResponses=challenge_response,
459
+ )
460
+
461
+ if (
462
+ tokens.get("ChallengeName")
463
+ == self.NEW_PASSWORD_REQUIRED_CHALLENGE
464
+ ):
465
+ raise Exception("Change password before authenticating")
466
+
467
+ return tokens
468
+ else:
469
+ raise NotImplementedError(
470
+ "The %s challenge is not supported" % response["ChallengeName"]
471
+ )
472
+
473
+ def set_new_password_challenge(self, new_password, client=None):
474
+ boto_client = self.client or client
475
+ auth_params = self.get_auth_params()
476
+ response = boto_client.initiate_auth(
477
+ AuthFlow="USER_SRP_AUTH",
478
+ AuthParameters=auth_params,
479
+ ClientId=self.client_id,
480
+ )
481
+ if response["ChallengeName"] == self.PASSWORD_VERIFIER_CHALLENGE:
482
+ challenge_response = self.process_challenge(
483
+ response["ChallengeParameters"]
484
+ )
485
+ tokens = boto_client.respond_to_auth_challenge(
486
+ ClientId=self.client_id,
487
+ ChallengeName=self.PASSWORD_VERIFIER_CHALLENGE,
488
+ ChallengeResponses=challenge_response,
489
+ )
490
+
491
+ if tokens["ChallengeName"] == self.NEW_PASSWORD_REQUIRED_CHALLENGE:
492
+ challenge_response = {
493
+ "USERNAME": auth_params["USERNAME"],
494
+ "NEW_PASSWORD": new_password,
495
+ }
496
+ new_password_response = boto_client.respond_to_auth_challenge(
497
+ ClientId=self.client_id,
498
+ ChallengeName=self.NEW_PASSWORD_REQUIRED_CHALLENGE,
499
+ Session=tokens["Session"],
500
+ ChallengeResponses=challenge_response,
501
+ )
502
+ return new_password_response
503
+ return tokens
504
+ else:
505
+ raise NotImplementedError(
506
+ "The %s challenge is not supported" % response["ChallengeName"]
507
+ )
508
+
509
+
510
+ class Auth(object):
511
+ def __init__(self, config_dir, domain_suffix="proximl.ai", **kwargs):
512
+ try:
513
+ with open(f"{config_dir}/environment.json", "r") as file:
514
+ env_str = file.read().replace("\n", "")
515
+ env = json.loads(env_str)
516
+ except:
517
+ env = dict()
518
+
519
+ auth_defaults = requests.get(
520
+ "https://app.{}/.well-known/auth-config.json".format(domain_suffix)
521
+ ).json()
522
+
523
+ self.region = (
524
+ kwargs.get("region")
525
+ or os.environ.get("TRAINML_REGION")
526
+ or env.get("region")
527
+ or auth_defaults.get("region")
528
+ )
529
+ self.client_id = (
530
+ kwargs.get("client_id")
531
+ or os.environ.get("TRAINML_CLIENT_ID")
532
+ or env.get("client_id")
533
+ or auth_defaults.get("userPoolSDKClientId")
534
+ )
535
+ self.pool_id = (
536
+ kwargs.get("pool_id")
537
+ or os.environ.get("TRAINML_POOL_ID")
538
+ or env.get("pool_id")
539
+ or auth_defaults.get("userPoolId")
540
+ )
541
+
542
+ try:
543
+ with open(f"{config_dir}/credentials.json", "r") as file:
544
+ key_str = file.read().replace("\n", "")
545
+ keys = json.loads(key_str)
546
+ except:
547
+ keys = dict()
548
+
549
+ self.username = (
550
+ kwargs.get("user")
551
+ or os.environ.get("TRAINML_USER")
552
+ or keys.get("user")
553
+ )
554
+ self.password = (
555
+ kwargs.get("key")
556
+ or os.environ.get("TRAINML_KEY")
557
+ or keys.get("key")
558
+ )
559
+ if not self.username or not self.password:
560
+ raise TrainMLException("trainML credentials not found.")
561
+ self.client = boto3.client("cognito-idp", region_name=self.region)
562
+ self.id_token = None
563
+ self.access_token = None
564
+ self.refresh_token = None
565
+ self.expires = 0
566
+
567
+ def get_keys(self):
568
+ pool_jwk = requests.get(
569
+ "https://cognito-idp.{}.amazonaws.com/{}/.well-known/jwks.json".format(
570
+ self.region, self.pool_id
571
+ )
572
+ ).json()
573
+ return pool_jwk
574
+
575
+ def get_key(self, kid):
576
+ keys = self.get_keys().get("keys")
577
+ key = list(filter(lambda x: x.get("kid") == kid, keys))
578
+ return key[0]
579
+
580
+ def verify_token(self, token, id_name):
581
+ kid = jwt.get_unverified_header(token).get("kid")
582
+ unverified_claims = jwt.get_unverified_claims(token)
583
+ hmac_key = self.get_key(kid)
584
+ try:
585
+ verified = jwt.decode(
586
+ token,
587
+ hmac_key,
588
+ algorithms=["RS256"],
589
+ audience=unverified_claims.get("aud"),
590
+ issuer=unverified_claims.get("iss"),
591
+ )
592
+ except Exception:
593
+ return False
594
+ return verified
595
+
596
+ def get_new_tokens(self):
597
+ aws = AWSSRP(
598
+ username=self.username,
599
+ password=self.password,
600
+ pool_id=self.pool_id,
601
+ client_id=self.client_id,
602
+ client=self.client,
603
+ )
604
+ tokens = aws.authenticate_user()
605
+ refresh_token = tokens["AuthenticationResult"]["RefreshToken"]
606
+ id_verify = self.verify_token(
607
+ tokens["AuthenticationResult"]["IdToken"], "id_token"
608
+ )
609
+ logging.debug(f"ID Token Verification: {id_verify}")
610
+ if id_verify:
611
+ id_token = tokens["AuthenticationResult"]["IdToken"]
612
+ self.id_token = id_token
613
+
614
+ access_verify = self.verify_token(
615
+ tokens["AuthenticationResult"]["AccessToken"], "access_token"
616
+ )
617
+ logging.debug(f"Access Token Verification: {access_verify}")
618
+ if access_verify:
619
+ access_token = tokens["AuthenticationResult"]["AccessToken"]
620
+ self.access_token = access_token
621
+
622
+ self.refresh_token = refresh_token
623
+ self.expires = (
624
+ id_verify.get("exp") - 300
625
+ ) ## prevent just about to expire tokens from being used
626
+
627
+ def get_tokens(self):
628
+ logging.debug(f"Token expires: {self.expires}")
629
+ logging.debug(f"Token is expired: {self.expires < time.time()}")
630
+ if not self.id_token:
631
+ self.get_new_tokens()
632
+ elif self.expires < time.time():
633
+ self.get_new_tokens()
634
+ logging.debug(f"New token expires: {self.expires}")
635
+
636
+ return dict(
637
+ id_token=self.id_token,
638
+ access_token=self.access_token,
639
+ refresh_token=self.refresh_token,
640
+ expires=self.expires,
641
+ )