aimodelshare 0.3.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aimodelshare/README.md +26 -0
- aimodelshare/__init__.py +100 -0
- aimodelshare/aimsonnx.py +2381 -0
- aimodelshare/api.py +836 -0
- aimodelshare/auth.py +163 -0
- aimodelshare/aws.py +511 -0
- aimodelshare/aws_client.py +173 -0
- aimodelshare/base_image.py +154 -0
- aimodelshare/bucketpolicy.py +106 -0
- aimodelshare/color_mappings/color_mapping_keras.csv +121 -0
- aimodelshare/color_mappings/color_mapping_pytorch.csv +117 -0
- aimodelshare/containerisation.py +244 -0
- aimodelshare/containerization.py +712 -0
- aimodelshare/containerization_templates/Dockerfile.txt +8 -0
- aimodelshare/containerization_templates/Dockerfile_PySpark.txt +23 -0
- aimodelshare/containerization_templates/buildspec.txt +14 -0
- aimodelshare/containerization_templates/lambda_function.txt +40 -0
- aimodelshare/custom_approach/__init__.py +1 -0
- aimodelshare/custom_approach/lambda_function.py +17 -0
- aimodelshare/custom_eval_metrics.py +103 -0
- aimodelshare/data_sharing/__init__.py +0 -0
- aimodelshare/data_sharing/data_sharing_templates/Dockerfile.txt +3 -0
- aimodelshare/data_sharing/data_sharing_templates/__init__.py +1 -0
- aimodelshare/data_sharing/data_sharing_templates/buildspec.txt +15 -0
- aimodelshare/data_sharing/data_sharing_templates/codebuild_policies.txt +129 -0
- aimodelshare/data_sharing/data_sharing_templates/codebuild_trust_relationship.txt +12 -0
- aimodelshare/data_sharing/download_data.py +620 -0
- aimodelshare/data_sharing/share_data.py +373 -0
- aimodelshare/data_sharing/utils.py +8 -0
- aimodelshare/deploy_custom_lambda.py +246 -0
- aimodelshare/documentation/Makefile +20 -0
- aimodelshare/documentation/karma_sphinx_theme/__init__.py +28 -0
- aimodelshare/documentation/karma_sphinx_theme/_version.py +2 -0
- aimodelshare/documentation/karma_sphinx_theme/breadcrumbs.html +70 -0
- aimodelshare/documentation/karma_sphinx_theme/layout.html +172 -0
- aimodelshare/documentation/karma_sphinx_theme/search.html +50 -0
- aimodelshare/documentation/karma_sphinx_theme/searchbox.html +14 -0
- aimodelshare/documentation/karma_sphinx_theme/static/css/custom.css +2 -0
- aimodelshare/documentation/karma_sphinx_theme/static/css/custom.css.map +1 -0
- aimodelshare/documentation/karma_sphinx_theme/static/css/theme.css +2751 -0
- aimodelshare/documentation/karma_sphinx_theme/static/css/theme.css.map +1 -0
- aimodelshare/documentation/karma_sphinx_theme/static/css/theme.min.css +2 -0
- aimodelshare/documentation/karma_sphinx_theme/static/css/theme.min.css.map +1 -0
- aimodelshare/documentation/karma_sphinx_theme/static/font/fontello.eot +0 -0
- aimodelshare/documentation/karma_sphinx_theme/static/font/fontello.svg +32 -0
- aimodelshare/documentation/karma_sphinx_theme/static/font/fontello.ttf +0 -0
- aimodelshare/documentation/karma_sphinx_theme/static/font/fontello.woff +0 -0
- aimodelshare/documentation/karma_sphinx_theme/static/font/fontello.woff2 +0 -0
- aimodelshare/documentation/karma_sphinx_theme/static/js/theme.js +68 -0
- aimodelshare/documentation/karma_sphinx_theme/theme.conf +9 -0
- aimodelshare/documentation/make.bat +35 -0
- aimodelshare/documentation/requirements.txt +2 -0
- aimodelshare/documentation/source/about.rst +18 -0
- aimodelshare/documentation/source/advanced_features.rst +137 -0
- aimodelshare/documentation/source/competition.rst +218 -0
- aimodelshare/documentation/source/conf.py +58 -0
- aimodelshare/documentation/source/create_credentials.rst +86 -0
- aimodelshare/documentation/source/example_notebooks.rst +132 -0
- aimodelshare/documentation/source/functions.rst +151 -0
- aimodelshare/documentation/source/gettingstarted.rst +390 -0
- aimodelshare/documentation/source/images/creds1.png +0 -0
- aimodelshare/documentation/source/images/creds2.png +0 -0
- aimodelshare/documentation/source/images/creds3.png +0 -0
- aimodelshare/documentation/source/images/creds4.png +0 -0
- aimodelshare/documentation/source/images/creds5.png +0 -0
- aimodelshare/documentation/source/images/creds_file_example.png +0 -0
- aimodelshare/documentation/source/images/predict_tab.png +0 -0
- aimodelshare/documentation/source/index.rst +110 -0
- aimodelshare/documentation/source/modelplayground.rst +132 -0
- aimodelshare/exceptions.py +11 -0
- aimodelshare/generatemodelapi.py +1270 -0
- aimodelshare/iam/codebuild_policy.txt +129 -0
- aimodelshare/iam/codebuild_trust_relationship.txt +12 -0
- aimodelshare/iam/lambda_policy.txt +15 -0
- aimodelshare/iam/lambda_trust_relationship.txt +12 -0
- aimodelshare/json_templates/__init__.py +1 -0
- aimodelshare/json_templates/api_json.txt +155 -0
- aimodelshare/json_templates/auth/policy.txt +1 -0
- aimodelshare/json_templates/auth/role.txt +1 -0
- aimodelshare/json_templates/eval/policy.txt +1 -0
- aimodelshare/json_templates/eval/role.txt +1 -0
- aimodelshare/json_templates/function/policy.txt +1 -0
- aimodelshare/json_templates/function/role.txt +1 -0
- aimodelshare/json_templates/integration_response.txt +5 -0
- aimodelshare/json_templates/lambda_policy_1.txt +15 -0
- aimodelshare/json_templates/lambda_policy_2.txt +8 -0
- aimodelshare/json_templates/lambda_role_1.txt +12 -0
- aimodelshare/json_templates/lambda_role_2.txt +16 -0
- aimodelshare/leaderboard.py +174 -0
- aimodelshare/main/1.txt +132 -0
- aimodelshare/main/1B.txt +112 -0
- aimodelshare/main/2.txt +153 -0
- aimodelshare/main/3.txt +134 -0
- aimodelshare/main/4.txt +128 -0
- aimodelshare/main/5.txt +109 -0
- aimodelshare/main/6.txt +105 -0
- aimodelshare/main/7.txt +144 -0
- aimodelshare/main/8.txt +142 -0
- aimodelshare/main/__init__.py +1 -0
- aimodelshare/main/authorization.txt +275 -0
- aimodelshare/main/eval_classification.txt +79 -0
- aimodelshare/main/eval_lambda.txt +1709 -0
- aimodelshare/main/eval_regression.txt +80 -0
- aimodelshare/main/lambda_function.txt +8 -0
- aimodelshare/main/nst.txt +149 -0
- aimodelshare/model.py +1543 -0
- aimodelshare/modeluser.py +215 -0
- aimodelshare/moral_compass/README.md +408 -0
- aimodelshare/moral_compass/__init__.py +65 -0
- aimodelshare/moral_compass/_version.py +3 -0
- aimodelshare/moral_compass/api_client.py +601 -0
- aimodelshare/moral_compass/apps/__init__.py +69 -0
- aimodelshare/moral_compass/apps/ai_consequences.py +540 -0
- aimodelshare/moral_compass/apps/bias_detective.py +714 -0
- aimodelshare/moral_compass/apps/ethical_revelation.py +898 -0
- aimodelshare/moral_compass/apps/fairness_fixer.py +889 -0
- aimodelshare/moral_compass/apps/judge.py +888 -0
- aimodelshare/moral_compass/apps/justice_equity_upgrade.py +853 -0
- aimodelshare/moral_compass/apps/mc_integration_helpers.py +820 -0
- aimodelshare/moral_compass/apps/model_building_game.py +1104 -0
- aimodelshare/moral_compass/apps/model_building_game_beginner.py +687 -0
- aimodelshare/moral_compass/apps/moral_compass_challenge.py +858 -0
- aimodelshare/moral_compass/apps/session_auth.py +254 -0
- aimodelshare/moral_compass/apps/shared_activity_styles.css +349 -0
- aimodelshare/moral_compass/apps/tutorial.py +481 -0
- aimodelshare/moral_compass/apps/what_is_ai.py +853 -0
- aimodelshare/moral_compass/challenge.py +365 -0
- aimodelshare/moral_compass/config.py +187 -0
- aimodelshare/placeholders/model.onnx +0 -0
- aimodelshare/placeholders/preprocessor.zip +0 -0
- aimodelshare/playground.py +1968 -0
- aimodelshare/postprocessormodules.py +157 -0
- aimodelshare/preprocessormodules.py +373 -0
- aimodelshare/pyspark/1.txt +195 -0
- aimodelshare/pyspark/1B.txt +181 -0
- aimodelshare/pyspark/2.txt +220 -0
- aimodelshare/pyspark/3.txt +204 -0
- aimodelshare/pyspark/4.txt +187 -0
- aimodelshare/pyspark/5.txt +178 -0
- aimodelshare/pyspark/6.txt +174 -0
- aimodelshare/pyspark/7.txt +211 -0
- aimodelshare/pyspark/8.txt +206 -0
- aimodelshare/pyspark/__init__.py +1 -0
- aimodelshare/pyspark/authorization.txt +258 -0
- aimodelshare/pyspark/eval_classification.txt +79 -0
- aimodelshare/pyspark/eval_lambda.txt +1441 -0
- aimodelshare/pyspark/eval_regression.txt +80 -0
- aimodelshare/pyspark/lambda_function.txt +8 -0
- aimodelshare/pyspark/nst.txt +213 -0
- aimodelshare/python/my_preprocessor.py +58 -0
- aimodelshare/readme.md +26 -0
- aimodelshare/reproducibility.py +181 -0
- aimodelshare/sam/Dockerfile.txt +8 -0
- aimodelshare/sam/Dockerfile_PySpark.txt +24 -0
- aimodelshare/sam/__init__.py +1 -0
- aimodelshare/sam/buildspec.txt +11 -0
- aimodelshare/sam/codebuild_policies.txt +129 -0
- aimodelshare/sam/codebuild_trust_relationship.txt +12 -0
- aimodelshare/sam/codepipeline_policies.txt +173 -0
- aimodelshare/sam/codepipeline_trust_relationship.txt +12 -0
- aimodelshare/sam/spark-class.txt +2 -0
- aimodelshare/sam/template.txt +54 -0
- aimodelshare/tools.py +103 -0
- aimodelshare/utils/__init__.py +78 -0
- aimodelshare/utils/optional_deps.py +38 -0
- aimodelshare/utils.py +57 -0
- aimodelshare-0.3.7.dist-info/METADATA +298 -0
- aimodelshare-0.3.7.dist-info/RECORD +171 -0
- aimodelshare-0.3.7.dist-info/WHEEL +5 -0
- aimodelshare-0.3.7.dist-info/licenses/LICENSE +5 -0
- aimodelshare-0.3.7.dist-info/top_level.txt +1 -0
aimodelshare/model.py
ADDED
|
@@ -0,0 +1,1543 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import boto3
|
|
3
|
+
import json
|
|
4
|
+
import onnx
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pandas as pd
|
|
7
|
+
import requests
|
|
8
|
+
import json
|
|
9
|
+
import ast
|
|
10
|
+
try:
|
|
11
|
+
import tensorflow as tf
|
|
12
|
+
except ImportError:
|
|
13
|
+
pass
|
|
14
|
+
import tempfile as tmp
|
|
15
|
+
from datetime import datetime
|
|
16
|
+
try:
|
|
17
|
+
import torch
|
|
18
|
+
except:
|
|
19
|
+
pass
|
|
20
|
+
|
|
21
|
+
from aimodelshare.leaderboard import get_leaderboard
|
|
22
|
+
from aimodelshare.aws import run_function_on_lambda, get_token, get_aws_token, get_aws_client
|
|
23
|
+
from aimodelshare.aimsonnx import _get_leaderboard_data, inspect_model, _get_metadata, _model_summary, model_from_string, pyspark_model_from_string, _get_layer_names, _get_layer_names_pytorch
|
|
24
|
+
from aimodelshare.aimsonnx import model_to_onnx
|
|
25
|
+
from aimodelshare.utils import ignore_warning
|
|
26
|
+
import warnings
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _normalize_eval_payload(raw_eval):
|
|
30
|
+
"""
|
|
31
|
+
Normalize the API response eval payload to (public_eval_dict, private_eval_dict).
|
|
32
|
+
|
|
33
|
+
Handles multiple response formats:
|
|
34
|
+
- {"eval": [public_dict, private_dict]} -> extract both dicts
|
|
35
|
+
- {"eval": public_dict} -> public_dict, {}
|
|
36
|
+
- {"eval": None} or missing -> {}, {}
|
|
37
|
+
- Malformed responses -> {}, {} with warning
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
raw_eval: The raw API response (expected to be dict with 'eval' key)
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
tuple: (public_eval_dict, private_eval_dict) - both guaranteed to be dicts
|
|
44
|
+
"""
|
|
45
|
+
public_eval = {}
|
|
46
|
+
private_eval = {}
|
|
47
|
+
|
|
48
|
+
if not isinstance(raw_eval, dict):
|
|
49
|
+
print("---------------------------------------------------------------")
|
|
50
|
+
print(f"--- WARNING: API response is not a dict (type={type(raw_eval)}) ---")
|
|
51
|
+
print("Defaulting to empty eval metrics.")
|
|
52
|
+
print("---------------------------------------------------------------")
|
|
53
|
+
return public_eval, private_eval
|
|
54
|
+
|
|
55
|
+
eval_field = raw_eval.get('eval')
|
|
56
|
+
|
|
57
|
+
if eval_field is None:
|
|
58
|
+
# No eval field present
|
|
59
|
+
return public_eval, private_eval
|
|
60
|
+
|
|
61
|
+
if isinstance(eval_field, list):
|
|
62
|
+
# Expected format: [public_dict, private_dict, ...]
|
|
63
|
+
if len(eval_field) >= 1 and isinstance(eval_field[0], dict):
|
|
64
|
+
public_eval = eval_field[0]
|
|
65
|
+
if len(eval_field) >= 2 and isinstance(eval_field[1], dict):
|
|
66
|
+
private_eval = eval_field[1]
|
|
67
|
+
elif len(eval_field) >= 1:
|
|
68
|
+
# Only one dict in list, treat as public
|
|
69
|
+
if not public_eval:
|
|
70
|
+
public_eval = {}
|
|
71
|
+
elif isinstance(eval_field, dict):
|
|
72
|
+
# Single dict, treat as public eval
|
|
73
|
+
public_eval = eval_field
|
|
74
|
+
else:
|
|
75
|
+
print("---------------------------------------------------------------")
|
|
76
|
+
print(f"--- WARNING: 'eval' field has unexpected type: {type(eval_field)} ---")
|
|
77
|
+
print("Defaulting to empty eval metrics.")
|
|
78
|
+
print("---------------------------------------------------------------")
|
|
79
|
+
|
|
80
|
+
return public_eval, private_eval
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _subset_numeric(metrics_dict, keys_to_extract):
|
|
84
|
+
"""
|
|
85
|
+
Safely extract a subset of numeric metrics from a metrics dictionary.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
metrics_dict: Dictionary containing metric key-value pairs
|
|
89
|
+
keys_to_extract: List of keys to extract from the dictionary
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
dict: Subset of metrics that exist and have numeric (float/int) values
|
|
93
|
+
"""
|
|
94
|
+
if not isinstance(metrics_dict, dict):
|
|
95
|
+
print("---------------------------------------------------------------")
|
|
96
|
+
print(f"--- WARNING: metrics_dict is not a dict (type={type(metrics_dict)}) ---")
|
|
97
|
+
print("Returning empty metrics subset.")
|
|
98
|
+
print("---------------------------------------------------------------")
|
|
99
|
+
return {}
|
|
100
|
+
|
|
101
|
+
subset = {}
|
|
102
|
+
for key in keys_to_extract:
|
|
103
|
+
value = metrics_dict.get(key)
|
|
104
|
+
if value is not None and isinstance(value, (int, float)):
|
|
105
|
+
subset[key] = value
|
|
106
|
+
|
|
107
|
+
return subset
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def _prepare_preprocessor_if_function(preprocessor, debug_mode=False):
|
|
111
|
+
"""Prepare a preprocessor for submission.
|
|
112
|
+
Accepts:
|
|
113
|
+
- None: returns None
|
|
114
|
+
- Path to existing preprocessor zip (.zip)
|
|
115
|
+
- Callable function: exports source or pickled callable with loader
|
|
116
|
+
- Transformer object (e.g., sklearn Pipeline/ColumnTransformer) with .transform: pickles object + loader
|
|
117
|
+
Returns: absolute path to created or existing preprocessor zip, or None.
|
|
118
|
+
Raises: RuntimeError with actionable message on failure.
|
|
119
|
+
"""
|
|
120
|
+
import inspect
|
|
121
|
+
import tempfile
|
|
122
|
+
import zipfile
|
|
123
|
+
import pickle
|
|
124
|
+
import textwrap
|
|
125
|
+
|
|
126
|
+
if preprocessor is None:
|
|
127
|
+
return None
|
|
128
|
+
|
|
129
|
+
# Existing zip path
|
|
130
|
+
if isinstance(preprocessor, str) and preprocessor.endswith('.zip'):
|
|
131
|
+
if not os.path.exists(preprocessor):
|
|
132
|
+
raise RuntimeError(f"Preprocessor export failed: zip path not found: {preprocessor}")
|
|
133
|
+
if debug_mode:
|
|
134
|
+
print(f"[DEBUG] Using existing preprocessor zip: {preprocessor}")
|
|
135
|
+
return preprocessor
|
|
136
|
+
|
|
137
|
+
# Determine if transformer object
|
|
138
|
+
is_transformer_obj = hasattr(preprocessor, 'transform') and not inspect.isfunction(preprocessor)
|
|
139
|
+
|
|
140
|
+
serialize_object = None
|
|
141
|
+
export_callable = None
|
|
142
|
+
|
|
143
|
+
if is_transformer_obj:
|
|
144
|
+
if debug_mode:
|
|
145
|
+
print('[DEBUG] Detected transformer object; preparing wrapper.')
|
|
146
|
+
transformer_obj = preprocessor
|
|
147
|
+
|
|
148
|
+
def _wrapped_preprocessor(data):
|
|
149
|
+
return transformer_obj.transform(data)
|
|
150
|
+
export_callable = _wrapped_preprocessor
|
|
151
|
+
serialize_object = transformer_obj # pickle the transformer
|
|
152
|
+
|
|
153
|
+
elif callable(preprocessor):
|
|
154
|
+
export_callable = preprocessor
|
|
155
|
+
else:
|
|
156
|
+
raise RuntimeError(
|
|
157
|
+
f"Preprocessor export failed: Unsupported type {type(preprocessor)}. "
|
|
158
|
+
"Provide a callable, transformer with .transform, an existing .zip path, or None."
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
tmp_dir = tempfile.mkdtemp()
|
|
162
|
+
py_path = os.path.join(tmp_dir, 'preprocessor.py')
|
|
163
|
+
zip_path = os.path.join(tmp_dir, 'preprocessor.zip')
|
|
164
|
+
pkl_name = 'preprocessor.pkl'
|
|
165
|
+
|
|
166
|
+
source_written = False
|
|
167
|
+
# Attempt direct source extraction if not a transformer serialization
|
|
168
|
+
if serialize_object is None:
|
|
169
|
+
try:
|
|
170
|
+
src = inspect.getsource(export_callable)
|
|
171
|
+
with open(py_path, 'w') as f:
|
|
172
|
+
f.write(src)
|
|
173
|
+
source_written = True
|
|
174
|
+
if debug_mode:
|
|
175
|
+
print('[DEBUG] Wrote source for callable preprocessor.')
|
|
176
|
+
except Exception as e:
|
|
177
|
+
if debug_mode:
|
|
178
|
+
print(f'[DEBUG] Source extraction failed; falling back to pickled callable: {e}')
|
|
179
|
+
serialize_object = export_callable # fallback to pickling callable
|
|
180
|
+
|
|
181
|
+
# If transformer or fallback pickled callable: write loader stub
|
|
182
|
+
if serialize_object is not None and not source_written:
|
|
183
|
+
loader_stub = textwrap.dedent(f"""
|
|
184
|
+
import pickle, os
|
|
185
|
+
_PKL_FILE = '{pkl_name}'
|
|
186
|
+
_loaded_obj = None
|
|
187
|
+
def preprocessor(data):
|
|
188
|
+
global _loaded_obj
|
|
189
|
+
if _loaded_obj is None:
|
|
190
|
+
with open(os.path.join(os.path.dirname(__file__), _PKL_FILE), 'rb') as pf:
|
|
191
|
+
_loaded_obj = pickle.load(pf)
|
|
192
|
+
# If original object was a transformer it has .transform; else callable
|
|
193
|
+
if hasattr(_loaded_obj, 'transform'):
|
|
194
|
+
return _loaded_obj.transform(data)
|
|
195
|
+
return _loaded_obj(data)
|
|
196
|
+
""")
|
|
197
|
+
with open(py_path, 'w') as f:
|
|
198
|
+
f.write(loader_stub)
|
|
199
|
+
if debug_mode:
|
|
200
|
+
print('[DEBUG] Wrote loader stub for pickled object.')
|
|
201
|
+
|
|
202
|
+
# Serialize object if needed
|
|
203
|
+
if serialize_object is not None:
|
|
204
|
+
try:
|
|
205
|
+
with open(os.path.join(tmp_dir, pkl_name), 'wb') as pf:
|
|
206
|
+
pickle.dump(serialize_object, pf)
|
|
207
|
+
if debug_mode:
|
|
208
|
+
print('[DEBUG] Pickled transformer/callable successfully.')
|
|
209
|
+
except Exception as e:
|
|
210
|
+
raise RuntimeError(f'Preprocessor export failed: pickling failed: {e}')
|
|
211
|
+
|
|
212
|
+
# Create zip
|
|
213
|
+
try:
|
|
214
|
+
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zf:
|
|
215
|
+
zf.write(py_path, arcname='preprocessor.py')
|
|
216
|
+
pkl_path = os.path.join(tmp_dir, pkl_name)
|
|
217
|
+
if os.path.exists(pkl_path):
|
|
218
|
+
zf.write(pkl_path, arcname=pkl_name)
|
|
219
|
+
except Exception as e:
|
|
220
|
+
raise RuntimeError(f'Preprocessor export failed: zip creation error: {e}')
|
|
221
|
+
|
|
222
|
+
# Final validation
|
|
223
|
+
if not os.path.exists(zip_path) or os.path.getsize(zip_path) == 0:
|
|
224
|
+
raise RuntimeError(f'Preprocessor export failed: zip file not found or empty at {zip_path}')
|
|
225
|
+
|
|
226
|
+
if debug_mode:
|
|
227
|
+
print(f'[DEBUG] Preprocessor zip created: {zip_path}')
|
|
228
|
+
return zip_path
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def _diagnose_closure_variables(preprocessor_fxn):
|
|
232
|
+
"""
|
|
233
|
+
Diagnose closure variables for serialization issues.
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
preprocessor_fxn: Function to diagnose
|
|
237
|
+
|
|
238
|
+
Logs:
|
|
239
|
+
INFO for successful serialization of each closure object
|
|
240
|
+
WARNING for failed serialization attempts
|
|
241
|
+
"""
|
|
242
|
+
import inspect
|
|
243
|
+
import pickle
|
|
244
|
+
import logging
|
|
245
|
+
|
|
246
|
+
# Get closure variables
|
|
247
|
+
closure_vars = inspect.getclosurevars(preprocessor_fxn)
|
|
248
|
+
all_globals = closure_vars.globals
|
|
249
|
+
|
|
250
|
+
if not all_globals:
|
|
251
|
+
logging.info("No closure variables detected in preprocessor function")
|
|
252
|
+
return
|
|
253
|
+
|
|
254
|
+
logging.info(f"Analyzing {len(all_globals)} closure variables...")
|
|
255
|
+
|
|
256
|
+
successful = []
|
|
257
|
+
failed = []
|
|
258
|
+
|
|
259
|
+
for var_name, var_value in all_globals.items():
|
|
260
|
+
try:
|
|
261
|
+
# Attempt to pickle the object
|
|
262
|
+
pickle.dumps(var_value)
|
|
263
|
+
successful.append(var_name)
|
|
264
|
+
logging.info(f"✓ Closure variable '{var_name}' (type: {type(var_value).__name__}) is serializable")
|
|
265
|
+
except Exception as e:
|
|
266
|
+
failed.append((var_name, type(var_value).__name__, str(e)))
|
|
267
|
+
logging.warning(f"✗ Closure variable '{var_name}' (type: {type(var_value).__name__}) failed serialization: {e}")
|
|
268
|
+
|
|
269
|
+
# Summary
|
|
270
|
+
if failed:
|
|
271
|
+
failure_summary = "; ".join([f"{name} ({vtype})" for name, vtype, _ in failed])
|
|
272
|
+
logging.warning(f"Serialization failures detected: {failure_summary}")
|
|
273
|
+
else:
|
|
274
|
+
logging.info(f"All {len(successful)} closure variables are serializable")
|
|
275
|
+
|
|
276
|
+
return successful, failed
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
def _get_file_list(client, bucket,keysubfolderid):
|
|
280
|
+
# Reading file list {{{
|
|
281
|
+
try:
|
|
282
|
+
objectlist=[]
|
|
283
|
+
paginator = client.get_paginator('list_objects')
|
|
284
|
+
pages = paginator.paginate(Bucket=bucket, Prefix=keysubfolderid)
|
|
285
|
+
|
|
286
|
+
for page in pages:
|
|
287
|
+
for obj in page['Contents']:
|
|
288
|
+
objectlist.append(obj['Key'])
|
|
289
|
+
|
|
290
|
+
except Exception as err:
|
|
291
|
+
return None, err
|
|
292
|
+
|
|
293
|
+
file_list = []
|
|
294
|
+
for key in objectlist:
|
|
295
|
+
file_list.append(key.split("/")[-1])
|
|
296
|
+
# }}}
|
|
297
|
+
|
|
298
|
+
return file_list, None
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def _delete_s3_object(client, bucket, model_id, filename):
|
|
302
|
+
deletionobject = client["resource"].Object(bucket, model_id + "/" + filename)
|
|
303
|
+
deletionobject.delete()
|
|
304
|
+
|
|
305
|
+
def _get_predictionmodel_key(unique_model_id,file_extension):
|
|
306
|
+
if file_extension==".pkl":
|
|
307
|
+
file_key = unique_model_id + "/runtime_model" + file_extension
|
|
308
|
+
versionfile_key = unique_model_id + "/predictionmodel_1" + file_extension
|
|
309
|
+
else:
|
|
310
|
+
file_key = unique_model_id + "/runtime_model" + file_extension
|
|
311
|
+
versionfile_key = unique_model_id + "/predictionmodel_1" + file_extension
|
|
312
|
+
return file_key,versionfile_key
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def _upload_onnx_model(modelpath, client, bucket, model_id, model_version):
|
|
316
|
+
# Check the model {{{
|
|
317
|
+
if not os.path.exists(modelpath):
|
|
318
|
+
raise FileNotFoundError(f"The model file at {modelpath} does not exist")
|
|
319
|
+
|
|
320
|
+
file_name = os.path.basename(modelpath)
|
|
321
|
+
file_name, file_ext = os.path.splitext(file_name)
|
|
322
|
+
|
|
323
|
+
assert (
|
|
324
|
+
file_ext == ".onnx"
|
|
325
|
+
), "modelshareai api only supports .onnx models at the moment"
|
|
326
|
+
# }}}
|
|
327
|
+
|
|
328
|
+
# Upload the model {{{
|
|
329
|
+
try:
|
|
330
|
+
client["client"].upload_file(
|
|
331
|
+
modelpath, bucket, model_id + "/onnx_model_mostrecent.onnx"
|
|
332
|
+
)
|
|
333
|
+
client["client"].upload_file(
|
|
334
|
+
modelpath,
|
|
335
|
+
bucket,
|
|
336
|
+
model_id + "/onnx_model_v" + str(model_version) + file_ext,
|
|
337
|
+
)
|
|
338
|
+
except Exception as err:
|
|
339
|
+
return err
|
|
340
|
+
# }}}
|
|
341
|
+
|
|
342
|
+
def _upload_native_model(modelpath, client, bucket, model_id, model_version):
|
|
343
|
+
# Check the model {{{
|
|
344
|
+
if not os.path.exists(modelpath):
|
|
345
|
+
raise FileNotFoundError(f"The model file at {modelpath} does not exist")
|
|
346
|
+
|
|
347
|
+
file_name = os.path.basename(modelpath)
|
|
348
|
+
file_name, file_ext = os.path.splitext(file_name)
|
|
349
|
+
|
|
350
|
+
assert (
|
|
351
|
+
file_ext == ".onnx"
|
|
352
|
+
), "modelshareai api only supports .onnx models at the moment"
|
|
353
|
+
# }}}
|
|
354
|
+
|
|
355
|
+
# Upload the model {{{
|
|
356
|
+
try:
|
|
357
|
+
client["client"].upload_file(
|
|
358
|
+
modelpath, bucket, model_id + "/onnx_model_mostrecent.onnx"
|
|
359
|
+
)
|
|
360
|
+
client["client"].upload_file(
|
|
361
|
+
modelpath,
|
|
362
|
+
bucket,
|
|
363
|
+
model_id + "/onnx_model_v" + str(model_version) + file_ext,
|
|
364
|
+
)
|
|
365
|
+
except Exception as err:
|
|
366
|
+
return err
|
|
367
|
+
# }}}
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
def _upload_preprocessor(preprocessor, client, bucket, model_id, model_version):
|
|
371
|
+
|
|
372
|
+
try:
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
# Check the preprocessor {{{
|
|
376
|
+
if not os.path.exists(preprocessor):
|
|
377
|
+
raise FileNotFoundError(
|
|
378
|
+
f"The preprocessor file at {preprocessor} does not exist"
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
file_name = os.path.basename(preprocessor)
|
|
383
|
+
file_name, file_ext = os.path.splitext(file_name)
|
|
384
|
+
|
|
385
|
+
from zipfile import ZipFile
|
|
386
|
+
dir_zip = preprocessor
|
|
387
|
+
|
|
388
|
+
#zipObj = ZipFile(os.path.join("./preprocessor.zip"), 'a')
|
|
389
|
+
#/Users/aishwarya/Downloads/aimodelshare-master
|
|
390
|
+
client["client"].upload_file(dir_zip, bucket, model_id + "/preprocessor_v" + str(model_version)+ ".zip")
|
|
391
|
+
except Exception as e:
|
|
392
|
+
print(e)
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
def _update_leaderboard_public(
|
|
396
|
+
modelpath,
|
|
397
|
+
eval_metrics,
|
|
398
|
+
s3_presigned_dict,
|
|
399
|
+
username=None,
|
|
400
|
+
custom_metadata=None,
|
|
401
|
+
private=False,
|
|
402
|
+
leaderboard_type="competition",
|
|
403
|
+
onnx_model=None,
|
|
404
|
+
):
|
|
405
|
+
"""
|
|
406
|
+
Update the public (or private) leaderboard file via presigned URLs.
|
|
407
|
+
Adds new columns if custom_metadata introduces new keys.
|
|
408
|
+
"""
|
|
409
|
+
mastertable_path = (
|
|
410
|
+
"model_eval_data_mastertable_private.csv"
|
|
411
|
+
if private
|
|
412
|
+
else "model_eval_data_mastertable.csv"
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
# Load or derive metadata
|
|
416
|
+
if modelpath is not None and not os.path.exists(modelpath):
|
|
417
|
+
raise FileNotFoundError(f"The model file at {modelpath} does not exist")
|
|
418
|
+
|
|
419
|
+
model_versions = [
|
|
420
|
+
os.path.splitext(f)[0].split("_")[-1][1:]
|
|
421
|
+
for f in s3_presigned_dict["put"].keys()
|
|
422
|
+
]
|
|
423
|
+
model_versions = list(map(int, filter(lambda v: v.isnumeric(), model_versions)))
|
|
424
|
+
model_version = model_versions[0]
|
|
425
|
+
|
|
426
|
+
if onnx_model is not None:
|
|
427
|
+
metadata = _get_leaderboard_data(onnx_model, eval_metrics)
|
|
428
|
+
elif modelpath is not None:
|
|
429
|
+
onnx_model = onnx.load(modelpath)
|
|
430
|
+
metadata = _get_leaderboard_data(onnx_model, eval_metrics)
|
|
431
|
+
else:
|
|
432
|
+
metadata = _get_leaderboard_data(None, eval_metrics)
|
|
433
|
+
|
|
434
|
+
if custom_metadata:
|
|
435
|
+
metadata = {**metadata, **custom_metadata}
|
|
436
|
+
|
|
437
|
+
metadata["username"] = username if username else os.environ.get("username")
|
|
438
|
+
metadata["timestamp"] = str(datetime.now())
|
|
439
|
+
metadata["version"] = model_version
|
|
440
|
+
|
|
441
|
+
temp_dir = tmp.mkdtemp()
|
|
442
|
+
|
|
443
|
+
# Read existing leaderboard (if any)
|
|
444
|
+
try:
|
|
445
|
+
import wget
|
|
446
|
+
|
|
447
|
+
wget.download(
|
|
448
|
+
s3_presigned_dict["get"][mastertable_path],
|
|
449
|
+
out=os.path.join(temp_dir, mastertable_path),
|
|
450
|
+
)
|
|
451
|
+
leaderboard = pd.read_csv(
|
|
452
|
+
os.path.join(temp_dir, mastertable_path), sep="\t"
|
|
453
|
+
)
|
|
454
|
+
except Exception:
|
|
455
|
+
leaderboard = pd.DataFrame(columns=list(metadata.keys()))
|
|
456
|
+
|
|
457
|
+
# Expand columns for any new metadata keys
|
|
458
|
+
existing_cols = set(leaderboard.columns.tolist())
|
|
459
|
+
new_cols = [c for c in metadata.keys() if c not in existing_cols]
|
|
460
|
+
for c in new_cols:
|
|
461
|
+
leaderboard[c] = None
|
|
462
|
+
|
|
463
|
+
# Append row
|
|
464
|
+
row_dict = {col: metadata.get(col, None) for col in leaderboard.columns}
|
|
465
|
+
leaderboard.loc[len(leaderboard)] = row_dict
|
|
466
|
+
|
|
467
|
+
# Legacy behavior: remove model_config from metadata dict before returning
|
|
468
|
+
metadata.pop("model_config", None)
|
|
469
|
+
|
|
470
|
+
# Write updated leaderboard to temp
|
|
471
|
+
leaderboard.to_csv(
|
|
472
|
+
os.path.join(temp_dir, mastertable_path), index=False, sep="\t"
|
|
473
|
+
)
|
|
474
|
+
|
|
475
|
+
# Upload via appropriate presigned POST
|
|
476
|
+
try:
|
|
477
|
+
put_keys = list(s3_presigned_dict["put"].keys())
|
|
478
|
+
csv_put_entries = [k for k in put_keys if "csv" in k]
|
|
479
|
+
|
|
480
|
+
file_put_dicts = [
|
|
481
|
+
ast.literal_eval(s3_presigned_dict["put"][k]) for k in csv_put_entries
|
|
482
|
+
]
|
|
483
|
+
# public uses first, private uses second
|
|
484
|
+
target_index = 1 if private else 0
|
|
485
|
+
upload_spec = file_put_dicts[target_index]
|
|
486
|
+
|
|
487
|
+
with open(os.path.join(temp_dir, mastertable_path), "rb") as f:
|
|
488
|
+
files = {"file": (mastertable_path, f)}
|
|
489
|
+
http_response = requests.post(
|
|
490
|
+
upload_spec["url"], data=upload_spec["fields"], files=files
|
|
491
|
+
)
|
|
492
|
+
if http_response.status_code not in (200, 204):
|
|
493
|
+
raise RuntimeError(
|
|
494
|
+
f"Leaderboard upload failed with status {http_response.status_code}: {http_response.text}"
|
|
495
|
+
)
|
|
496
|
+
|
|
497
|
+
return metadata
|
|
498
|
+
except Exception as err:
|
|
499
|
+
return err
|
|
500
|
+
|
|
501
|
+
|
|
502
|
+
def _update_leaderboard(
|
|
503
|
+
modelpath,
|
|
504
|
+
eval_metrics,
|
|
505
|
+
client,
|
|
506
|
+
bucket,
|
|
507
|
+
model_id,
|
|
508
|
+
model_version,
|
|
509
|
+
onnx_model=None,
|
|
510
|
+
custom_metadata=None,
|
|
511
|
+
):
|
|
512
|
+
"""
|
|
513
|
+
Update the leaderboard directly in S3 using boto3 client/resource (non-presigned path).
|
|
514
|
+
Adds new columns if custom_metadata introduces new keys.
|
|
515
|
+
"""
|
|
516
|
+
# Build metadata
|
|
517
|
+
if onnx_model is not None:
|
|
518
|
+
metadata = _get_leaderboard_data(onnx_model, eval_metrics)
|
|
519
|
+
elif modelpath is not None:
|
|
520
|
+
if not os.path.exists(modelpath):
|
|
521
|
+
raise FileNotFoundError(f"The model file at {modelpath} does not exist")
|
|
522
|
+
loaded = onnx.load(modelpath)
|
|
523
|
+
metadata = _get_leaderboard_data(loaded, eval_metrics)
|
|
524
|
+
else:
|
|
525
|
+
metadata = _get_leaderboard_data(None, eval_metrics)
|
|
526
|
+
|
|
527
|
+
if custom_metadata:
|
|
528
|
+
metadata = {**metadata, **custom_metadata}
|
|
529
|
+
|
|
530
|
+
metadata["username"] = os.environ.get("username")
|
|
531
|
+
metadata["timestamp"] = str(datetime.now())
|
|
532
|
+
metadata["version"] = model_version
|
|
533
|
+
|
|
534
|
+
# Fetch existing leaderboard (if any)
|
|
535
|
+
try:
|
|
536
|
+
obj = client["client"].get_object(
|
|
537
|
+
Bucket=bucket, Key=model_id + "/model_eval_data_mastertable.csv"
|
|
538
|
+
)
|
|
539
|
+
leaderboard = pd.read_csv(obj["Body"], sep="\t")
|
|
540
|
+
except client["client"].exceptions.NoSuchKey:
|
|
541
|
+
leaderboard = pd.DataFrame(columns=list(metadata.keys()))
|
|
542
|
+
except Exception as err:
|
|
543
|
+
raise err
|
|
544
|
+
|
|
545
|
+
# Expand columns as needed
|
|
546
|
+
existing_cols = set(leaderboard.columns.tolist())
|
|
547
|
+
new_cols = [c for c in metadata.keys() if c not in existing_cols]
|
|
548
|
+
for c in new_cols:
|
|
549
|
+
leaderboard[c] = None
|
|
550
|
+
|
|
551
|
+
# Append row
|
|
552
|
+
row_dict = {col: metadata.get(col, None) for col in leaderboard.columns}
|
|
553
|
+
leaderboard.loc[len(leaderboard)] = row_dict
|
|
554
|
+
|
|
555
|
+
# Legacy removal
|
|
556
|
+
metadata.pop("model_config", None)
|
|
557
|
+
|
|
558
|
+
# Write and upload
|
|
559
|
+
csv_payload = leaderboard.to_csv(index=False, sep="\t")
|
|
560
|
+
try:
|
|
561
|
+
s3_object = client["resource"].Object(
|
|
562
|
+
bucket, model_id + "/model_eval_data_mastertable.csv"
|
|
563
|
+
)
|
|
564
|
+
s3_object.put(Body=csv_payload)
|
|
565
|
+
return metadata
|
|
566
|
+
except Exception as err:
|
|
567
|
+
return err
|
|
568
|
+
|
|
569
|
+
|
|
570
|
+
def _normalize_model_config(model_config, model_type=None):
|
|
571
|
+
"""
|
|
572
|
+
Normalize model_config to a dict, handling various input types.
|
|
573
|
+
|
|
574
|
+
Args:
|
|
575
|
+
model_config: Can be None, dict, or string representation of dict
|
|
576
|
+
model_type: Optional model type for context in warnings
|
|
577
|
+
|
|
578
|
+
Returns:
|
|
579
|
+
dict: Normalized model config, or empty dict if normalization fails
|
|
580
|
+
"""
|
|
581
|
+
import ast
|
|
582
|
+
|
|
583
|
+
# If already a dict, return as-is
|
|
584
|
+
if isinstance(model_config, dict):
|
|
585
|
+
return model_config
|
|
586
|
+
|
|
587
|
+
# If None or other non-string type, return empty dict
|
|
588
|
+
if not isinstance(model_config, str):
|
|
589
|
+
if model_config is not None:
|
|
590
|
+
print(f"Warning: model_config is {type(model_config).__name__}, expected str or dict. Using empty config.")
|
|
591
|
+
return {}
|
|
592
|
+
|
|
593
|
+
# Try to parse string to dict
|
|
594
|
+
try:
|
|
595
|
+
import astunparse
|
|
596
|
+
|
|
597
|
+
tree = ast.parse(model_config)
|
|
598
|
+
stringconfig = model_config
|
|
599
|
+
|
|
600
|
+
# Find and quote callable nodes
|
|
601
|
+
problemnodes = []
|
|
602
|
+
for node in ast.walk(tree):
|
|
603
|
+
if isinstance(node, ast.Call):
|
|
604
|
+
problemnodes.append(astunparse.unparse(node).replace("\n", ""))
|
|
605
|
+
|
|
606
|
+
problemnodesunique = set(problemnodes)
|
|
607
|
+
for i in problemnodesunique:
|
|
608
|
+
stringconfig = stringconfig.replace(i, "'" + i + "'")
|
|
609
|
+
|
|
610
|
+
# Parse the modified string
|
|
611
|
+
model_config_dict = ast.literal_eval(stringconfig)
|
|
612
|
+
return model_config_dict if isinstance(model_config_dict, dict) else {}
|
|
613
|
+
|
|
614
|
+
except Exception as e:
|
|
615
|
+
print(f"Warning: Failed to parse model_config string: {e}. Using empty config.")
|
|
616
|
+
return {}
|
|
617
|
+
|
|
618
|
+
|
|
619
|
+
def _build_sklearn_param_dataframe(model_type, model_config):
|
|
620
|
+
"""
|
|
621
|
+
Build parameter inspection DataFrame for sklearn/xgboost models.
|
|
622
|
+
|
|
623
|
+
Creates a DataFrame with aligned columns by taking the union of default
|
|
624
|
+
parameters and model_config parameters. This ensures equal-length arrays
|
|
625
|
+
even when model_config contains extra parameters or is missing defaults.
|
|
626
|
+
|
|
627
|
+
Args:
|
|
628
|
+
model_type: String name of the sklearn model class
|
|
629
|
+
model_config: Dict of model configuration parameters
|
|
630
|
+
|
|
631
|
+
Returns:
|
|
632
|
+
pd.DataFrame: DataFrame with param_name, default_value, param_value columns,
|
|
633
|
+
or empty DataFrame on error
|
|
634
|
+
"""
|
|
635
|
+
import pandas as pd
|
|
636
|
+
import warnings
|
|
637
|
+
|
|
638
|
+
try:
|
|
639
|
+
model_class = model_from_string(model_type)
|
|
640
|
+
default_instance = model_class()
|
|
641
|
+
defaults_dict = default_instance.get_params()
|
|
642
|
+
|
|
643
|
+
# Take union of keys from both sources to ensure all parameters are included
|
|
644
|
+
# This prevents ValueError: "All arrays must be of the same length"
|
|
645
|
+
# when model_config has different keys than defaults
|
|
646
|
+
param_names = sorted(set(defaults_dict.keys()) | set(model_config.keys()))
|
|
647
|
+
default_values = [defaults_dict.get(k, None) for k in param_names]
|
|
648
|
+
param_values = [model_config.get(k, None) for k in param_names]
|
|
649
|
+
|
|
650
|
+
return pd.DataFrame({
|
|
651
|
+
'param_name': param_names,
|
|
652
|
+
'default_value': default_values,
|
|
653
|
+
'param_value': param_values
|
|
654
|
+
})
|
|
655
|
+
except Exception as e:
|
|
656
|
+
# Log warning and fallback to empty DataFrame
|
|
657
|
+
warnings.warn(f"Failed to instantiate model class for {model_type}: {e}")
|
|
658
|
+
return pd.DataFrame()
|
|
659
|
+
|
|
660
|
+
|
|
661
|
+
def upload_model_dict(modelpath, s3_presigned_dict, bucket, model_id, model_version, placeholder=False, onnx_model=None):
|
|
662
|
+
import wget
|
|
663
|
+
import json
|
|
664
|
+
import ast
|
|
665
|
+
temp=tmp.mkdtemp()
|
|
666
|
+
# get model summary from onnx
|
|
667
|
+
import astunparse
|
|
668
|
+
|
|
669
|
+
if placeholder==False:
|
|
670
|
+
|
|
671
|
+
if onnx_model==None:
|
|
672
|
+
onnx_model = onnx.load(modelpath)
|
|
673
|
+
meta_dict = _get_metadata(onnx_model)
|
|
674
|
+
|
|
675
|
+
if meta_dict['ml_framework'] in ['keras', 'pytorch']:
|
|
676
|
+
|
|
677
|
+
inspect_pd = _model_summary(meta_dict)
|
|
678
|
+
|
|
679
|
+
elif meta_dict['ml_framework'] in ['sklearn', 'xgboost']:
|
|
680
|
+
|
|
681
|
+
# Normalize model_config to dict (handles None, dict, or string)
|
|
682
|
+
model_config = _normalize_model_config(
|
|
683
|
+
meta_dict.get("model_config"),
|
|
684
|
+
meta_dict.get('model_type')
|
|
685
|
+
)
|
|
686
|
+
|
|
687
|
+
# Build parameter inspection DataFrame
|
|
688
|
+
inspect_pd = _build_sklearn_param_dataframe(
|
|
689
|
+
meta_dict['model_type'],
|
|
690
|
+
model_config
|
|
691
|
+
)
|
|
692
|
+
|
|
693
|
+
elif meta_dict['ml_framework'] in ['pyspark']:
|
|
694
|
+
|
|
695
|
+
# Normalize model_config to dict (handles None, dict, or string)
|
|
696
|
+
model_config_temp = _normalize_model_config(
|
|
697
|
+
meta_dict.get("model_config"),
|
|
698
|
+
meta_dict.get('model_type')
|
|
699
|
+
)
|
|
700
|
+
|
|
701
|
+
try:
|
|
702
|
+
model_class = pyspark_model_from_string(meta_dict['model_type'])
|
|
703
|
+
default = model_class()
|
|
704
|
+
|
|
705
|
+
# get model config dict from pyspark model object
|
|
706
|
+
default_config_temp = {}
|
|
707
|
+
for key, value in default.extractParamMap().items():
|
|
708
|
+
default_config_temp[key.name] = value
|
|
709
|
+
|
|
710
|
+
# Sort the keys so default and model config key matches each other
|
|
711
|
+
model_config = dict(sorted(model_config_temp.items()))
|
|
712
|
+
default_config = dict(sorted(default_config_temp.items()))
|
|
713
|
+
|
|
714
|
+
model_configkeys = model_config.keys()
|
|
715
|
+
model_configvalues = model_config.values()
|
|
716
|
+
default_config = default_config.values()
|
|
717
|
+
except:
|
|
718
|
+
model_class = str(pyspark_model_from_string(meta_dict['model_type']))
|
|
719
|
+
if model_class.find("Voting") > 0:
|
|
720
|
+
default_config = ["No data available"]
|
|
721
|
+
model_configkeys = ["No data available"]
|
|
722
|
+
model_configvalues = ["No data available"]
|
|
723
|
+
else:
|
|
724
|
+
# Fallback for other exceptions
|
|
725
|
+
default_config = []
|
|
726
|
+
model_configkeys = []
|
|
727
|
+
model_configvalues = []
|
|
728
|
+
|
|
729
|
+
inspect_pd = pd.DataFrame({'param_name': model_configkeys,
|
|
730
|
+
'default_value': default_config,
|
|
731
|
+
'param_value': model_configvalues})
|
|
732
|
+
|
|
733
|
+
else:
|
|
734
|
+
meta_dict = {}
|
|
735
|
+
meta_dict['ml_framework'] = "undefined"
|
|
736
|
+
meta_dict['model_type'] = "undefined"
|
|
737
|
+
|
|
738
|
+
#inspect_pd = pd.DataFrame({' ':["No metadata available for this model"]})
|
|
739
|
+
inspect_pd = pd.DataFrame()
|
|
740
|
+
|
|
741
|
+
try:
|
|
742
|
+
#Get inspect json
|
|
743
|
+
inspectdatafilename = wget.download(s3_presigned_dict['get']['inspect_pd_'+str(model_version)+'.json'], out=temp+"/"+'inspect_pd_'+str(model_version)+'.json')
|
|
744
|
+
|
|
745
|
+
with open(temp+"/"+'inspect_pd_'+str(model_version)+'.json') as f:
|
|
746
|
+
model_dict = json.load(f)
|
|
747
|
+
except:
|
|
748
|
+
model_dict = {}
|
|
749
|
+
|
|
750
|
+
model_dict[str(model_version)] = {'ml_framework': meta_dict['ml_framework'],
|
|
751
|
+
'model_type': meta_dict['model_type'],
|
|
752
|
+
'model_dict': inspect_pd.to_dict()}
|
|
753
|
+
|
|
754
|
+
with open(temp+"/"+'inspect_pd_'+str(model_version)+'.json', 'w') as outfile:
|
|
755
|
+
json.dump(model_dict, outfile)
|
|
756
|
+
|
|
757
|
+
try:
|
|
758
|
+
|
|
759
|
+
putfilekeys=list(s3_presigned_dict['put'].keys())
|
|
760
|
+
modelputfiles = [s for s in putfilekeys if str('inspect_pd_'+str(model_version)+'.json') in s]
|
|
761
|
+
|
|
762
|
+
fileputlistofdicts=[]
|
|
763
|
+
for i in modelputfiles:
|
|
764
|
+
filedownload_dict=ast.literal_eval(s3_presigned_dict ['put'][i])
|
|
765
|
+
fileputlistofdicts.append(filedownload_dict)
|
|
766
|
+
|
|
767
|
+
with open(temp+"/"+'inspect_pd_'+str(model_version)+'.json', 'rb') as f:
|
|
768
|
+
files = {'file': (temp+"/"+'inspect_pd_'+str(model_version)+'.json', f)}
|
|
769
|
+
http_response = requests.post(fileputlistofdicts[0]['url'], data=fileputlistofdicts[0]['fields'], files=files)
|
|
770
|
+
except:
|
|
771
|
+
pass
|
|
772
|
+
return 1
|
|
773
|
+
|
|
774
|
+
|
|
775
|
+
def upload_model_graph(modelpath, s3_presigned_dict, bucket, model_id, model_version, onnx_model=None):
|
|
776
|
+
import wget
|
|
777
|
+
import json
|
|
778
|
+
import ast
|
|
779
|
+
temp=tmp.mkdtemp()
|
|
780
|
+
# get model summary from onnx
|
|
781
|
+
|
|
782
|
+
if onnx_model==None:
|
|
783
|
+
onnx_model = onnx.load(modelpath)
|
|
784
|
+
|
|
785
|
+
meta_dict = _get_metadata(onnx_model)
|
|
786
|
+
|
|
787
|
+
if meta_dict['ml_framework'] == 'keras':
|
|
788
|
+
|
|
789
|
+
model_graph = meta_dict['model_graph']
|
|
790
|
+
|
|
791
|
+
if meta_dict['ml_framework'] == 'pytorch':
|
|
792
|
+
|
|
793
|
+
model_graph = ''
|
|
794
|
+
|
|
795
|
+
elif meta_dict['ml_framework'] in ['sklearn', 'xgboost', 'pyspark']:
|
|
796
|
+
|
|
797
|
+
model_graph = ''
|
|
798
|
+
|
|
799
|
+
key = model_id+'/model_graph_'+str(model_version)+'.json'
|
|
800
|
+
|
|
801
|
+
try:
|
|
802
|
+
#Get inspect json
|
|
803
|
+
modelgraphdatafilename = wget.download(s3_presigned_dict['get']['model_graph_'+str(model_version)+'.json'], out=temp+"/"+'model_graph_'+str(model_version)+'.json')
|
|
804
|
+
|
|
805
|
+
with open(temp+"/"+'model_graph_'+str(model_version)+'.json') as f:
|
|
806
|
+
graph_dict = json.load(f)
|
|
807
|
+
|
|
808
|
+
except:
|
|
809
|
+
graph_dict = {}
|
|
810
|
+
|
|
811
|
+
graph_dict[str(model_version)] = {'ml_framework': meta_dict['ml_framework'],
|
|
812
|
+
'model_type': meta_dict['model_type'],
|
|
813
|
+
'model_graph': model_graph}
|
|
814
|
+
|
|
815
|
+
with open(temp+"/"+'model_graph_'+str(model_version)+'.json', 'w') as outfile:
|
|
816
|
+
json.dump(graph_dict, outfile)
|
|
817
|
+
|
|
818
|
+
try:
|
|
819
|
+
|
|
820
|
+
putfilekeys=list(s3_presigned_dict['put'].keys())
|
|
821
|
+
modelputfiles = [s for s in putfilekeys if str('model_graph_'+str(model_version)+'.json') in s]
|
|
822
|
+
|
|
823
|
+
fileputlistofdicts=[]
|
|
824
|
+
for i in modelputfiles:
|
|
825
|
+
filedownload_dict=ast.literal_eval(s3_presigned_dict ['put'][i])
|
|
826
|
+
fileputlistofdicts.append(filedownload_dict)
|
|
827
|
+
|
|
828
|
+
with open(temp+"/"+'model_graph_'+str(model_version)+'.json', 'rb') as f:
|
|
829
|
+
files = {'file': (temp+"/"+'model_graph_'+str(model_version)+'.json', f)}
|
|
830
|
+
http_response = requests.post(fileputlistofdicts[0]['url'], data=fileputlistofdicts[0]['fields'], files=files)
|
|
831
|
+
except:
|
|
832
|
+
pass
|
|
833
|
+
|
|
834
|
+
return 1
|
|
835
|
+
|
|
836
|
+
|
|
837
|
+
def submit_model(
|
|
838
|
+
model_filepath=None,
|
|
839
|
+
apiurl=None,
|
|
840
|
+
prediction_submission=None,
|
|
841
|
+
preprocessor=None,
|
|
842
|
+
reproducibility_env_filepath=None,
|
|
843
|
+
custom_metadata=None,
|
|
844
|
+
submission_type="competition",
|
|
845
|
+
input_dict = None,
|
|
846
|
+
print_output=True,
|
|
847
|
+
debug_preprocessor=False,
|
|
848
|
+
token=None,
|
|
849
|
+
return_metrics=None # <--- NEW ARGUMENT
|
|
850
|
+
):
|
|
851
|
+
"""
|
|
852
|
+
Submits model/preprocessor to machine learning competition using live prediction API url.
|
|
853
|
+
The submitted model gets evaluated and compared with all existing models and a leaderboard can be generated
|
|
854
|
+
"""
|
|
855
|
+
|
|
856
|
+
# catch missing model_input for pytorch
|
|
857
|
+
try:
|
|
858
|
+
import torch
|
|
859
|
+
if isinstance(model_filepath, torch.nn.Module) and model_input==None:
|
|
860
|
+
raise ValueError("Please submit valid model_input for pytorch model.")
|
|
861
|
+
except:
|
|
862
|
+
pass
|
|
863
|
+
|
|
864
|
+
# check whether preprocessor is function and validate export
|
|
865
|
+
preprocessor = _prepare_preprocessor_if_function(preprocessor, debug_mode=debug_preprocessor)
|
|
866
|
+
|
|
867
|
+
import os
|
|
868
|
+
from aimodelshare.aws import get_aws_token
|
|
869
|
+
from aimodelshare.modeluser import get_jwt_token
|
|
870
|
+
import ast
|
|
871
|
+
|
|
872
|
+
# Confirm that creds are loaded, raise error if not
|
|
873
|
+
if token==None:
|
|
874
|
+
if not all(["username" in os.environ,
|
|
875
|
+
"password" in os.environ]):
|
|
876
|
+
raise RuntimeError("'Submit Model' unsuccessful. Please provide username and password using set_credentials() function.")
|
|
877
|
+
else:
|
|
878
|
+
pass
|
|
879
|
+
|
|
880
|
+
##---Step 2: Get bucket and model_id for playground and check prediction submission structure
|
|
881
|
+
apiurl=apiurl.replace('"','')
|
|
882
|
+
|
|
883
|
+
# Get bucket and model_id for user
|
|
884
|
+
if token==None:
|
|
885
|
+
response, error = run_function_on_lambda(
|
|
886
|
+
apiurl, **{"delete": "FALSE", "versionupdateget": "TRUE"}
|
|
887
|
+
)
|
|
888
|
+
username = os.environ.get("username")
|
|
889
|
+
else:
|
|
890
|
+
from aimodelshare.aws import get_token_from_session, _get_username_from_token
|
|
891
|
+
username=_get_username_from_token(token)
|
|
892
|
+
response, error = run_function_on_lambda(
|
|
893
|
+
apiurl, username=username, token=token,**{"delete": "FALSE", "versionupdateget": "TRUE"}
|
|
894
|
+
)
|
|
895
|
+
if error is not None:
|
|
896
|
+
raise error
|
|
897
|
+
|
|
898
|
+
_, bucket, model_id = json.loads(response.content.decode("utf-8"))
|
|
899
|
+
|
|
900
|
+
# Add call to eval lambda here to retrieve presigned urls and eval metrics
|
|
901
|
+
if prediction_submission is not None:
|
|
902
|
+
if type(prediction_submission) is not list:
|
|
903
|
+
prediction_submission=prediction_submission.tolist()
|
|
904
|
+
else:
|
|
905
|
+
pass
|
|
906
|
+
|
|
907
|
+
if all(isinstance(x, (np.float64)) for x in prediction_submission):
|
|
908
|
+
prediction_submission = [float(i) for i in prediction_submission]
|
|
909
|
+
else:
|
|
910
|
+
pass
|
|
911
|
+
|
|
912
|
+
##---Step 3: Attempt to get eval metrics and file access dict for model leaderboard submission
|
|
913
|
+
import os
|
|
914
|
+
import pickle
|
|
915
|
+
temp = tmp.mkdtemp()
|
|
916
|
+
predictions_path = temp + "/" + 'predictions.pkl'
|
|
917
|
+
|
|
918
|
+
fileObject = open(predictions_path, 'wb')
|
|
919
|
+
pickle.dump(prediction_submission, fileObject)
|
|
920
|
+
predfilesize=os.path.getsize(predictions_path)
|
|
921
|
+
fileObject.close()
|
|
922
|
+
|
|
923
|
+
if predfilesize>3555000:
|
|
924
|
+
post_dict = {"y_pred": [],
|
|
925
|
+
"return_eval_files": "True",
|
|
926
|
+
"submission_type": submission_type,
|
|
927
|
+
"return_y": "False"}
|
|
928
|
+
if token==None:
|
|
929
|
+
headers = { 'Content-Type':'application/json', 'authorizationToken': json.dumps({"token":os.environ.get("AWS_TOKEN"),"eval":"TEST"}), }
|
|
930
|
+
else:
|
|
931
|
+
headers = { 'Content-Type':'application/json', 'authorizationToken': json.dumps({"token":token,"eval":"TEST"}), }
|
|
932
|
+
|
|
933
|
+
apiurl_eval=apiurl[:-1]+"eval"
|
|
934
|
+
predictionfiles = requests.post(apiurl_eval,headers=headers,data=json.dumps(post_dict))
|
|
935
|
+
eval_metrics=json.loads(predictionfiles.text)
|
|
936
|
+
|
|
937
|
+
s3_presigned_dict = {key:val for key, val in eval_metrics.items() if key != 'eval'}
|
|
938
|
+
|
|
939
|
+
idempotentmodel_version=s3_presigned_dict['idempotentmodel_version']
|
|
940
|
+
s3_presigned_dict.pop('idempotentmodel_version')
|
|
941
|
+
|
|
942
|
+
# Upload preprocessor (1s for small upload vs 21 for 306 mbs)
|
|
943
|
+
putfilekeys=list(s3_presigned_dict['put'].keys())
|
|
944
|
+
modelputfiles = [s for s in putfilekeys if str("pkl") in s]
|
|
945
|
+
|
|
946
|
+
fileputlistofdicts=[]
|
|
947
|
+
for i in modelputfiles:
|
|
948
|
+
filedownload_dict=ast.literal_eval(s3_presigned_dict ['put'][i])
|
|
949
|
+
fileputlistofdicts.append(filedownload_dict)
|
|
950
|
+
|
|
951
|
+
with open(predictions_path , 'rb') as f:
|
|
952
|
+
files = {'file': (predictions_path , f)}
|
|
953
|
+
http_response = requests.post(fileputlistofdicts[0]['url'], data=fileputlistofdicts[0]['fields'], files=files)
|
|
954
|
+
f.close()
|
|
955
|
+
|
|
956
|
+
post_dict = {"y_pred": [],
|
|
957
|
+
"predictionpklname":fileputlistofdicts[0]['fields']['key'].split("/")[2],
|
|
958
|
+
"submission_type": submission_type,
|
|
959
|
+
"return_y": "False",
|
|
960
|
+
"return_eval": "True"}
|
|
961
|
+
|
|
962
|
+
apiurl_eval=apiurl[:-1]+"eval"
|
|
963
|
+
prediction = requests.post(apiurl_eval,headers=headers,data=json.dumps(post_dict))
|
|
964
|
+
|
|
965
|
+
else:
|
|
966
|
+
post_dict = {"y_pred": prediction_submission,
|
|
967
|
+
"return_eval": "True",
|
|
968
|
+
"submission_type": submission_type,
|
|
969
|
+
"return_y": "False"}
|
|
970
|
+
|
|
971
|
+
if token==None:
|
|
972
|
+
headers = { 'Content-Type':'application/json', 'authorizationToken': json.dumps({"token":os.environ.get("AWS_TOKEN"),"eval":"TEST"}), }
|
|
973
|
+
else:
|
|
974
|
+
headers = { 'Content-Type':'application/json', 'authorizationToken': json.dumps({"token":token,"eval":"TEST"}), }
|
|
975
|
+
apiurl_eval=apiurl[:-1]+"eval"
|
|
976
|
+
import requests
|
|
977
|
+
prediction = requests.post(apiurl_eval,headers=headers,data=json.dumps(post_dict))
|
|
978
|
+
|
|
979
|
+
# Parse the raw API response
|
|
980
|
+
eval_metrics_raw = json.loads(prediction.text)
|
|
981
|
+
|
|
982
|
+
# Validate API response structure
|
|
983
|
+
if not isinstance(eval_metrics_raw, dict):
|
|
984
|
+
if isinstance(eval_metrics_raw, list):
|
|
985
|
+
error_msg = str(eval_metrics_raw[0]) if eval_metrics_raw else "Empty list response"
|
|
986
|
+
raise RuntimeError(f'Unauthorized user: {error_msg}')
|
|
987
|
+
else:
|
|
988
|
+
raise RuntimeError('Unauthorized user: You do not have access to submit models to, or request data from, this competition.')
|
|
989
|
+
|
|
990
|
+
if "message" in eval_metrics_raw:
|
|
991
|
+
raise RuntimeError(f'Unauthorized user: {eval_metrics_raw.get("message", "You do not have access to submit models to, or request data from, this competition.")}')
|
|
992
|
+
|
|
993
|
+
# Extract S3 presigned URL structure separately (before normalizing eval metrics)
|
|
994
|
+
s3_presigned_dict = {key: val for key, val in eval_metrics_raw.items() if key != 'eval'}
|
|
995
|
+
|
|
996
|
+
if 'idempotentmodel_version' not in s3_presigned_dict:
|
|
997
|
+
raise RuntimeError("Failed to get model version from API. Please check the API response.")
|
|
998
|
+
|
|
999
|
+
idempotentmodel_version = s3_presigned_dict['idempotentmodel_version']
|
|
1000
|
+
s3_presigned_dict.pop('idempotentmodel_version')
|
|
1001
|
+
|
|
1002
|
+
# Normalize eval metrics
|
|
1003
|
+
eval_metrics, eval_metrics_private = _normalize_eval_payload(eval_metrics_raw)
|
|
1004
|
+
|
|
1005
|
+
# Check if we got any valid metrics
|
|
1006
|
+
if not eval_metrics and not eval_metrics_private:
|
|
1007
|
+
print("---------------------------------------------------------------")
|
|
1008
|
+
print("--- WARNING: No evaluation metrics returned from API ---")
|
|
1009
|
+
print("Proceeding with empty metrics. Model will be submitted without eval data.")
|
|
1010
|
+
print("---------------------------------------------------------------")
|
|
1011
|
+
|
|
1012
|
+
# Upload preprocessor
|
|
1013
|
+
putfilekeys=list(s3_presigned_dict['put'].keys())
|
|
1014
|
+
|
|
1015
|
+
# Find preprocessor upload key using explicit pattern matching
|
|
1016
|
+
preprocessor_key = None
|
|
1017
|
+
for key in putfilekeys:
|
|
1018
|
+
if 'preprocessor_v' in key and key.endswith('.zip'):
|
|
1019
|
+
preprocessor_key = key
|
|
1020
|
+
break
|
|
1021
|
+
elif 'preprocessor' in key and key.endswith('.zip'):
|
|
1022
|
+
preprocessor_key = key
|
|
1023
|
+
|
|
1024
|
+
if preprocessor_key is None and preprocessor is not None:
|
|
1025
|
+
# Fallback to original logic if no explicit match
|
|
1026
|
+
modelputfiles = [s for s in putfilekeys if str("zip") in s]
|
|
1027
|
+
if modelputfiles:
|
|
1028
|
+
preprocessor_key = modelputfiles[0]
|
|
1029
|
+
|
|
1030
|
+
if preprocessor is not None:
|
|
1031
|
+
if preprocessor_key is None:
|
|
1032
|
+
raise RuntimeError("Failed to find preprocessor upload URL in presigned URLs")
|
|
1033
|
+
|
|
1034
|
+
filedownload_dict = ast.literal_eval(s3_presigned_dict['put'][preprocessor_key])
|
|
1035
|
+
|
|
1036
|
+
with open(preprocessor, 'rb') as f:
|
|
1037
|
+
files = {'file': (preprocessor, f)}
|
|
1038
|
+
http_response = requests.post(filedownload_dict['url'], data=filedownload_dict['fields'], files=files)
|
|
1039
|
+
|
|
1040
|
+
if http_response.status_code not in [200, 204]:
|
|
1041
|
+
raise RuntimeError(
|
|
1042
|
+
f"Preprocessor upload failed with status {http_response.status_code}: {http_response.text}"
|
|
1043
|
+
)
|
|
1044
|
+
|
|
1045
|
+
putfilekeys=list(s3_presigned_dict['put'].keys())
|
|
1046
|
+
modelputfiles = [s for s in putfilekeys if str("onnx") in s]
|
|
1047
|
+
|
|
1048
|
+
fileputlistofdicts=[]
|
|
1049
|
+
for i in modelputfiles:
|
|
1050
|
+
filedownload_dict=ast.literal_eval(s3_presigned_dict ['put'][i])
|
|
1051
|
+
fileputlistofdicts.append(filedownload_dict)
|
|
1052
|
+
|
|
1053
|
+
if not (model_filepath == None or isinstance(model_filepath, str)):
|
|
1054
|
+
if isinstance(model_filepath, onnx.ModelProto):
|
|
1055
|
+
onnx_model = model_filepath
|
|
1056
|
+
else:
|
|
1057
|
+
print("Transform model object to onnx.")
|
|
1058
|
+
try:
|
|
1059
|
+
import torch
|
|
1060
|
+
if isinstance(model_filepath, torch.nn.Module) and model_input==None:
|
|
1061
|
+
onnx_model = model_to_onnx(model_filepath, model_input=model_input)
|
|
1062
|
+
except:
|
|
1063
|
+
onnx_model = model_to_onnx(model_filepath)
|
|
1064
|
+
pass
|
|
1065
|
+
|
|
1066
|
+
temp_prep=tmp.mkdtemp()
|
|
1067
|
+
model_filepath = temp_prep+"/model.onnx"
|
|
1068
|
+
with open(model_filepath, "wb") as f:
|
|
1069
|
+
f.write(onnx_model.SerializeToString())
|
|
1070
|
+
|
|
1071
|
+
load_onnx_from_path = False
|
|
1072
|
+
else:
|
|
1073
|
+
load_onnx_from_path = True
|
|
1074
|
+
|
|
1075
|
+
if model_filepath is not None:
|
|
1076
|
+
with open(model_filepath, 'rb') as f:
|
|
1077
|
+
files = {'file': (model_filepath, f)}
|
|
1078
|
+
http_response = requests.post(fileputlistofdicts[1]['url'], data=fileputlistofdicts[1]['fields'], files=files)
|
|
1079
|
+
|
|
1080
|
+
putfilekeys=list(s3_presigned_dict['put'].keys())
|
|
1081
|
+
modelputfiles = [s for s in putfilekeys if str("reproducibility") in s]
|
|
1082
|
+
|
|
1083
|
+
fileputlistofdicts=[]
|
|
1084
|
+
for i in modelputfiles:
|
|
1085
|
+
filedownload_dict=ast.literal_eval(s3_presigned_dict ['put'][i])
|
|
1086
|
+
fileputlistofdicts.append(filedownload_dict)
|
|
1087
|
+
|
|
1088
|
+
if reproducibility_env_filepath:
|
|
1089
|
+
with open(reproducibility_env_filepath, 'rb') as f:
|
|
1090
|
+
files = {'file': (reproducibility_env_filepath, f)}
|
|
1091
|
+
http_response = requests.post(fileputlistofdicts[0]['url'], data=fileputlistofdicts[0]['fields'], files=files)
|
|
1092
|
+
|
|
1093
|
+
# Model metadata upload
|
|
1094
|
+
if model_filepath:
|
|
1095
|
+
putfilekeys=list(s3_presigned_dict['put'].keys())
|
|
1096
|
+
modelputfiles = [s for s in putfilekeys if str("model_metadata") in s]
|
|
1097
|
+
|
|
1098
|
+
fileputlistofdicts=[]
|
|
1099
|
+
for i in modelputfiles:
|
|
1100
|
+
filedownload_dict=ast.literal_eval(s3_presigned_dict ['put'][i])
|
|
1101
|
+
fileputlistofdicts.append(filedownload_dict)
|
|
1102
|
+
|
|
1103
|
+
if load_onnx_from_path:
|
|
1104
|
+
onnx_model = onnx.load(model_filepath)
|
|
1105
|
+
|
|
1106
|
+
meta_dict = _get_metadata(onnx_model)
|
|
1107
|
+
model_metadata = {
|
|
1108
|
+
"model_config": meta_dict["model_config"],
|
|
1109
|
+
"ml_framework": meta_dict["ml_framework"],
|
|
1110
|
+
"model_type": meta_dict["model_type"]
|
|
1111
|
+
}
|
|
1112
|
+
|
|
1113
|
+
temp = tmp.mkdtemp()
|
|
1114
|
+
model_metadata_path = temp + "/" + 'model_metadata.json'
|
|
1115
|
+
with open(model_metadata_path, 'w') as outfile:
|
|
1116
|
+
json.dump(model_metadata, outfile)
|
|
1117
|
+
|
|
1118
|
+
with open(model_metadata_path, 'rb') as f:
|
|
1119
|
+
files = {'file': (model_metadata_path, f)}
|
|
1120
|
+
http_response = requests.post(fileputlistofdicts[0]['url'], data=fileputlistofdicts[0]['fields'], files=files)
|
|
1121
|
+
|
|
1122
|
+
# Upload model metrics and metadata
|
|
1123
|
+
if load_onnx_from_path:
|
|
1124
|
+
modelleaderboarddata = _update_leaderboard_public(
|
|
1125
|
+
modelpath, eval_metrics, s3_presigned_dict,
|
|
1126
|
+
username=username, # Explicit keyword argument
|
|
1127
|
+
custom_metadata=custom_metadata
|
|
1128
|
+
)
|
|
1129
|
+
modelleaderboarddata_private = _update_leaderboard_public(
|
|
1130
|
+
modelpath, eval_metrics_private, s3_presigned_dict,
|
|
1131
|
+
username=username, # Explicit keyword argument
|
|
1132
|
+
custom_metadata=custom_metadata,
|
|
1133
|
+
private=True
|
|
1134
|
+
)
|
|
1135
|
+
else:
|
|
1136
|
+
modelleaderboarddata = _update_leaderboard_public(
|
|
1137
|
+
None, eval_metrics, s3_presigned_dict,
|
|
1138
|
+
username=username, # FIX: Explicitly map username
|
|
1139
|
+
custom_metadata=custom_metadata, # FIX: Explicitly map metadata
|
|
1140
|
+
onnx_model=onnx_model
|
|
1141
|
+
)
|
|
1142
|
+
modelleaderboarddata_private = _update_leaderboard_public(
|
|
1143
|
+
None, eval_metrics_private, s3_presigned_dict,
|
|
1144
|
+
username=username, # FIX: Explicitly map username
|
|
1145
|
+
custom_metadata=custom_metadata, # FIX: Explicitly map metadata
|
|
1146
|
+
private=True,
|
|
1147
|
+
onnx_model=onnx_model
|
|
1148
|
+
)
|
|
1149
|
+
|
|
1150
|
+
model_versions = [os.path.splitext(f)[0].split("_")[-1][1:] for f in s3_presigned_dict['put'].keys()]
|
|
1151
|
+
model_versions = filter(lambda v: v.isnumeric(), model_versions)
|
|
1152
|
+
model_versions = list(map(int, model_versions))
|
|
1153
|
+
model_version=model_versions[0]
|
|
1154
|
+
|
|
1155
|
+
if load_onnx_from_path:
|
|
1156
|
+
if model_filepath is not None:
|
|
1157
|
+
upload_model_dict(model_filepath, s3_presigned_dict, bucket, model_id, model_version)
|
|
1158
|
+
upload_model_graph(model_filepath, s3_presigned_dict, bucket, model_id, model_version)
|
|
1159
|
+
else:
|
|
1160
|
+
upload_model_dict(model_filepath, s3_presigned_dict, bucket, model_id, model_version, placeholder=True)
|
|
1161
|
+
else:
|
|
1162
|
+
upload_model_dict(None, s3_presigned_dict, bucket, model_id, model_version, onnx_model=onnx_model)
|
|
1163
|
+
upload_model_graph(None, s3_presigned_dict, bucket, model_id, model_version, onnx_model=onnx_model)
|
|
1164
|
+
|
|
1165
|
+
modelpath=model_filepath
|
|
1166
|
+
|
|
1167
|
+
def dict_clean(items):
|
|
1168
|
+
result = {}
|
|
1169
|
+
for key, value in items:
|
|
1170
|
+
if value is None:
|
|
1171
|
+
value = '0'
|
|
1172
|
+
result[key] = value
|
|
1173
|
+
return result
|
|
1174
|
+
|
|
1175
|
+
if isinstance(modelleaderboarddata, Exception):
|
|
1176
|
+
raise err
|
|
1177
|
+
else:
|
|
1178
|
+
dict_str = json.dumps(modelleaderboarddata)
|
|
1179
|
+
modelleaderboarddata_cleaned = json.loads(dict_str, object_pairs_hook=dict_clean)
|
|
1180
|
+
|
|
1181
|
+
if isinstance(modelleaderboarddata_private, Exception):
|
|
1182
|
+
raise err
|
|
1183
|
+
else:
|
|
1184
|
+
dict_str = json.dumps(modelleaderboarddata_private)
|
|
1185
|
+
modelleaderboarddata_private_cleaned = json.loads(dict_str, object_pairs_hook=dict_clean)
|
|
1186
|
+
|
|
1187
|
+
if input_dict == None:
|
|
1188
|
+
modelsubmissiontags=input("Insert search tags to help users find your model (optional): ")
|
|
1189
|
+
modelsubmissiondescription=input("Provide any useful notes about your model (optional): ")
|
|
1190
|
+
else:
|
|
1191
|
+
modelsubmissiontags = input_dict["tags"]
|
|
1192
|
+
modelsubmissiondescription = input_dict["description"]
|
|
1193
|
+
|
|
1194
|
+
if submission_type=="competition":
|
|
1195
|
+
experimenttruefalse="FALSE"
|
|
1196
|
+
else:
|
|
1197
|
+
experimenttruefalse="TRUE"
|
|
1198
|
+
|
|
1199
|
+
#Update competition or experiment data
|
|
1200
|
+
bodydata = {"apiurl": apiurl,
|
|
1201
|
+
"submissions": model_version,
|
|
1202
|
+
"contributoruniquenames":os.environ.get('username'),
|
|
1203
|
+
"versionupdateputsubmit":"TRUE",
|
|
1204
|
+
"experiment":experimenttruefalse
|
|
1205
|
+
}
|
|
1206
|
+
|
|
1207
|
+
# Get the response
|
|
1208
|
+
if token==None:
|
|
1209
|
+
headers_with_authentication = {'Content-Type': 'application/json', 'authorizationToken': os.environ.get("AWS_TOKEN"), 'Access-Control-Allow-Headers':
|
|
1210
|
+
'Content-Type,X-Amz-Date,authorizationToken,Access-Control-Allow-Origin,X-Api-Key,X-Amz-Security-Token,Authorization', 'Access-Control-Allow-Origin': '*'}
|
|
1211
|
+
else:
|
|
1212
|
+
headers_with_authentication = {'Content-Type': 'application/json', 'authorizationToken': token, 'Access-Control-Allow-Headers':
|
|
1213
|
+
'Content-Type,X-Amz-Date,authorizationToken,Access-Control-Allow-Origin,X-Api-Key,X-Amz-Security-Token,Authorization', 'Access-Control-Allow-Origin': '*'}
|
|
1214
|
+
|
|
1215
|
+
# --------------------------------------------------------------------------------
|
|
1216
|
+
# BACKEND UPDATE 1: Updates submission counts and contributor names
|
|
1217
|
+
# --------------------------------------------------------------------------------
|
|
1218
|
+
requests.post("https://o35jwfakca.execute-api.us-east-1.amazonaws.com/dev/modeldata",
|
|
1219
|
+
json=bodydata, headers=headers_with_authentication)
|
|
1220
|
+
|
|
1221
|
+
|
|
1222
|
+
if modelpath is not None:
|
|
1223
|
+
# get model summary from onnx
|
|
1224
|
+
if load_onnx_from_path:
|
|
1225
|
+
onnx_model = onnx.load(modelpath)
|
|
1226
|
+
meta_dict = _get_metadata(onnx_model)
|
|
1227
|
+
|
|
1228
|
+
if meta_dict['ml_framework'] == 'keras':
|
|
1229
|
+
inspect_pd = _model_summary(meta_dict)
|
|
1230
|
+
model_graph = ""
|
|
1231
|
+
if meta_dict['ml_framework'] == 'pytorch':
|
|
1232
|
+
inspect_pd = _model_summary(meta_dict)
|
|
1233
|
+
model_graph = ""
|
|
1234
|
+
elif meta_dict['ml_framework'] in ['sklearn', 'xgboost']:
|
|
1235
|
+
model_config = _normalize_model_config(
|
|
1236
|
+
meta_dict.get("model_config"),
|
|
1237
|
+
meta_dict.get('model_type')
|
|
1238
|
+
)
|
|
1239
|
+
inspect_pd = _build_sklearn_param_dataframe(
|
|
1240
|
+
meta_dict['model_type'],
|
|
1241
|
+
model_config
|
|
1242
|
+
)
|
|
1243
|
+
model_graph = ''
|
|
1244
|
+
elif meta_dict['ml_framework'] in ['pyspark']:
|
|
1245
|
+
model_config_temp = _normalize_model_config(
|
|
1246
|
+
meta_dict.get("model_config"),
|
|
1247
|
+
meta_dict.get('model_type')
|
|
1248
|
+
)
|
|
1249
|
+
try:
|
|
1250
|
+
model_class = pyspark_model_from_string(meta_dict['model_type'])
|
|
1251
|
+
default = model_class()
|
|
1252
|
+
default_config_temp = {}
|
|
1253
|
+
for key, value in default.extractParamMap().items():
|
|
1254
|
+
default_config_temp[key.name] = value
|
|
1255
|
+
|
|
1256
|
+
model_config = dict(sorted(model_config_temp.items()))
|
|
1257
|
+
default_config = dict(sorted(default_config_temp.items()))
|
|
1258
|
+
|
|
1259
|
+
model_configkeys = model_config.keys()
|
|
1260
|
+
model_configvalues = model_config.values()
|
|
1261
|
+
default_config = default_config.values()
|
|
1262
|
+
except:
|
|
1263
|
+
model_class = str(pyspark_model_from_string(meta_dict['model_type']))
|
|
1264
|
+
if model_class.find("Voting") > 0:
|
|
1265
|
+
default_config = ["No data available"]
|
|
1266
|
+
model_configkeys = ["No data available"]
|
|
1267
|
+
model_configvalues = ["No data available"]
|
|
1268
|
+
else:
|
|
1269
|
+
default_config = []
|
|
1270
|
+
model_configkeys = []
|
|
1271
|
+
model_configvalues = []
|
|
1272
|
+
|
|
1273
|
+
inspect_pd = pd.DataFrame({'param_name': model_configkeys,
|
|
1274
|
+
'default_value': default_config,
|
|
1275
|
+
'param_value': model_configvalues})
|
|
1276
|
+
model_graph = ""
|
|
1277
|
+
else:
|
|
1278
|
+
inspect_pd = pd.DataFrame()
|
|
1279
|
+
model_graph = ''
|
|
1280
|
+
|
|
1281
|
+
keys_to_extract = [ "accuracy", "f1_score", "precision", "recall", "mse", "rmse", "mae", "r2"]
|
|
1282
|
+
|
|
1283
|
+
# Safely extract metric subsets using helper function
|
|
1284
|
+
eval_metrics_subset = _subset_numeric(eval_metrics, keys_to_extract)
|
|
1285
|
+
eval_metrics_private_subset = _subset_numeric(eval_metrics_private, keys_to_extract)
|
|
1286
|
+
|
|
1287
|
+
# Keep only numeric values
|
|
1288
|
+
eval_metrics_subset_nonulls = {key: value for key, value in eval_metrics_subset.items() if isinstance(value, (int, float))}
|
|
1289
|
+
eval_metrics_private_subset_nonulls = {key: value for key, value in eval_metrics_private_subset.items() if isinstance(value, (int, float))}
|
|
1290
|
+
|
|
1291
|
+
# Update model architecture data
|
|
1292
|
+
bodydatamodels = {
|
|
1293
|
+
"apiurl": apiurl,
|
|
1294
|
+
"modelsummary":json.dumps(inspect_pd.to_json()),
|
|
1295
|
+
"model_graph": model_graph,
|
|
1296
|
+
"Private":"FALSE",
|
|
1297
|
+
"modelsubmissiondescription": modelsubmissiondescription,
|
|
1298
|
+
"modelsubmissiontags":modelsubmissiontags,
|
|
1299
|
+
"eval_metrics":json.dumps(eval_metrics_subset_nonulls),
|
|
1300
|
+
"eval_metrics_private":json.dumps(eval_metrics_private_subset_nonulls),
|
|
1301
|
+
"submission_type": submission_type
|
|
1302
|
+
}
|
|
1303
|
+
|
|
1304
|
+
bodydatamodels.update(modelleaderboarddata_cleaned)
|
|
1305
|
+
bodydatamodels.update(modelleaderboarddata_private_cleaned)
|
|
1306
|
+
|
|
1307
|
+
d = bodydatamodels
|
|
1308
|
+
keys_values = d.items()
|
|
1309
|
+
bodydatamodels_allstrings = {str(key): str(value) for key, value in keys_values}
|
|
1310
|
+
|
|
1311
|
+
if token==None:
|
|
1312
|
+
headers_with_authentication = {'Content-Type': 'application/json', 'authorizationToken': os.environ.get("AWS_TOKEN"), 'Access-Control-Allow-Headers':
|
|
1313
|
+
'Content-Type,X-Amz-Date,authorizationToken,Access-Control-Allow-Origin,X-Api-Key,X-Amz-Security-Token,Authorization', 'Access-Control-Allow-Origin': '*'}
|
|
1314
|
+
else:
|
|
1315
|
+
headers_with_authentication = {'Content-Type': 'application/json', 'authorizationToken': token, 'Access-Control-Allow-Headers':
|
|
1316
|
+
'Content-Type,X-Amz-Date,authorizationToken,Access-Control-Allow-Origin,X-Api-Key,X-Amz-Security-Token,Authorization', 'Access-Control-Allow-Origin': '*'}
|
|
1317
|
+
|
|
1318
|
+
# --------------------------------------------------------------------------------
|
|
1319
|
+
# BACKEND UPDATE 2: (CRITICAL) This updates the leaderboard database
|
|
1320
|
+
# --------------------------------------------------------------------------------
|
|
1321
|
+
response=requests.post("https://eeqq8zuo9j.execute-api.us-east-1.amazonaws.com/dev/modeldata",
|
|
1322
|
+
json=bodydatamodels_allstrings, headers=headers_with_authentication)
|
|
1323
|
+
|
|
1324
|
+
if str(response.status_code)=="200":
|
|
1325
|
+
code_comp_result="To submit code used to create this model or to view current leaderboard navigate to Model Playground: \n\n https://www.modelshare.ai/detail/model:"+response.text.split(":")[1]
|
|
1326
|
+
else:
|
|
1327
|
+
code_comp_result=""
|
|
1328
|
+
|
|
1329
|
+
model_page_url = "https://www.modelshare.ai/detail/model:"+response.text.split(":")[1]
|
|
1330
|
+
|
|
1331
|
+
if print_output:
|
|
1332
|
+
print("\nYour model has been submitted as model version "+str(model_version)+ "\n\n"+code_comp_result)
|
|
1333
|
+
|
|
1334
|
+
# --------------------------------------------------------------------------
|
|
1335
|
+
# NEW LOGIC: Return metrics ONLY after all backend updates are complete
|
|
1336
|
+
# --------------------------------------------------------------------------
|
|
1337
|
+
if return_metrics:
|
|
1338
|
+
# Determine source of metrics: prefer public, fallback to private, or empty
|
|
1339
|
+
source_metrics = eval_metrics if eval_metrics else (eval_metrics_private if eval_metrics_private else {})
|
|
1340
|
+
|
|
1341
|
+
# Determine keys to extract
|
|
1342
|
+
keys_to_fetch = []
|
|
1343
|
+
if isinstance(return_metrics, str):
|
|
1344
|
+
keys_to_fetch = [return_metrics]
|
|
1345
|
+
elif isinstance(return_metrics, list):
|
|
1346
|
+
keys_to_fetch = return_metrics
|
|
1347
|
+
elif return_metrics is True:
|
|
1348
|
+
# Return all keys available in the source
|
|
1349
|
+
keys_to_fetch = list(source_metrics.keys())
|
|
1350
|
+
|
|
1351
|
+
# Extract specific metrics into new dict
|
|
1352
|
+
returned_metrics_dict = {}
|
|
1353
|
+
for key in keys_to_fetch:
|
|
1354
|
+
val = source_metrics.get(key)
|
|
1355
|
+
# Unpack single-item lists if present (common pattern in Lambda response)
|
|
1356
|
+
if isinstance(val, list) and len(val) > 0:
|
|
1357
|
+
returned_metrics_dict[key] = val[0]
|
|
1358
|
+
else:
|
|
1359
|
+
returned_metrics_dict[key] = val
|
|
1360
|
+
|
|
1361
|
+
# Return extended tuple
|
|
1362
|
+
return str(model_version), model_page_url, returned_metrics_dict
|
|
1363
|
+
|
|
1364
|
+
# Default backward-compatible return
|
|
1365
|
+
return str(model_version), model_page_url
|
|
1366
|
+
|
|
1367
|
+
def update_runtime_model(apiurl, model_version=None, submission_type="competition"):
|
|
1368
|
+
"""
|
|
1369
|
+
apiurl: string of API URL that the user wishes to edit
|
|
1370
|
+
new_model_version: string of model version number (from leaderboard) to replace original model
|
|
1371
|
+
"""
|
|
1372
|
+
import os
|
|
1373
|
+
if os.environ.get("cloud_location") is not None:
|
|
1374
|
+
cloudlocation=os.environ.get("cloud_location")
|
|
1375
|
+
else:
|
|
1376
|
+
cloudlocation="not set"
|
|
1377
|
+
if "model_share"==cloudlocation:
|
|
1378
|
+
def nonecheck(objinput=""):
|
|
1379
|
+
if objinput==None:
|
|
1380
|
+
objinput="None"
|
|
1381
|
+
else:
|
|
1382
|
+
objinput="'/tmp/"+objinput+"'"
|
|
1383
|
+
return objinput
|
|
1384
|
+
|
|
1385
|
+
runtimemodstring="update_runtime_model('"+apiurl+"',"+str(model_version)+",submission_type='"+str(submission_type)+"')"
|
|
1386
|
+
import base64
|
|
1387
|
+
import requests
|
|
1388
|
+
import json
|
|
1389
|
+
|
|
1390
|
+
api_url = "https://z4kvag4sxdnv2mvs2b6c4thzj40bxnuw.lambda-url.us-east-2.on.aws/"
|
|
1391
|
+
|
|
1392
|
+
data = json.dumps({"code": """from aimodelshare.model import update_runtime_model;"""+runtimemodstring, "zipfilename": "","username":os.environ.get("username"), "password":os.environ.get("password"),"token":os.environ.get("JWT_AUTHORIZATION_TOKEN"),"s3keyid":"xrjpv1i7xe"})
|
|
1393
|
+
|
|
1394
|
+
headers = {"Content-Type": "application/json"}
|
|
1395
|
+
|
|
1396
|
+
response = requests.request("POST", api_url, headers = headers, data=data)
|
|
1397
|
+
# Print response
|
|
1398
|
+
result=json.loads(response.text)
|
|
1399
|
+
|
|
1400
|
+
for i in json.loads(result['body']):
|
|
1401
|
+
print(i)
|
|
1402
|
+
|
|
1403
|
+
else:
|
|
1404
|
+
# Confirm that creds are loaded, print warning if not
|
|
1405
|
+
if all(["AWS_ACCESS_KEY_ID_AIMS" in os.environ,
|
|
1406
|
+
"AWS_SECRET_ACCESS_KEY_AIMS" in os.environ,
|
|
1407
|
+
"AWS_REGION_AIMS" in os.environ,
|
|
1408
|
+
"username" in os.environ,
|
|
1409
|
+
"password" in os.environ]):
|
|
1410
|
+
pass
|
|
1411
|
+
else:
|
|
1412
|
+
return print("'Update Runtime Model' unsuccessful. Please provide credentials with set_credentials().")
|
|
1413
|
+
|
|
1414
|
+
# Create user session
|
|
1415
|
+
aws_client_and_resource=get_aws_client(aws_key=os.environ.get('AWS_ACCESS_KEY_ID_AIMS'),
|
|
1416
|
+
aws_secret=os.environ.get('AWS_SECRET_ACCESS_KEY_AIMS'),
|
|
1417
|
+
aws_region=os.environ.get('AWS_REGION_AIMS'))
|
|
1418
|
+
aws_client = aws_client_and_resource['client']
|
|
1419
|
+
|
|
1420
|
+
user_sess = boto3.session.Session(aws_access_key_id=os.environ.get('AWS_ACCESS_KEY_ID_AIMS'),
|
|
1421
|
+
aws_secret_access_key=os.environ.get('AWS_SECRET_ACCESS_KEY_AIMS'),
|
|
1422
|
+
region_name=os.environ.get('AWS_REGION_AIMS'))
|
|
1423
|
+
|
|
1424
|
+
s3 = user_sess.resource('s3')
|
|
1425
|
+
model_version=str(model_version)
|
|
1426
|
+
# Get bucket and model_id for user based on apiurl {{{
|
|
1427
|
+
response, error = run_function_on_lambda(
|
|
1428
|
+
apiurl, **{"delete": "FALSE", "versionupdateget": "TRUE"}
|
|
1429
|
+
)
|
|
1430
|
+
if error is not None:
|
|
1431
|
+
raise error
|
|
1432
|
+
import json
|
|
1433
|
+
_, api_bucket, model_id = json.loads(response.content.decode("utf-8"))
|
|
1434
|
+
# }}}
|
|
1435
|
+
|
|
1436
|
+
try:
|
|
1437
|
+
leaderboard = get_leaderboard(apiurl=apiurl, submission_type=submission_type)
|
|
1438
|
+
|
|
1439
|
+
columns = leaderboard.columns
|
|
1440
|
+
leaderboardversion=leaderboard[leaderboard['version']==int(model_version)]
|
|
1441
|
+
leaderboardversion=leaderboardversion.dropna(axis=1)
|
|
1442
|
+
|
|
1443
|
+
metric_names_subset=list(columns[0:4])
|
|
1444
|
+
leaderboardversiondict=leaderboardversion.loc[:,metric_names_subset].to_dict('records')[0]
|
|
1445
|
+
|
|
1446
|
+
except Exception as err:
|
|
1447
|
+
raise err
|
|
1448
|
+
|
|
1449
|
+
# Get file list for current bucket {{{
|
|
1450
|
+
model_files, err = _get_file_list(aws_client, api_bucket, model_id+"/"+submission_type)
|
|
1451
|
+
if err is not None:
|
|
1452
|
+
raise err
|
|
1453
|
+
# }}}
|
|
1454
|
+
|
|
1455
|
+
# extract subfolder objects specific to the model id
|
|
1456
|
+
folder = s3.meta.client.list_objects(Bucket=api_bucket, Prefix=model_id+"/"+submission_type+"/")
|
|
1457
|
+
bucket = s3.Bucket(api_bucket)
|
|
1458
|
+
file_list = [file['Key'] for file in folder['Contents']]
|
|
1459
|
+
s3 = boto3.resource('s3')
|
|
1460
|
+
model_source_key = model_id+"/"+submission_type+"/onnx_model_v"+str(model_version)+".onnx"
|
|
1461
|
+
preprocesor_source_key = model_id+"/"+submission_type+"/preprocessor_v"+str(model_version)+".zip"
|
|
1462
|
+
model_copy_source = {
|
|
1463
|
+
'Bucket': api_bucket,
|
|
1464
|
+
'Key': model_source_key
|
|
1465
|
+
}
|
|
1466
|
+
preprocessor_copy_source = {
|
|
1467
|
+
'Bucket': api_bucket,
|
|
1468
|
+
'Key': preprocesor_source_key
|
|
1469
|
+
}
|
|
1470
|
+
# Sending correct model metrics to front end
|
|
1471
|
+
bodydatamodelmetrics={"apiurl":apiurl,
|
|
1472
|
+
"versionupdateput":"TRUE",
|
|
1473
|
+
"verified_metrics":"TRUE",
|
|
1474
|
+
"eval_metrics":json.dumps(leaderboardversiondict)}
|
|
1475
|
+
import requests
|
|
1476
|
+
headers = { 'Content-Type':'application/json', 'authorizationToken': os.environ.get("AWS_TOKEN"), }
|
|
1477
|
+
prediction = requests.post("https://bhrdesksak.execute-api.us-east-1.amazonaws.com/dev/modeldata",headers=headers,data=json.dumps(bodydatamodelmetrics))
|
|
1478
|
+
|
|
1479
|
+
# overwrite runtime_model.onnx file & runtime_preprocessor.zip files:
|
|
1480
|
+
if (model_source_key in file_list) & (preprocesor_source_key in file_list):
|
|
1481
|
+
response = bucket.copy(model_copy_source, model_id+"/"+'runtime_model.onnx')
|
|
1482
|
+
response = bucket.copy(preprocessor_copy_source, model_id+"/"+'runtime_preprocessor.zip')
|
|
1483
|
+
return print('Runtime model & preprocessor for api: '+apiurl+" updated to model version "+model_version+".\n\nModel metrics are now updated and verified for this model playground.")
|
|
1484
|
+
else:
|
|
1485
|
+
# the file resource to be the new runtime_model is not available
|
|
1486
|
+
return print('New Runtime Model version ' + model_version + ' not found.')
|
|
1487
|
+
|
|
1488
|
+
|
|
1489
|
+
def _extract_model_metadata(model, eval_metrics=None):
|
|
1490
|
+
# Getting the model metadata {{{
|
|
1491
|
+
graph = model.graph
|
|
1492
|
+
|
|
1493
|
+
if eval_metrics is not None:
|
|
1494
|
+
metadata = eval_metrics
|
|
1495
|
+
else:
|
|
1496
|
+
metadata = dict()
|
|
1497
|
+
|
|
1498
|
+
metadata["num_nodes"] = len(graph.node)
|
|
1499
|
+
metadata["depth_test"] = len(graph.initializer)
|
|
1500
|
+
metadata["num_params"] = sum(np.product(node.dims) for node in graph.initializer)
|
|
1501
|
+
|
|
1502
|
+
# layers = ""
|
|
1503
|
+
# for node in graph.node:
|
|
1504
|
+
# # consider type and get node attributes (??)
|
|
1505
|
+
# layers += (
|
|
1506
|
+
# node.op_type
|
|
1507
|
+
# + "x".join(str(d.ints) for d in node.attribute if hasattr(d, 'ints'))
|
|
1508
|
+
# )
|
|
1509
|
+
metadata["layers"] = "; ".join(node.op_type for node in graph.node)
|
|
1510
|
+
|
|
1511
|
+
inputs = ""
|
|
1512
|
+
for inp in graph.input:
|
|
1513
|
+
dims = []
|
|
1514
|
+
for d in inp.type.tensor_type.shape.dim:
|
|
1515
|
+
if d.dim_param != "":
|
|
1516
|
+
dims.append(d.dim_param)
|
|
1517
|
+
else:
|
|
1518
|
+
dims.append(str(d.dim_value))
|
|
1519
|
+
|
|
1520
|
+
metadata["input_shape"] = dims
|
|
1521
|
+
inputs += f"{inp.name} ({'x'.join(dims)})"
|
|
1522
|
+
metadata["inputs"] = inputs
|
|
1523
|
+
|
|
1524
|
+
outputs = ""
|
|
1525
|
+
for out in graph.output:
|
|
1526
|
+
dims = []
|
|
1527
|
+
for d in out.type.tensor_type.shape.dim:
|
|
1528
|
+
if d.dim_param != "":
|
|
1529
|
+
dims.append(d.dim_param)
|
|
1530
|
+
else:
|
|
1531
|
+
dims.append(str(d.dim_value))
|
|
1532
|
+
|
|
1533
|
+
outputs += f"{out.name} ({'x'.join(dims)})"
|
|
1534
|
+
metadata["outputs"] = outputs
|
|
1535
|
+
# }}}
|
|
1536
|
+
|
|
1537
|
+
return metadata
|
|
1538
|
+
|
|
1539
|
+
__all__ = [
|
|
1540
|
+
submit_model,
|
|
1541
|
+
_extract_model_metadata,
|
|
1542
|
+
update_runtime_model
|
|
1543
|
+
]
|