aimodelshare 0.1.32__py3-none-any.whl → 0.1.62__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of aimodelshare might be problematic. Click here for more details.

Files changed (38) hide show
  1. aimodelshare/__init__.py +94 -14
  2. aimodelshare/aimsonnx.py +312 -95
  3. aimodelshare/api.py +13 -12
  4. aimodelshare/auth.py +163 -0
  5. aimodelshare/aws.py +4 -4
  6. aimodelshare/base_image.py +1 -1
  7. aimodelshare/containerisation.py +1 -1
  8. aimodelshare/data_sharing/download_data.py +142 -87
  9. aimodelshare/generatemodelapi.py +7 -6
  10. aimodelshare/main/authorization.txt +275 -275
  11. aimodelshare/main/eval_lambda.txt +81 -13
  12. aimodelshare/model.py +493 -197
  13. aimodelshare/modeluser.py +89 -1
  14. aimodelshare/moral_compass/README.md +408 -0
  15. aimodelshare/moral_compass/__init__.py +37 -0
  16. aimodelshare/moral_compass/_version.py +3 -0
  17. aimodelshare/moral_compass/api_client.py +601 -0
  18. aimodelshare/moral_compass/apps/__init__.py +17 -0
  19. aimodelshare/moral_compass/apps/tutorial.py +198 -0
  20. aimodelshare/moral_compass/challenge.py +365 -0
  21. aimodelshare/moral_compass/config.py +187 -0
  22. aimodelshare/playground.py +26 -14
  23. aimodelshare/preprocessormodules.py +60 -6
  24. aimodelshare/pyspark/authorization.txt +258 -258
  25. aimodelshare/pyspark/eval_lambda.txt +1 -1
  26. aimodelshare/reproducibility.py +20 -5
  27. aimodelshare/utils/__init__.py +78 -0
  28. aimodelshare/utils/optional_deps.py +38 -0
  29. aimodelshare-0.1.62.dist-info/METADATA +298 -0
  30. {aimodelshare-0.1.32.dist-info → aimodelshare-0.1.62.dist-info}/RECORD +33 -25
  31. {aimodelshare-0.1.32.dist-info → aimodelshare-0.1.62.dist-info}/WHEEL +1 -1
  32. aimodelshare-0.1.62.dist-info/licenses/LICENSE +5 -0
  33. {aimodelshare-0.1.32.dist-info → aimodelshare-0.1.62.dist-info}/top_level.txt +0 -1
  34. aimodelshare-0.1.32.dist-info/METADATA +0 -78
  35. aimodelshare-0.1.32.dist-info/licenses/LICENSE +0 -22
  36. tests/__init__.py +0 -0
  37. tests/test_aimsonnx.py +0 -135
  38. tests/test_playground.py +0 -721
aimodelshare/auth.py ADDED
@@ -0,0 +1,163 @@
1
+ """
2
+ Authentication and identity management helpers for aimodelshare.
3
+
4
+ Provides unified authentication around Cognito IdToken (JWT_AUTHORIZATION_TOKEN),
5
+ with backward compatibility for legacy AWS_TOKEN.
6
+ """
7
+
8
+ import os
9
+ import warnings
10
+ import logging
11
+ from typing import Optional, Dict, Any
12
+ import json
13
+
14
+ logger = logging.getLogger("aimodelshare.auth")
15
+
16
+ try:
17
+ import jwt
18
+ except ImportError:
19
+ jwt = None
20
+ logger.warning("PyJWT not installed. JWT decode functionality will be limited.")
21
+
22
+
23
+ def get_primary_token() -> Optional[str]:
24
+ """
25
+ Get the primary authentication token from environment variables.
26
+
27
+ Prefers JWT_AUTHORIZATION_TOKEN over legacy AWS_TOKEN.
28
+ Issues a deprecation warning if only AWS_TOKEN is present.
29
+
30
+ Returns:
31
+ Optional[str]: The authentication token, or None if not found
32
+ """
33
+ jwt_token = os.getenv('JWT_AUTHORIZATION_TOKEN')
34
+ if jwt_token:
35
+ return jwt_token
36
+
37
+ aws_token = os.getenv('AWS_TOKEN')
38
+ if aws_token:
39
+ warnings.warn(
40
+ "Using legacy AWS_TOKEN environment variable. "
41
+ "Please migrate to JWT_AUTHORIZATION_TOKEN. "
42
+ "AWS_TOKEN support will be deprecated in a future release.",
43
+ DeprecationWarning,
44
+ stacklevel=2
45
+ )
46
+ return aws_token
47
+
48
+ return None
49
+
50
+
51
+ def get_identity_claims(token: Optional[str] = None, verify: bool = False) -> Dict[str, Any]:
52
+ """
53
+ Extract identity claims from a JWT token.
54
+
55
+ Args:
56
+ token: JWT token string. If None, uses get_primary_token()
57
+ verify: If True, performs signature verification (requires JWKS endpoint)
58
+ Currently defaults to False as JWKS verification is future work
59
+
60
+ Returns:
61
+ Dict containing identity claims:
62
+ - sub: Subject (user ID)
63
+ - email: User email
64
+ - cognito:username: Username (if present)
65
+ - iss: Issuer
66
+ - principal: Derived principal identifier
67
+
68
+ Raises:
69
+ ValueError: If token is invalid or missing
70
+ RuntimeError: If PyJWT is not installed
71
+
72
+ Note:
73
+ This currently performs unverified decode as JWKS signature verification
74
+ is planned for future work. Do not use in production security-critical
75
+ contexts without implementing signature verification.
76
+ """
77
+ if token is None:
78
+ token = get_primary_token()
79
+
80
+ if not token:
81
+ raise ValueError("No authentication token available")
82
+
83
+ if jwt is None:
84
+ raise RuntimeError("PyJWT not installed. Install with: pip install PyJWT>=2.4.0")
85
+
86
+ # TODO: Implement JWKS signature verification (future work)
87
+ # For now, perform unverified decode
88
+ if verify:
89
+ warnings.warn(
90
+ "JWT signature verification requested but not yet implemented. "
91
+ "Using unverified decode. This should not be used in production "
92
+ "for security-critical operations.",
93
+ UserWarning,
94
+ stacklevel=2
95
+ )
96
+
97
+ try:
98
+ # Unverified decode - JWKS verification is future work
99
+ claims = jwt.decode(token, options={"verify_signature": False})
100
+
101
+ # Derive principal from claims
102
+ # Priority: cognito:username > email > sub
103
+ principal = (
104
+ claims.get('cognito:username') or
105
+ claims.get('email') or
106
+ claims.get('sub')
107
+ )
108
+
109
+ if principal:
110
+ claims['principal'] = principal
111
+
112
+ return claims
113
+
114
+ except jwt.DecodeError as e:
115
+ raise ValueError(f"Invalid JWT token: {e}")
116
+ except Exception as e:
117
+ raise ValueError(f"Failed to decode JWT token: {e}")
118
+
119
+
120
+ def derive_principal(claims: Dict[str, Any]) -> str:
121
+ """
122
+ Derive a principal identifier from identity claims.
123
+
124
+ Args:
125
+ claims: Identity claims dictionary
126
+
127
+ Returns:
128
+ str: Principal identifier
129
+
130
+ Raises:
131
+ ValueError: If no suitable principal identifier found
132
+ """
133
+ principal = (
134
+ claims.get('principal') or
135
+ claims.get('cognito:username') or
136
+ claims.get('email') or
137
+ claims.get('sub')
138
+ )
139
+
140
+ if not principal:
141
+ raise ValueError("No principal identifier found in claims")
142
+
143
+ return str(principal)
144
+
145
+
146
+ def is_admin(claims: Dict[str, Any]) -> bool:
147
+ """
148
+ Check if the identity has admin privileges.
149
+
150
+ Args:
151
+ claims: Identity claims dictionary
152
+
153
+ Returns:
154
+ bool: True if user has admin privileges
155
+
156
+ Note:
157
+ Currently checks for 'cognito:groups' containing 'admin'.
158
+ Extend this logic as needed for your authorization model.
159
+ """
160
+ groups = claims.get('cognito:groups', [])
161
+ if isinstance(groups, list):
162
+ return 'admin' in groups
163
+ return False
aimodelshare/aws.py CHANGED
@@ -10,7 +10,7 @@ def set_credentials(credential_file=None, type="submit_model", apiurl="apiurl",
10
10
  import os
11
11
  import getpass
12
12
  from aimodelshare.aws import get_aws_token
13
- from aimodelshare.modeluser import get_jwt_token, create_user_getkeyandpassword
13
+ from aimodelshare.modeluser import get_jwt_token, setup_bucket_only
14
14
  if all([credential_file==None, type=="submit_model"]):
15
15
  set_credentials_public(type="submit_model", apiurl=apiurl)
16
16
  os.environ["AWS_TOKEN"]=get_aws_token()
@@ -131,7 +131,7 @@ def set_credentials(credential_file=None, type="submit_model", apiurl="apiurl",
131
131
  # Set Environment Variables for deploy models
132
132
  if type == "deploy_model":
133
133
  get_jwt_token(os.environ.get("username"), os.environ.get("password"))
134
- create_user_getkeyandpassword()
134
+ setup_bucket_only() # Use new function that doesn't create IAM users
135
135
 
136
136
  if not flag:
137
137
  print("Error: apiurl or type not found in"+str(credential_file)+". Please correct entries and resubmit.")
@@ -147,7 +147,7 @@ def set_credentials_public(credential_file=None, type="submit_model", apiurl="ap
147
147
  import os
148
148
  import getpass
149
149
  from aimodelshare.aws import get_aws_token
150
- from aimodelshare.modeluser import get_jwt_token, create_user_getkeyandpassword
150
+ from aimodelshare.modeluser import get_jwt_token, setup_bucket_only
151
151
 
152
152
  ##TODO: Require that "type" is provided, to ensure correct env vars get loaded
153
153
  flag = False
@@ -211,7 +211,7 @@ def set_credentials_public_aimscloud(credential_file=None, type="deploy_model",
211
211
  import os
212
212
  import getpass
213
213
  from aimodelshare.aws import get_aws_token
214
- from aimodelshare.modeluser import get_jwt_token, create_user_getkeyandpassword
214
+ from aimodelshare.modeluser import get_jwt_token, setup_bucket_only
215
215
 
216
216
  ##TODO: Require that "type" is provided, to ensure correct env vars get loaded
217
217
  flag = False
@@ -8,7 +8,7 @@ import zipfile
8
8
  import importlib.resources as pkg_resources
9
9
  from string import Template
10
10
 
11
- def lambda_using_base_image(account_id, region, session, project_name, model_dir, requirements_file_path, apiid, memory_size='3000', timeout='90', python_version='3.10'):
11
+ def lambda_using_base_image(account_id, region, session, project_name, model_dir, requirements_file_path, apiid, memory_size='3000', timeout='90', python_version='3.7'):
12
12
 
13
13
  codebuild_bucket_name=os.environ.get("BUCKET_NAME") # s3 bucket name to create #TODO: use same bucket and subfolder we used previously to store this data
14
14
  #Why? AWS limits users to 100 total buckets! Our old code only creates one per user per acct.
@@ -27,7 +27,7 @@ def create_bucket(s3_client, bucket_name, region):
27
27
  )
28
28
  return response
29
29
 
30
- def deploy_container(account_id, region, session, project_name, model_dir, requirements_file_path, apiid, memory_size='1024', timeout='120', python_version='3.10', pyspark_support=False):
30
+ def deploy_container(account_id, region, session, project_name, model_dir, requirements_file_path, apiid, memory_size='1024', timeout='120', python_version='3.7', pyspark_support=False):
31
31
 
32
32
  codebuild_bucket_name=os.environ.get("BUCKET_NAME") # s3 bucket name to create #TODO: use same bucket and subfolder we used previously to store this data
33
33
  # Why? AWS limits users to 100 total buckets! Our old code only creates one per user per acct.
@@ -34,8 +34,30 @@ def progress_bar(layer_label, nb_traits):
34
34
  def get_auth_url(registry): # to do with auth
35
35
  return 'https://' + registry + '/token/' # no aws auth
36
36
 
37
- def get_auth_head(auth_url, registry, repository): # to do with auth
38
- return get_auth_head_no_aws_auth(auth_url, registry, repository, 'application/vnd.docker.distribution.manifest.v2+json') # no aws auth
37
+ def get_auth_head(auth_url, registry, repository):
38
+ # Broaden Accept header to allow manifest list / OCI fallbacks
39
+ return get_auth_head_no_aws_auth(
40
+ auth_url,
41
+ registry,
42
+ repository,
43
+ ('application/vnd.docker.distribution.manifest.v2+json,'
44
+ 'application/vnd.docker.distribution.manifest.list.v2+json,'
45
+ 'application/vnd.oci.image.manifest.v1+json')
46
+ )
47
+
48
+ def _fetch_concrete_manifest(registry, repository, tag_or_digest, auth_head):
49
+ """Fetch a concrete image manifest (not a list)."""
50
+ resp = requests.get(
51
+ f'https://{registry}/v2/{repository}/manifests/{tag_or_digest}',
52
+ headers=auth_head,
53
+ verify=False
54
+ )
55
+ if not resp.ok:
56
+ raise RuntimeError(
57
+ f"Failed to fetch manifest {tag_or_digest} (status {resp.status_code}): {resp.text[:300]}"
58
+ )
59
+ return resp
60
+
39
61
 
40
62
  def download_layer(layer, layer_count, tmp_img_dir, blobs_resp):
41
63
 
@@ -76,87 +98,115 @@ def download_layer(layer, layer_count, tmp_img_dir, blobs_resp):
76
98
  return layer_id, layer_dir
77
99
 
78
100
  def pull_image(image_uri):
79
-
80
- image_uri_parts = image_uri.split('/')
81
-
82
- registry = image_uri_parts[0]
83
- image, tag = image_uri_parts[2].split(':')
84
- repository = '/'.join([image_uri_parts[1], image])
85
-
86
- auth_url = get_auth_url(registry)
87
-
88
- auth_head = get_auth_head(auth_url, registry, repository)
89
-
90
- resp = requests.get('https://{}/v2/{}/manifests/{}'.format(registry, repository, tag), headers=auth_head, verify=False)
91
-
92
- config = resp.json()['config']['digest']
93
- config_resp = requests.get('https://{}/v2/{}/blobs/{}'.format(registry, repository, config), headers=auth_head, verify=False)
94
-
95
- tmp_img_dir = tempfile.gettempdir() + '/' + 'tmp_{}_{}'.format(image, tag)
96
- os.mkdir(tmp_img_dir)
97
-
98
- file = open('{}/{}.json'.format(tmp_img_dir, config[7:]), 'wb')
99
- file.write(config_resp.content)
100
- file.close()
101
-
102
- content = [{
103
- 'Config': config[7:] + '.json',
104
- 'RepoTags': [],
105
- 'Layers': []
106
- }]
107
- content[0]['RepoTags'].append(image_uri)
108
-
109
- layer_count=0
110
- layers = resp.json()['layers'][6:]
111
-
112
- for layer in layers:
113
-
114
- layer_count += 1
115
-
116
- auth_head = get_auth_head(auth_url, registry, repository) # done to keep from expiring
117
- blobs_resp = requests.get('https://{}/v2/{}/blobs/{}'.format(registry, repository, layer['digest']), headers=auth_head, stream=True, verify=False)
118
-
119
- layer_id, layer_dir = download_layer(layer, layer_count, tmp_img_dir, blobs_resp)
120
- content[0]['Layers'].append(layer_id + '/layer.tar')
121
-
122
- # Creating json file
123
- file = open(layer_dir + '/json', 'w')
124
-
125
- # last layer = config manifest - history - rootfs
126
- if layers[-1]['digest'] == layer['digest']:
127
- json_obj = json.loads(config_resp.content)
128
- del json_obj['history']
129
- del json_obj['rootfs']
130
- else: # other layers json are empty
131
- json_obj = json.loads('{}')
132
-
133
- json_obj['id'] = layer_id
134
- file.write(json.dumps(json_obj))
135
- file.close()
136
-
137
- file = open(tmp_img_dir + '/manifest.json', 'w')
138
- file.write(json.dumps(content))
139
- file.close()
140
-
141
- content = {
142
- '/'.join(image_uri_parts[:-1]) + '/' + image : { tag : layer_id }
143
- }
144
-
145
- file = open(tmp_img_dir + '/repositories', 'w')
146
- file.write(json.dumps(content))
147
- file.close()
148
-
149
- # Create image tar and clean tmp folder
150
- docker_tar = tempfile.gettempdir() + '/' + '_'.join([repository.replace('/', '_'), tag]) + '.tar'
151
- sys.stdout.flush()
152
-
153
- tar = tarfile.open(docker_tar, "w")
154
- tar.add(tmp_img_dir, arcname=os.path.sep)
155
- tar.close()
156
-
157
- shutil.rmtree(tmp_img_dir, onerror=redo_with_write)
158
-
159
- return docker_tar
101
+ image_uri_parts = image_uri.split('/')
102
+ registry = image_uri_parts[0]
103
+ image, tag = image_uri_parts[2].split(':')
104
+ repository = '/'.join([image_uri_parts[1], image])
105
+
106
+ auth_url = get_auth_url(registry)
107
+ auth_head = get_auth_head(auth_url, registry, repository)
108
+
109
+ # 1. Fetch initial manifest (may be list or concrete)
110
+ resp = _fetch_concrete_manifest(registry, repository, tag, auth_head)
111
+ manifest_json = resp.json()
112
+
113
+ # 2. Handle manifest list fallback
114
+ if 'config' not in manifest_json:
115
+ if 'manifests' in manifest_json:
116
+ # Choose amd64 if available, else first
117
+ chosen = None
118
+ for m in manifest_json['manifests']:
119
+ arch = (m.get('platform') or {}).get('architecture')
120
+ if arch in ('amd64', 'x86_64'):
121
+ chosen = m
122
+ break
123
+ if chosen is None:
124
+ chosen = manifest_json['manifests'][0]
125
+ digest = chosen['digest']
126
+ # Re-auth to avoid token expiry
127
+ auth_head = get_auth_head(auth_url, registry, repository)
128
+ resp = _fetch_concrete_manifest(registry, repository, digest, auth_head)
129
+ manifest_json = resp.json()
130
+ else:
131
+ raise KeyError(
132
+ f"Manifest does not contain 'config' or 'manifests'. Keys: {list(manifest_json.keys())}"
133
+ )
134
+
135
+ if 'config' not in manifest_json or 'layers' not in manifest_json:
136
+ raise KeyError(
137
+ f"Unexpected manifest shape. Keys: {list(manifest_json.keys())}"
138
+ )
139
+
140
+ config = manifest_json['config']['digest']
141
+ config_resp = requests.get(
142
+ f'https://{registry}/v2/{repository}/blobs/{config}',
143
+ headers=auth_head,
144
+ verify=False
145
+ )
146
+ if not config_resp.ok:
147
+ raise RuntimeError(
148
+ f"Failed to fetch config blob {config} (status {config_resp.status_code}): {config_resp.text[:300]}"
149
+ )
150
+
151
+ tmp_img_dir = tempfile.gettempdir() + '/' + f'tmp_{image}_{tag}'
152
+ os.mkdir(tmp_img_dir)
153
+
154
+ with open(f'{tmp_img_dir}/{config[7:]}.json', 'wb') as f:
155
+ f.write(config_resp.content)
156
+
157
+ content = [{
158
+ 'Config': config[7:] + '.json',
159
+ 'RepoTags': [image_uri],
160
+ 'Layers': []
161
+ }]
162
+
163
+ layer_count = 0
164
+ layers = manifest_json['layers'] # removed [6:] slicing
165
+
166
+ for layer in layers:
167
+ layer_count += 1
168
+ # Refresh auth (avoid expiry)
169
+ auth_head = get_auth_head(auth_url, registry, repository)
170
+ blobs_resp = requests.get(
171
+ f'https://{registry}/v2/{repository}/blobs/{layer["digest"]}',
172
+ headers=auth_head,
173
+ stream=True,
174
+ verify=False
175
+ )
176
+ if not blobs_resp.ok:
177
+ raise RuntimeError(
178
+ f"Failed to stream layer {layer['digest']} status {blobs_resp.status_code}: {blobs_resp.text[:200]}"
179
+ )
180
+ layer_id, layer_dir = download_layer(layer, layer_count, tmp_img_dir, blobs_resp)
181
+ content[0]['Layers'].append(layer_id + '/layer.tar')
182
+
183
+ # Create layer json
184
+ with open(layer_dir + '/json', 'w') as fjson:
185
+ if layers[-1]['digest'] == layer['digest']:
186
+ json_obj = json.loads(config_resp.content)
187
+ json_obj.pop('history', None)
188
+ json_obj.pop('rootfs', None)
189
+ else:
190
+ json_obj = {}
191
+ json_obj['id'] = layer_id
192
+ fjson.write(json.dumps(json_obj))
193
+
194
+ with open(tmp_img_dir + '/manifest.json', 'w') as mf:
195
+ mf.write(json.dumps(content))
196
+
197
+ # repositories file
198
+ repositories_json = {
199
+ '/'.join(image_uri_parts[:-1]) + '/' + image: {tag: layer_id}
200
+ }
201
+ with open(tmp_img_dir + '/repositories', 'w') as rf:
202
+ rf.write(json.dumps(repositories_json))
203
+
204
+ docker_tar = tempfile.gettempdir() + '/' + '_'.join([repository.replace('/', '_'), tag]) + '.tar'
205
+ tar = tarfile.open(docker_tar, "w")
206
+ tar.add(tmp_img_dir, arcname=os.path.sep)
207
+ tar.close()
208
+ shutil.rmtree(tmp_img_dir, onerror=redo_with_write)
209
+ return docker_tar
160
210
 
161
211
 
162
212
  def extract_data_from_image(image_name, file_name, location):
@@ -215,7 +265,7 @@ def import_quickstart_data(tutorial, section="modelplayground"):
215
265
  existing_folder = 'titanic_competition_data'
216
266
 
217
267
  if all([tutorial == "cars", section == "modelplayground"]):
218
- quickstart_repository = "public.ecr.aws/y2e2a1d6/quickstart_car_sales_competition-repository:latest"
268
+ quickstart_repository = "public.ecr.aws/z5w0c9e9/quickstart_car_sales_competition-repository:latest"
219
269
  existing_folder = 'used_car_competition_data'
220
270
 
221
271
  if all([tutorial == "clickbait", section == "modelplayground"]):
@@ -241,7 +291,7 @@ def import_quickstart_data(tutorial, section="modelplayground"):
241
291
  existing_folder = 'dog_competition_data'
242
292
 
243
293
  if all([tutorial == "imdb", section == "modelplayground"]):
244
- quickstart_repository = "public.ecr.aws/y2e2a1d6/imdb_quickstart_materials-repository:latest"
294
+ quickstart_repository = "public.ecr.aws/z5w0c9e9/imdb_quickstart_materials-repository:latest"
245
295
  existing_folder = 'imdb_competition_data'
246
296
 
247
297
  download_data(quickstart_repository)
@@ -284,8 +334,13 @@ def import_quickstart_data(tutorial, section="modelplayground"):
284
334
  #unpack data
285
335
  X_train = pd.read_csv("imdb_quickstart_materials/X_train.csv").squeeze("columns")
286
336
  X_test = pd.read_csv("imdb_quickstart_materials/X_test.csv").squeeze("columns")
287
- y_test_labels = pd.read_csv("imdb_quickstart_materials/y_test_labels.csv").squeeze("columns")
288
- y_train_labels = pd.read_csv("imdb_quickstart_materials/y_train_labels.csv").squeeze("columns")
337
+ with open("imdb_quickstart_materials/y_train_labels.json", "r") as f:
338
+ y_train_labels = json.load(f)
339
+ with open("imdb_quickstart_materials/y_test_labels.json", "r") as f:
340
+ y_test_labels = json.load(f)
341
+ import pandas as pd
342
+ y_train_labels=pd.Series(y_train_labels)
343
+ y_test_labels=pd.Series(y_test_labels)
289
344
  # example data
290
345
  example_data = X_train[50:55]
291
346
 
@@ -23,7 +23,7 @@ from aimodelshare.aws import get_s3_iam_client, run_function_on_lambda, get_toke
23
23
  from aimodelshare.bucketpolicy import _custom_upload_policy
24
24
  from aimodelshare.exceptions import AuthorizationError, AWSAccessError, AWSUploadError
25
25
  from aimodelshare.api import get_api_json
26
- from aimodelshare.modeluser import create_user_getkeyandpassword
26
+ from aimodelshare.modeluser import decode_token_unverified
27
27
  from aimodelshare.preprocessormodules import upload_preprocessor
28
28
  from aimodelshare.model import _get_predictionmodel_key, _extract_model_metadata
29
29
  from aimodelshare.data_sharing.share_data import share_data_codebuild
@@ -447,7 +447,8 @@ def model_to_api(model_filepath, model_type, private, categorical, y_train, prep
447
447
 
448
448
  if all([isinstance(email_list, list)]):
449
449
  idtoken = get_aws_token()
450
- decoded = jwt.decode(idtoken, options={"verify_signature": False}) # works in PyJWT < v2.0
450
+ decoded = decode_token_unverified(idtoken)
451
+
451
452
  email=None
452
453
  email = decoded['email']
453
454
  # Owner has to be the first on the list
@@ -592,9 +593,9 @@ def create_competition(apiurl, data_directory, y_test, eval_metric_filepath=None
592
593
  """
593
594
  if all([isinstance(email_list, list)]):
594
595
  if any([len(email_list)>0, public=="True",public=="TRUE",public==True]):
595
- import jwt
596
596
  idtoken=get_aws_token()
597
- decoded = jwt.decode(idtoken, options={"verify_signature": False}) # works in PyJWT < v2.0
597
+ decoded = decode_token_unverified(idtoken)
598
+
598
599
  email=decoded['email']
599
600
  email_list.append(email)
600
601
  else:
@@ -757,9 +758,9 @@ def create_experiment(apiurl, data_directory, y_test, eval_metric_filepath=None,
757
758
  """
758
759
  if all([isinstance(email_list, list)]):
759
760
  if any([len(email_list)>0, public=="True",public=="TRUE",public==True]):
760
- import jwt
761
761
  idtoken=get_aws_token()
762
- decoded = jwt.decode(idtoken, options={"verify_signature": False}) # works in PyJWT < v2.0
762
+ decoded = decode_token_unverified(idtoken)
763
+
763
764
  email=decoded['email']
764
765
  email_list.append(email)
765
766
  else: