aimodelshare 0.3.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (171) hide show
  1. aimodelshare/README.md +26 -0
  2. aimodelshare/__init__.py +100 -0
  3. aimodelshare/aimsonnx.py +2381 -0
  4. aimodelshare/api.py +836 -0
  5. aimodelshare/auth.py +163 -0
  6. aimodelshare/aws.py +511 -0
  7. aimodelshare/aws_client.py +173 -0
  8. aimodelshare/base_image.py +154 -0
  9. aimodelshare/bucketpolicy.py +106 -0
  10. aimodelshare/color_mappings/color_mapping_keras.csv +121 -0
  11. aimodelshare/color_mappings/color_mapping_pytorch.csv +117 -0
  12. aimodelshare/containerisation.py +244 -0
  13. aimodelshare/containerization.py +712 -0
  14. aimodelshare/containerization_templates/Dockerfile.txt +8 -0
  15. aimodelshare/containerization_templates/Dockerfile_PySpark.txt +23 -0
  16. aimodelshare/containerization_templates/buildspec.txt +14 -0
  17. aimodelshare/containerization_templates/lambda_function.txt +40 -0
  18. aimodelshare/custom_approach/__init__.py +1 -0
  19. aimodelshare/custom_approach/lambda_function.py +17 -0
  20. aimodelshare/custom_eval_metrics.py +103 -0
  21. aimodelshare/data_sharing/__init__.py +0 -0
  22. aimodelshare/data_sharing/data_sharing_templates/Dockerfile.txt +3 -0
  23. aimodelshare/data_sharing/data_sharing_templates/__init__.py +1 -0
  24. aimodelshare/data_sharing/data_sharing_templates/buildspec.txt +15 -0
  25. aimodelshare/data_sharing/data_sharing_templates/codebuild_policies.txt +129 -0
  26. aimodelshare/data_sharing/data_sharing_templates/codebuild_trust_relationship.txt +12 -0
  27. aimodelshare/data_sharing/download_data.py +620 -0
  28. aimodelshare/data_sharing/share_data.py +373 -0
  29. aimodelshare/data_sharing/utils.py +8 -0
  30. aimodelshare/deploy_custom_lambda.py +246 -0
  31. aimodelshare/documentation/Makefile +20 -0
  32. aimodelshare/documentation/karma_sphinx_theme/__init__.py +28 -0
  33. aimodelshare/documentation/karma_sphinx_theme/_version.py +2 -0
  34. aimodelshare/documentation/karma_sphinx_theme/breadcrumbs.html +70 -0
  35. aimodelshare/documentation/karma_sphinx_theme/layout.html +172 -0
  36. aimodelshare/documentation/karma_sphinx_theme/search.html +50 -0
  37. aimodelshare/documentation/karma_sphinx_theme/searchbox.html +14 -0
  38. aimodelshare/documentation/karma_sphinx_theme/static/css/custom.css +2 -0
  39. aimodelshare/documentation/karma_sphinx_theme/static/css/custom.css.map +1 -0
  40. aimodelshare/documentation/karma_sphinx_theme/static/css/theme.css +2751 -0
  41. aimodelshare/documentation/karma_sphinx_theme/static/css/theme.css.map +1 -0
  42. aimodelshare/documentation/karma_sphinx_theme/static/css/theme.min.css +2 -0
  43. aimodelshare/documentation/karma_sphinx_theme/static/css/theme.min.css.map +1 -0
  44. aimodelshare/documentation/karma_sphinx_theme/static/font/fontello.eot +0 -0
  45. aimodelshare/documentation/karma_sphinx_theme/static/font/fontello.svg +32 -0
  46. aimodelshare/documentation/karma_sphinx_theme/static/font/fontello.ttf +0 -0
  47. aimodelshare/documentation/karma_sphinx_theme/static/font/fontello.woff +0 -0
  48. aimodelshare/documentation/karma_sphinx_theme/static/font/fontello.woff2 +0 -0
  49. aimodelshare/documentation/karma_sphinx_theme/static/js/theme.js +68 -0
  50. aimodelshare/documentation/karma_sphinx_theme/theme.conf +9 -0
  51. aimodelshare/documentation/make.bat +35 -0
  52. aimodelshare/documentation/requirements.txt +2 -0
  53. aimodelshare/documentation/source/about.rst +18 -0
  54. aimodelshare/documentation/source/advanced_features.rst +137 -0
  55. aimodelshare/documentation/source/competition.rst +218 -0
  56. aimodelshare/documentation/source/conf.py +58 -0
  57. aimodelshare/documentation/source/create_credentials.rst +86 -0
  58. aimodelshare/documentation/source/example_notebooks.rst +132 -0
  59. aimodelshare/documentation/source/functions.rst +151 -0
  60. aimodelshare/documentation/source/gettingstarted.rst +390 -0
  61. aimodelshare/documentation/source/images/creds1.png +0 -0
  62. aimodelshare/documentation/source/images/creds2.png +0 -0
  63. aimodelshare/documentation/source/images/creds3.png +0 -0
  64. aimodelshare/documentation/source/images/creds4.png +0 -0
  65. aimodelshare/documentation/source/images/creds5.png +0 -0
  66. aimodelshare/documentation/source/images/creds_file_example.png +0 -0
  67. aimodelshare/documentation/source/images/predict_tab.png +0 -0
  68. aimodelshare/documentation/source/index.rst +110 -0
  69. aimodelshare/documentation/source/modelplayground.rst +132 -0
  70. aimodelshare/exceptions.py +11 -0
  71. aimodelshare/generatemodelapi.py +1270 -0
  72. aimodelshare/iam/codebuild_policy.txt +129 -0
  73. aimodelshare/iam/codebuild_trust_relationship.txt +12 -0
  74. aimodelshare/iam/lambda_policy.txt +15 -0
  75. aimodelshare/iam/lambda_trust_relationship.txt +12 -0
  76. aimodelshare/json_templates/__init__.py +1 -0
  77. aimodelshare/json_templates/api_json.txt +155 -0
  78. aimodelshare/json_templates/auth/policy.txt +1 -0
  79. aimodelshare/json_templates/auth/role.txt +1 -0
  80. aimodelshare/json_templates/eval/policy.txt +1 -0
  81. aimodelshare/json_templates/eval/role.txt +1 -0
  82. aimodelshare/json_templates/function/policy.txt +1 -0
  83. aimodelshare/json_templates/function/role.txt +1 -0
  84. aimodelshare/json_templates/integration_response.txt +5 -0
  85. aimodelshare/json_templates/lambda_policy_1.txt +15 -0
  86. aimodelshare/json_templates/lambda_policy_2.txt +8 -0
  87. aimodelshare/json_templates/lambda_role_1.txt +12 -0
  88. aimodelshare/json_templates/lambda_role_2.txt +16 -0
  89. aimodelshare/leaderboard.py +174 -0
  90. aimodelshare/main/1.txt +132 -0
  91. aimodelshare/main/1B.txt +112 -0
  92. aimodelshare/main/2.txt +153 -0
  93. aimodelshare/main/3.txt +134 -0
  94. aimodelshare/main/4.txt +128 -0
  95. aimodelshare/main/5.txt +109 -0
  96. aimodelshare/main/6.txt +105 -0
  97. aimodelshare/main/7.txt +144 -0
  98. aimodelshare/main/8.txt +142 -0
  99. aimodelshare/main/__init__.py +1 -0
  100. aimodelshare/main/authorization.txt +275 -0
  101. aimodelshare/main/eval_classification.txt +79 -0
  102. aimodelshare/main/eval_lambda.txt +1709 -0
  103. aimodelshare/main/eval_regression.txt +80 -0
  104. aimodelshare/main/lambda_function.txt +8 -0
  105. aimodelshare/main/nst.txt +149 -0
  106. aimodelshare/model.py +1543 -0
  107. aimodelshare/modeluser.py +215 -0
  108. aimodelshare/moral_compass/README.md +408 -0
  109. aimodelshare/moral_compass/__init__.py +65 -0
  110. aimodelshare/moral_compass/_version.py +3 -0
  111. aimodelshare/moral_compass/api_client.py +601 -0
  112. aimodelshare/moral_compass/apps/__init__.py +69 -0
  113. aimodelshare/moral_compass/apps/ai_consequences.py +540 -0
  114. aimodelshare/moral_compass/apps/bias_detective.py +714 -0
  115. aimodelshare/moral_compass/apps/ethical_revelation.py +898 -0
  116. aimodelshare/moral_compass/apps/fairness_fixer.py +889 -0
  117. aimodelshare/moral_compass/apps/judge.py +888 -0
  118. aimodelshare/moral_compass/apps/justice_equity_upgrade.py +853 -0
  119. aimodelshare/moral_compass/apps/mc_integration_helpers.py +820 -0
  120. aimodelshare/moral_compass/apps/model_building_game.py +1104 -0
  121. aimodelshare/moral_compass/apps/model_building_game_beginner.py +687 -0
  122. aimodelshare/moral_compass/apps/moral_compass_challenge.py +858 -0
  123. aimodelshare/moral_compass/apps/session_auth.py +254 -0
  124. aimodelshare/moral_compass/apps/shared_activity_styles.css +349 -0
  125. aimodelshare/moral_compass/apps/tutorial.py +481 -0
  126. aimodelshare/moral_compass/apps/what_is_ai.py +853 -0
  127. aimodelshare/moral_compass/challenge.py +365 -0
  128. aimodelshare/moral_compass/config.py +187 -0
  129. aimodelshare/placeholders/model.onnx +0 -0
  130. aimodelshare/placeholders/preprocessor.zip +0 -0
  131. aimodelshare/playground.py +1968 -0
  132. aimodelshare/postprocessormodules.py +157 -0
  133. aimodelshare/preprocessormodules.py +373 -0
  134. aimodelshare/pyspark/1.txt +195 -0
  135. aimodelshare/pyspark/1B.txt +181 -0
  136. aimodelshare/pyspark/2.txt +220 -0
  137. aimodelshare/pyspark/3.txt +204 -0
  138. aimodelshare/pyspark/4.txt +187 -0
  139. aimodelshare/pyspark/5.txt +178 -0
  140. aimodelshare/pyspark/6.txt +174 -0
  141. aimodelshare/pyspark/7.txt +211 -0
  142. aimodelshare/pyspark/8.txt +206 -0
  143. aimodelshare/pyspark/__init__.py +1 -0
  144. aimodelshare/pyspark/authorization.txt +258 -0
  145. aimodelshare/pyspark/eval_classification.txt +79 -0
  146. aimodelshare/pyspark/eval_lambda.txt +1441 -0
  147. aimodelshare/pyspark/eval_regression.txt +80 -0
  148. aimodelshare/pyspark/lambda_function.txt +8 -0
  149. aimodelshare/pyspark/nst.txt +213 -0
  150. aimodelshare/python/my_preprocessor.py +58 -0
  151. aimodelshare/readme.md +26 -0
  152. aimodelshare/reproducibility.py +181 -0
  153. aimodelshare/sam/Dockerfile.txt +8 -0
  154. aimodelshare/sam/Dockerfile_PySpark.txt +24 -0
  155. aimodelshare/sam/__init__.py +1 -0
  156. aimodelshare/sam/buildspec.txt +11 -0
  157. aimodelshare/sam/codebuild_policies.txt +129 -0
  158. aimodelshare/sam/codebuild_trust_relationship.txt +12 -0
  159. aimodelshare/sam/codepipeline_policies.txt +173 -0
  160. aimodelshare/sam/codepipeline_trust_relationship.txt +12 -0
  161. aimodelshare/sam/spark-class.txt +2 -0
  162. aimodelshare/sam/template.txt +54 -0
  163. aimodelshare/tools.py +103 -0
  164. aimodelshare/utils/__init__.py +78 -0
  165. aimodelshare/utils/optional_deps.py +38 -0
  166. aimodelshare/utils.py +57 -0
  167. aimodelshare-0.3.7.dist-info/METADATA +298 -0
  168. aimodelshare-0.3.7.dist-info/RECORD +171 -0
  169. aimodelshare-0.3.7.dist-info/WHEEL +5 -0
  170. aimodelshare-0.3.7.dist-info/licenses/LICENSE +5 -0
  171. aimodelshare-0.3.7.dist-info/top_level.txt +1 -0
@@ -0,0 +1,2381 @@
1
+ # data wrangling
2
+ import pandas as pd
3
+ import numpy as np
4
+
5
+ # Import optional dependency checker
6
+ from aimodelshare.utils.optional_deps import check_optional
7
+
8
+ # --- ML FRAMEWORKS IMPORT SECTION ---
9
+ # Initialize to None to prevent NameErrors later if imports fail
10
+ sklearn = None
11
+ torch = None
12
+ xgboost = None
13
+ tf = None
14
+ keras = None
15
+ pyspark = None
16
+
17
+ # ml frameworks
18
+ try:
19
+ import sklearn
20
+ from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
21
+ except:
22
+ check_optional("sklearn", "Scikit-learn")
23
+
24
+ try:
25
+ import torch
26
+ except:
27
+ check_optional("torch", "PyTorch")
28
+
29
+ try:
30
+ import xgboost
31
+ except:
32
+ check_optional("xgboost", "XGBoost")
33
+
34
+ try:
35
+ import tensorflow as tf
36
+ import keras
37
+ except:
38
+ check_optional("tensorflow", "TensorFlow/Keras")
39
+
40
+ try:
41
+ import pyspark
42
+ from pyspark.sql import SparkSession
43
+ from pyspark.ml import PipelineModel, Model
44
+ from pyspark.ml.tuning import CrossValidatorModel, TrainValidationSplitModel
45
+ from onnxmltools import convert_sparkml
46
+ except:
47
+ check_optional("pyspark", "PySpark")
48
+
49
+
50
+ # onnx modules
51
+ import onnx
52
+ import skl2onnx
53
+ from skl2onnx import convert_sklearn
54
+ # tf2onnx import is lazy-loaded to avoid requiring TensorFlow for non-TF workflows
55
+ _TF2ONNX_AVAILABLE = None
56
+ _tf2onnx_module = None
57
+ _tensorflow_module = None
58
+ try:
59
+ from torch.onnx import export
60
+ except:
61
+ pass
62
+ from onnx.tools.net_drawer import GetPydotGraph, GetOpNodeProducer
63
+ import importlib
64
+ import onnxmltools
65
+ import onnxruntime as rt
66
+ from skl2onnx.common.data_types import FloatTensorType
67
+
68
+ # aims modules
69
+ from aimodelshare.aws import run_function_on_lambda, get_aws_client
70
+ from aimodelshare.reproducibility import set_reproducibility_env
71
+ from pandas.io.formats.style import Styler
72
+
73
+ # os etc
74
+ import os
75
+ import ast
76
+ import tempfile
77
+ import json
78
+ import re
79
+ import pickle
80
+ import requests
81
+ import sys
82
+ import shutil
83
+ from pathlib import Path
84
+ from zipfile import ZipFile
85
+ import wget
86
+ from copy import copy
87
+ import psutil
88
+ from pympler import asizeof
89
+ from IPython.display import display, HTML, SVG
90
+ import logging
91
+
92
+
93
+ import networkx as nx
94
+ import warnings
95
+ from pathlib import Path
96
+ import time
97
+ import signal
98
+
99
+ # scikeras imports keras which requires TensorFlow - lazy load it
100
+ try:
101
+ from scikeras.wrappers import KerasClassifier, KerasRegressor
102
+ _SCIKERAS_AVAILABLE = True
103
+ except ImportError:
104
+ _SCIKERAS_AVAILABLE = False
105
+ KerasClassifier = None
106
+ KerasRegressor = None
107
+
108
+
109
+ logging.getLogger().setLevel(logging.ERROR)
110
+
111
+ def _check_tf2onnx_available():
112
+ """Check if tf2onnx and TensorFlow are available, and load them if needed.
113
+
114
+ Returns:
115
+ tuple: (tf2onnx_module, tensorflow_module) on success
116
+
117
+ Raises:
118
+ RuntimeError: If TensorFlow or tf2onnx are not installed
119
+ """
120
+ global _TF2ONNX_AVAILABLE, _tf2onnx_module, _tensorflow_module
121
+
122
+ if _TF2ONNX_AVAILABLE is None:
123
+ try:
124
+ import tf2onnx as tf2onnx_temp
125
+ import tensorflow as tf_temp
126
+ _tf2onnx_module = tf2onnx_temp
127
+ _tensorflow_module = tf_temp
128
+ _TF2ONNX_AVAILABLE = True
129
+ except ImportError as e:
130
+ _TF2ONNX_AVAILABLE = False
131
+ raise RuntimeError(
132
+ "TensorFlow and tf2onnx are required for Keras model conversion to ONNX. "
133
+ "Please install them with: pip install tensorflow tf2onnx"
134
+ ) from e
135
+
136
+ if not _TF2ONNX_AVAILABLE:
137
+ raise RuntimeError(
138
+ "TensorFlow and tf2onnx are required for Keras model conversion to ONNX. "
139
+ "Please install them with: pip install tensorflow tf2onnx"
140
+ )
141
+
142
+ return _tf2onnx_module, _tensorflow_module
143
+
144
+ def _extract_onnx_metadata(onnx_model, framework):
145
+ '''Extracts model metadata from ONNX file.'''
146
+
147
+ # get model graph
148
+ graph = onnx_model.graph
149
+
150
+ # initialize metadata dict
151
+ metadata_onnx = {}
152
+
153
+ def _get_shape(dims):
154
+ return [d.dim_value if d.HasField("dim_value") else None for d in dims]
155
+
156
+ input_dims = graph.input[0].type.tensor_type.shape.dim
157
+ output_dims = graph.output[0].type.tensor_type.shape.dim
158
+
159
+ metadata_onnx["input_shape"] = _get_shape(input_dims)
160
+ metadata_onnx["output_shape"] = _get_shape(output_dims)
161
+
162
+ # get layers and activations NEW
163
+ # match layers and nodes and initalizers in sinle object
164
+ # insert here....
165
+
166
+ # get layers & activations
167
+ layer_nodes = ['MatMul', 'Gemm', 'Conv'] #add MaxPool, Transpose, Flatten
168
+ activation_nodes = ['Relu', 'Softmax']
169
+
170
+ layers = []
171
+ activations = []
172
+
173
+ for op_id, op in enumerate(graph.node):
174
+
175
+ if op.op_type in layer_nodes:
176
+ layers.append(op.op_type)
177
+ #if op.op_type == 'MaxPool':
178
+ #activations.append(None)
179
+
180
+ if op.op_type in activation_nodes:
181
+ activations.append(op.op_type)
182
+
183
+
184
+ # get shapes and parameters
185
+ layers_shapes = []
186
+ layers_n_params = []
187
+
188
+ if framework == 'keras':
189
+ try:
190
+ initializer = list(reversed(graph.initializer))
191
+ for layer_id, layer in enumerate(initializer):
192
+ if(len(layer.dims)>= 2):
193
+ layers_shapes.append(layer.dims[1])
194
+
195
+ try:
196
+ n_params = int(np.prod(layer.dims) + initializer[layer_id-1].dims)
197
+ except:
198
+ n_params = None
199
+
200
+ layers_n_params.append(n_params)
201
+ except:
202
+ pass
203
+
204
+
205
+ elif framework == 'pytorch':
206
+ try:
207
+ initializer = graph.initializer
208
+ for layer_id, layer in enumerate(initializer):
209
+ if(len(layer.dims)>= 2):
210
+ layers_shapes.append(layer.dims[0])
211
+ n_params = int(np.prod(layer.dims) + initializer[layer_id-1].dims)
212
+ layers_n_params.append(n_params)
213
+ except:
214
+ pass
215
+
216
+
217
+ # get model architecture stats
218
+ model_architecture = {'layers_number': len(layers),
219
+ 'layers_sequence': layers,
220
+ 'layers_summary': {i:layers.count(i) for i in set(layers)},
221
+ 'layers_n_params': layers_n_params,
222
+ 'layers_shapes': layers_shapes,
223
+ 'activations_sequence': activations,
224
+ 'activations_summary': {i:activations.count(i) for i in set(activations)}
225
+ }
226
+
227
+ metadata_onnx["model_architecture"] = model_architecture
228
+
229
+ return metadata_onnx
230
+
231
+
232
+
233
+ def _misc_to_onnx(model, initial_types, transfer_learning=None,
234
+ deep_learning=None, task_type=None):
235
+
236
+ # generate metadata dict
237
+ metadata = {}
238
+
239
+ # placeholders, need to be generated elsewhere
240
+ metadata['model_id'] = None
241
+ metadata['data_id'] = None
242
+ metadata['preprocessor_id'] = None
243
+
244
+ try:
245
+ # infer ml framework from function call
246
+ if xgboost is not None and isinstance(model, (xgboost.XGBClassifier, xgboost.XGBRegressor)):
247
+ metadata['ml_framework'] = 'xgboost'
248
+ onx = onnxmltools.convert.convert_xgboost(model, initial_types=initial_types)
249
+ except:
250
+ pass
251
+
252
+ # also integrate lightGBM
253
+
254
+ # get model type from model object
255
+ model_type = str(model).split('(')[0]
256
+ metadata['model_type'] = model_type
257
+
258
+ # get transfer learning bool from user input
259
+ metadata['transfer_learning'] = transfer_learning
260
+
261
+ # get deep learning bool from user input
262
+ metadata['deep_learning'] = deep_learning
263
+
264
+ # get task type from user input
265
+ metadata['task_type'] = task_type
266
+
267
+ # placeholders, need to be inferred from data
268
+ metadata['target_distribution'] = None
269
+ metadata['input_type'] = None
270
+ metadata['input_shape'] = None
271
+ metadata['input_dtypes'] = None
272
+ metadata['input_distribution'] = None
273
+
274
+ # get model config dict from sklearn model object
275
+ metadata['model_config'] = str(model.get_params())
276
+
277
+ # get model state from sklearn model object
278
+ metadata['model_state'] = None
279
+
280
+
281
+ model_architecture = {}
282
+
283
+ if hasattr(model, 'coef_'):
284
+ model_architecture['layers_n_params'] = [len(model.coef_.flatten())]
285
+ if hasattr(model, 'solver'):
286
+ model_architecture['optimizer'] = model.solver
287
+
288
+ metadata['model_architecture'] = str(model_architecture)
289
+
290
+ metadata['memory_size'] = asizeof.asizeof(model)
291
+
292
+
293
+ # placeholder, needs evaluation engine
294
+ metadata['eval_metrics'] = None
295
+
296
+ # add metadata from onnx object
297
+ # metadata['metadata_onnx'] = str(_extract_onnx_metadata(onx, framework='sklearn'))
298
+ metadata['metadata_onnx'] = None
299
+
300
+ meta = onx.metadata_props.add()
301
+ meta.key = 'model_metadata'
302
+ meta.value = str(metadata)
303
+
304
+ return onx
305
+
306
+
307
+
308
+ def _sklearn_to_onnx(model, initial_types=None, transfer_learning=None,
309
+ deep_learning=None, task_type=None):
310
+ '''Extracts metadata from sklearn model object.'''
311
+
312
+ # check whether this is a fitted sklearn model
313
+ # sklearn.utils.validation.check_is_fitted(model)
314
+
315
+ # deal with pipelines and parameter search
316
+ try:
317
+ if sklearn is not None:
318
+ if isinstance(model, (GridSearchCV, RandomizedSearchCV)):
319
+ model = model.best_estimator_
320
+
321
+ if isinstance(model, sklearn.pipeline.Pipeline):
322
+ model = model.steps[-1][1]
323
+ except:
324
+ pass
325
+
326
+ # fix ensemble voting models
327
+ if all([hasattr(model, 'flatten_transform'),hasattr(model, 'voting')]):
328
+ model.flatten_transform=False
329
+
330
+ # convert to onnx
331
+ if initial_types == None:
332
+ feature_count=model.n_features_in_
333
+ initial_types = [('float_input', FloatTensorType([None, feature_count]))]
334
+
335
+ onx = convert_sklearn(model, initial_types=initial_types,target_opset={'': 15, 'ai.onnx.ml': 2})
336
+
337
+ ## set model ir_version to ensure sklearn opsets work properly
338
+ onx.ir_version = 8
339
+
340
+ # generate metadata dict
341
+ metadata = {}
342
+
343
+ # placeholders, need to be generated elsewhere
344
+ metadata['model_id'] = None
345
+ metadata['data_id'] = None
346
+ metadata['preprocessor_id'] = None
347
+
348
+ # infer ml framework from function call
349
+ metadata['ml_framework'] = 'sklearn'
350
+
351
+ # get model type from model object
352
+ model_type = str(model).split('(')[0]
353
+ metadata['model_type'] = model_type
354
+
355
+ # get transfer learning bool from user input
356
+ metadata['transfer_learning'] = transfer_learning
357
+
358
+ # get deep learning bool from user input
359
+ metadata['deep_learning'] = deep_learning
360
+
361
+ # get task type from user input
362
+ metadata['task_type'] = task_type
363
+
364
+ # placeholders, need to be inferred from data
365
+ metadata['target_distribution'] = None
366
+ metadata['input_type'] = None
367
+ metadata['input_shape'] = None
368
+ metadata['input_dtypes'] = None
369
+ metadata['input_distribution'] = None
370
+
371
+ # get model config dict from sklearn model object
372
+ metadata['model_config'] = str(model.get_params())
373
+
374
+ # get weights for pretrained models
375
+ temp_dir = tempfile.gettempdir()
376
+ temp_path = os.path.join(temp_dir, 'temp_file_name')
377
+
378
+ with open(temp_path, 'wb') as f:
379
+ pickle.dump(model, f)
380
+
381
+ with open(temp_path, "rb") as f:
382
+ pkl_str = f.read()
383
+
384
+ metadata['model_weights'] = pkl_str
385
+
386
+ # get model state from sklearn model object
387
+ metadata['model_state'] = None
388
+
389
+ # get model architecture
390
+ if model_type == 'MLPClassifier' or model_type == 'MLPRegressor':
391
+
392
+ if model_type == 'MLPClassifier':
393
+ loss = 'log-loss'
394
+ if model_type == 'MLPRegressor':
395
+ loss = 'squared-loss'
396
+
397
+ n_params = []
398
+ layer_dims = [model.n_features_in_] + model.hidden_layer_sizes + [model.n_outputs_]
399
+ for i in range(len(layer_dims)-1):
400
+ n_params.append(layer_dims[i]*layer_dims[i+1] + layer_dims[i+1])
401
+
402
+ # insert data into model architecture dict
403
+ model_architecture = {'layers_number': len(model.hidden_layer_sizes),
404
+ 'layers_sequence': ['Dense']*len(model.hidden_layer_sizes),
405
+ 'layers_summary': {'Dense': len(model.hidden_layer_sizes)},
406
+ 'layers_n_params': n_params, #double check
407
+ 'layers_shapes': model.hidden_layer_sizes,
408
+ 'activations_sequence': [model.activation]*len(model.hidden_layer_sizes),
409
+ 'activations_summary': {model.activation: len(model.hidden_layer_sizes)},
410
+ 'loss': loss,
411
+ 'optimizer': model.solver
412
+ }
413
+
414
+ metadata['model_architecture'] = str(model_architecture)
415
+
416
+
417
+ else:
418
+ model_architecture = {}
419
+
420
+ if hasattr(model, 'coef_'):
421
+ model_architecture['layers_n_params'] = [len(model.coef_.flatten())]
422
+ if hasattr(model, 'solver'):
423
+ model_architecture['optimizer'] = model.solver
424
+
425
+ metadata['model_architecture'] = str(model_architecture)
426
+
427
+ metadata['memory_size'] = asizeof.asizeof(model)
428
+
429
+ # placeholder, needs evaluation engine
430
+ metadata['eval_metrics'] = None
431
+
432
+ # add metadata from onnx object
433
+ # metadata['metadata_onnx'] = str(_extract_onnx_metadata(onx, framework='sklearn'))
434
+ metadata['metadata_onnx'] = None
435
+
436
+ meta = onx.metadata_props.add()
437
+ meta.key = 'model_metadata'
438
+ meta.value = str(metadata)
439
+
440
+ return onx
441
+
442
+ def get_pyspark_model_files_paths(directory):
443
+ # Get list of relative path of all model files
444
+
445
+ # initializing empty file paths list
446
+ file_paths = []
447
+
448
+ for path in Path(directory).rglob('*'):
449
+ if not path.is_dir():
450
+ file_paths.append(path.relative_to(directory))
451
+
452
+ # returning all file paths
453
+ return file_paths
454
+
455
+ def _pyspark_to_onnx(model, initial_types, spark_session,
456
+ transfer_learning=None, deep_learning=None,
457
+ task_type=None):
458
+ '''Extracts metadata from pyspark model object.'''
459
+
460
+ try:
461
+ if pyspark is None:
462
+ raise("Error: Please install pyspark to enable pyspark features")
463
+ except:
464
+ raise("Error: Please install pyspark to enable pyspark features")
465
+
466
+ # deal with pipelines and parameter search
467
+ try:
468
+ if isinstance(model, (TrainValidationSplitModel, CrossValidatorModel)):
469
+ model = model.bestModel
470
+
471
+ whole_model = copy(model)
472
+
473
+ # Look for the last model in the pipeline
474
+ if isinstance(model, PipelineModel):
475
+ for t in model.stages:
476
+ if isinstance(t, Model):
477
+ model = t
478
+ except:
479
+ pass
480
+
481
+ # convert to onnx
482
+ onx = convert_sparkml(whole_model, 'Pyspark model', initial_types,
483
+ spark_session=spark_session)
484
+
485
+ # generate metadata dict
486
+ metadata = {}
487
+
488
+ # placeholders, need to be generated elsewhere
489
+ metadata['model_id'] = None
490
+ metadata['data_id'] = None
491
+ metadata['preprocessor_id'] = None
492
+
493
+ # infer ml framework from function call
494
+ metadata['ml_framework'] = 'pyspark'
495
+
496
+ # get model type from model object
497
+ model_type = str(model).split(':')[0]
498
+ metadata['model_type'] = model_type
499
+
500
+ # get transfer learning bool from user input
501
+ metadata['transfer_learning'] = transfer_learning
502
+
503
+ # get deep learning bool from user input
504
+ metadata['deep_learning'] = deep_learning
505
+
506
+ # get task type from user input
507
+ metadata['task_type'] = task_type
508
+
509
+ # placeholders, need to be inferred from data
510
+ metadata['target_distribution'] = None
511
+ metadata['input_type'] = None
512
+ metadata['input_shape'] = None
513
+ metadata['input_dtypes'] = None
514
+ metadata['input_distribution'] = None
515
+
516
+ # get model config dict from pyspark model object
517
+ model_config = {}
518
+ try:
519
+ for key, value in model.extractParamMap().items():
520
+ model_config[key.name] = value
521
+ except:
522
+ pass
523
+ metadata['model_config'] = str(model_config)
524
+
525
+ # get weights for pretrained models
526
+ temp_dir = tempfile.gettempdir()
527
+ temp_path = os.path.join(temp_dir, 'temp_pyspark_model')
528
+
529
+ model.write().overwrite().save(temp_path)
530
+
531
+ # calling function to get all file paths in the directory
532
+ file_paths = get_pyspark_model_files_paths(temp_path)
533
+
534
+ temp_path_zip = os.path.join(temp_dir, 'temp_pyspark_model.zip')
535
+ with ZipFile(temp_path_zip, 'w') as zip:
536
+ # writing each file one by one
537
+ for file in file_paths:
538
+ zip.write(os.path.join(temp_path, file), file)
539
+
540
+ with open(temp_path_zip, "rb") as f:
541
+ pkl_str = f.read()
542
+
543
+ metadata['model_weights'] = pkl_str
544
+
545
+ # clean up temp directory files for future runs
546
+ try:
547
+ shutil.rmtree(temp_path)
548
+ os.remove(temp_path_zip)
549
+ except:
550
+ pass
551
+
552
+ # get model state from sklearn model object
553
+ metadata['model_state'] = None
554
+
555
+ # get model architecture
556
+ if model_type == 'MultilayerPerceptronClassificationModel':
557
+
558
+ # https://spark.apache.org/docs/latest/ml-classification-regression.html#multilayer-perceptron-classifier
559
+ loss = 'log-loss'
560
+ hidden_layer_activation = 'sigmoid'
561
+ output_layer_activation = 'softmax'
562
+
563
+ n_params = []
564
+ layer_dims = model.getLayers()
565
+ hidden_layers = layer_dims[1:-1]
566
+ for i in range(len(layer_dims)-1):
567
+ n_params.append(layer_dims[i]*layer_dims[i+1] + layer_dims[i+1])
568
+
569
+ # insert data into model architecture dict
570
+ model_architecture = {'layers_number': len(hidden_layers),
571
+ 'layers_sequence': ['Dense']*len(hidden_layers),
572
+ 'layers_summary': {'Dense': len(hidden_layers)},
573
+ 'layers_n_params': n_params, #double check
574
+ 'layers_shapes': hidden_layers,
575
+ 'activations_sequence': [hidden_layer_activation]*len(hidden_layers) + [output_layer_activation],
576
+ 'activations_summary': {hidden_layer_activation: len(hidden_layers), output_layer_activation: 1},
577
+ 'loss': loss,
578
+ 'optimizer': model.getSolver()
579
+ }
580
+
581
+ metadata['model_architecture'] = str(model_architecture)
582
+
583
+
584
+ else:
585
+ model_architecture = {}
586
+
587
+ if hasattr(model, 'coefficients'):
588
+ model_architecture['layers_n_params'] = [model.coefficients.size]
589
+ if hasattr(model, 'getSolver') and callable(model.getSolver):
590
+ model_architecture['optimizer'] = model.getSolver()
591
+
592
+ metadata['model_architecture'] = str(model_architecture)
593
+
594
+ metadata['memory_size'] = asizeof.asizeof(model)
595
+
596
+ # placeholder, needs evaluation engine
597
+ metadata['eval_metrics'] = None
598
+
599
+ # add metadata from onnx object
600
+ # metadata['metadata_onnx'] = str(_extract_onnx_metadata(onx, framework='sklearn'))
601
+ metadata['metadata_onnx'] = None
602
+
603
+ meta = onx.metadata_props.add()
604
+ meta.key = 'model_metadata'
605
+ meta.value = str(metadata)
606
+
607
+ return onx
608
+
609
+ def _keras_to_onnx(model, transfer_learning=None,
610
+ deep_learning=None, task_type=None, epochs=None):
611
+ '''Converts a Keras model to ONNX and extracts metadata.'''
612
+
613
+ # Check and load tf2onnx and TensorFlow lazily (only when needed)
614
+ tf2onnx, tf = _check_tf2onnx_available()
615
+
616
+ import numpy as np
617
+ import onnx
618
+ import pickle
619
+ import psutil
620
+ import warnings
621
+ from pympler import asizeof
622
+ import logging
623
+ import os
624
+ import sys
625
+ from contextlib import contextmanager
626
+
627
+ # -- Helper to suppress tf2onnx stderr (NumPy warnings etc.)
628
+ @contextmanager
629
+ def suppress_stderr():
630
+ with open(os.devnull, "w") as devnull:
631
+ old_stderr = sys.stderr
632
+ sys.stderr = devnull
633
+ try:
634
+ yield
635
+ finally:
636
+ sys.stderr = old_stderr
637
+
638
+ # Reduce logging output
639
+ tf2onnx_logger = logging.getLogger("tf2onnx")
640
+ tf2onnx_logger.setLevel(logging.CRITICAL)
641
+
642
+ # Unwrap scikeras, sklearn pipelines etc.
643
+ try:
644
+ from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
645
+ from sklearn.pipeline import Pipeline
646
+
647
+ if isinstance(model, (GridSearchCV, RandomizedSearchCV)):
648
+ model = model.best_estimator_
649
+ if isinstance(model, Pipeline):
650
+ model = model.steps[-1][1]
651
+ except:
652
+ pass
653
+
654
+ try:
655
+ from scikeras.wrappers import KerasClassifier, KerasRegressor
656
+ if isinstance(model, (KerasClassifier, KerasRegressor)):
657
+ model = model.model
658
+ except:
659
+ pass
660
+
661
+ # Input signature
662
+ input_shape = model.input_shape
663
+ if isinstance(input_shape, list):
664
+ input_shape = input_shape[0]
665
+ input_signature = [tf.TensorSpec(input_shape, tf.float32, name="input")]
666
+
667
+ # Wrap model in tf.function
668
+ @tf.function(input_signature=input_signature)
669
+ def model_fn(x):
670
+ return model(x)
671
+
672
+ concrete_func = model_fn
673
+
674
+ # Convert to ONNX
675
+ with suppress_stderr():
676
+ onx_model, _ = tf2onnx.convert.from_function(
677
+ concrete_func,
678
+ input_signature=input_signature,
679
+ opset=13,
680
+ output_path=None
681
+ )
682
+
683
+ # Extract metadata
684
+ metadata = {
685
+ 'model_id': None,
686
+ 'data_id': None,
687
+ 'preprocessor_id': None,
688
+ 'ml_framework': 'keras',
689
+ 'model_type': model.__class__.__name__,
690
+ 'transfer_learning': transfer_learning,
691
+ 'deep_learning': deep_learning,
692
+ 'task_type': task_type,
693
+ 'target_distribution': None,
694
+ 'input_type': None,
695
+ 'input_shape': input_shape,
696
+ 'input_dtypes': None,
697
+ 'input_distribution': None,
698
+ 'model_config': str(model.get_config()),
699
+ 'model_state': None,
700
+ 'eval_metrics': None,
701
+ 'model_graph': "",
702
+ 'metadata_onnx': None,
703
+ 'epochs': epochs
704
+ }
705
+
706
+ model_size = asizeof.asizeof(model.get_weights())
707
+ mem = psutil.virtual_memory()
708
+
709
+ if model_size > mem.available:
710
+ warnings.warn(f"Model size ({model_size/1e6} MB) exceeds available memory.")
711
+ metadata['model_weights'] = None
712
+ else:
713
+ metadata['model_weights'] = pickle.dumps(model.get_weights())
714
+
715
+ # Extract architecture
716
+ if not model.built: # add shape outputs if model not built
717
+ try:
718
+ model(tf.random.uniform([1] + list(input_shape[1:])))
719
+ except Exception:
720
+ pass # fallback, don't crash conversion
721
+
722
+ keras_layers = keras_unpack(model)
723
+
724
+
725
+ from tensorflow.python.framework import tensor_shape # <- place this at the top of your file
726
+
727
+ layers = []
728
+ layers_n_params = []
729
+ layers_shapes = []
730
+ activations = []
731
+
732
+ for layer in keras_layers:
733
+ # layer name
734
+ layers.append(layer.__class__.__name__)
735
+
736
+ # parameter count
737
+ try:
738
+ layers_n_params.append(layer.count_params())
739
+ except:
740
+ layers_n_params.append(0)
741
+
742
+ # output shape (sanitized for JSON)
743
+ shape = getattr(layer, 'output_shape', None)
744
+
745
+ if isinstance(shape, tensor_shape.TensorShape):
746
+ shape = shape.as_list()
747
+ elif shape is not None:
748
+ try:
749
+ shape = list(shape)
750
+ except:
751
+ shape = str(shape)
752
+ else:
753
+ shape = None
754
+
755
+ layers_shapes.append(shape)
756
+
757
+ # activation
758
+ if hasattr(layer, 'activation'):
759
+ act = getattr(layer.activation, '__name__', None)
760
+ if act:
761
+ activations.append(act)
762
+
763
+ optimizer = getattr(model.optimizer, '__class__', None)
764
+ loss = getattr(model.loss, '__class__', None)
765
+
766
+ model_architecture = {
767
+ 'layers_number': len(layers),
768
+ 'layers_sequence': layers,
769
+ 'layers_summary': {i: layers.count(i) for i in set(layers)},
770
+ 'layers_n_params': layers_n_params,
771
+ 'layers_shapes': layers_shapes,
772
+ 'activations_sequence': activations,
773
+ 'activations_summary': {i: activations.count(i) for i in set(activations)},
774
+ 'loss': loss.__name__ if loss else None,
775
+ 'optimizer': optimizer.__name__ if optimizer else None
776
+ }
777
+
778
+ metadata['model_architecture'] = str(model_architecture)
779
+ metadata['model_summary'] = model_summary_keras(model).to_json()
780
+ metadata['memory_size'] = model_size
781
+
782
+ # Embed metadata in ONNX
783
+ meta = onx_model.metadata_props.add()
784
+ meta.key = 'model_metadata'
785
+ meta.value = str(metadata)
786
+
787
+ return onx_model
788
+
789
+
790
+
791
+ def _pytorch_to_onnx(model, model_input, transfer_learning=None,
792
+ deep_learning=None, task_type=None,
793
+ epochs=None):
794
+
795
+ '''Extracts metadata from pytorch model object.'''
796
+
797
+ # TODO check whether this is a fitted pytorch model
798
+ # isinstance...
799
+
800
+ # generate tempfile for onnx object
801
+ temp_dir = tempfile.gettempdir()
802
+ temp_path = os.path.join(temp_dir, 'temp_file_name')
803
+
804
+ # generate onnx file and save it to temporary path
805
+ export(model, model_input, temp_path)
806
+ #operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)
807
+
808
+ # load onnx file from temporary path
809
+ onx = onnx.load(temp_path)
810
+
811
+ # generate metadata dict
812
+ metadata = {}
813
+
814
+ # placeholders, need to be generated elsewhere
815
+ metadata['model_id'] = None
816
+ metadata['data_id'] = None
817
+ metadata['preprocessor_id'] = None
818
+
819
+ # infer ml framework from function call
820
+ metadata['ml_framework'] = 'pytorch'
821
+
822
+ # get model type from model object
823
+ metadata['model_type'] = str(model.__class__).split('.')[-1].split("'")[0] + '()'
824
+
825
+ # get transfer learning bool from user input
826
+ metadata['transfer_learning'] = transfer_learning
827
+
828
+ # get deep learning bool from user input
829
+ metadata['deep_learning'] = deep_learning
830
+
831
+ # get task type from user input
832
+ metadata['task_type'] = task_type
833
+
834
+ # placeholders, need to be inferred from data
835
+ metadata['target_distribution'] = None
836
+ metadata['input_type'] = None
837
+ metadata['input_shape'] = None
838
+ metadata['input_dtypes'] = None
839
+ metadata['input_distribution'] = None
840
+
841
+ # get model config dict from pytorch model object
842
+ metadata['model_config'] = str(model.__dict__)
843
+
844
+ # get model state from pytorch model object
845
+ metadata['model_state'] = str(model.state_dict())
846
+
847
+
848
+ name_list, layer_list, param_list, weight_list, activation_list = torch_metadata(model)
849
+
850
+ model_summary_pd = pd.DataFrame({"Name": name_list,
851
+ "Layer": layer_list,
852
+ "Shape": weight_list,
853
+ "Params": param_list,
854
+ "Connect": None,
855
+ "Activation": None})
856
+
857
+ model_architecture = {'layers_number': len(layer_list),
858
+ 'layers_sequence': layer_list,
859
+ 'layers_summary': {i:layer_list.count(i) for i in set(layer_list)},
860
+ 'layers_n_params': param_list,
861
+ 'layers_shapes': weight_list,
862
+ 'activations_sequence': activation_list,
863
+ 'activations_summary': {i:activation_list.count(i) for i in set(activation_list)},
864
+ 'loss': None,
865
+ 'optimizer': None}
866
+
867
+ metadata['model_architecture'] = str(model_architecture)
868
+
869
+ metadata['model_summary'] = model_summary_pd.to_json()
870
+
871
+
872
+ metadata['memory_size'] = asizeof.asizeof(model)
873
+ metadata['epochs'] = epochs
874
+
875
+ # placeholder, needs evaluation engine
876
+ metadata['eval_metrics'] = None
877
+
878
+ # add metadata from onnx object
879
+ # metadata['metadata_onnx'] = str(_extract_onnx_metadata(onx, framework='pytorch'))
880
+ metadata['metadata_onnx'] = None
881
+
882
+ # add metadata dict to onnx object
883
+ meta = onx.metadata_props.add()
884
+ meta.key = 'model_metadata'
885
+ meta.value = str(metadata)
886
+
887
+
888
+ return onx
889
+
890
+
891
+ def model_to_onnx(model, framework=None, model_input=None, initial_types=None,
892
+ transfer_learning=None, deep_learning=None, task_type=None,
893
+ epochs=None, spark_session=None):
894
+
895
+ '''Transforms sklearn, keras, or pytorch model object into ONNX format
896
+ and extracts model metadata dictionary. The model metadata dictionary
897
+ is saved in the ONNX file's metadata_props.
898
+
899
+ Parameters:
900
+ model: fitted sklearn, keras, or pytorch model object
901
+ Specifies the model object that will be converted to ONNX.
902
+
903
+ framework: {"sklearn", "keras", "pytorch"}
904
+ Specifies the machine learning framework of the model object.
905
+
906
+ model_input: array_like, default=None
907
+ Required when framework="pytorch".
908
+ initial_types: initial types tuple, default=None
909
+ Required when framework="sklearn".
910
+
911
+ transfer_learning: bool, default=None
912
+ Indicates whether transfer learning was used.
913
+
914
+ deep_learning: bool, default=None
915
+ Indicates whether deep learning was used.
916
+
917
+ task_type: {"classification", "regression"}
918
+ Indicates whether the model is a classification model or
919
+ a regression model.
920
+
921
+ Returns:
922
+ ONNX object with model metadata saved in metadata props
923
+ '''
924
+
925
+ # if no framework was passed, extract framework
926
+ if model and framework==None:
927
+ framework = model.__module__.split(".")[0]
928
+ try:
929
+ import torch
930
+ if isinstance(model, torch.nn.Module):
931
+ framework = "pytorch"
932
+ except:
933
+ pass
934
+
935
+ # assert that framework exists
936
+ frameworks = ['sklearn', 'keras', 'pytorch', 'xgboost', 'pyspark']
937
+ assert framework in frameworks, \
938
+ 'Please choose "sklearn", "keras", "pytorch", "pyspark" or "xgboost".'
939
+
940
+ # assert model input type THIS IS A PLACEHOLDER
941
+ #if model_input is not None:
942
+ # assert isinstance(model_input, (list, pd.core.series.Series, np.ndarray, torch.Tensor)), \
943
+ # 'Please format model input as XYZ.'
944
+
945
+ # assert initialtypes
946
+ if initial_types != None:
947
+ assert isinstance(initial_types[0][1], (skl2onnx.common.data_types.FloatTensorType)), \
948
+ 'Please use FloatTensorType as initial types.'
949
+
950
+ # assert transfer_learning
951
+ if transfer_learning != None:
952
+ assert isinstance(transfer_learning, (bool)), \
953
+ 'Please pass boolean to indicate whether transfer learning was used.'
954
+
955
+ # assert deep_learning
956
+ if deep_learning != None:
957
+ assert isinstance(deep_learning, (bool)), \
958
+ 'Please pass boolean to indicate whether deep learning was used.'
959
+
960
+ # assert task_type
961
+ if task_type != None:
962
+ assert task_type in ['classification', 'regression'], \
963
+ 'Please specify task type as "classification" or "regression".'
964
+
965
+
966
+ if framework == 'sklearn':
967
+ onnx = _sklearn_to_onnx(model, initial_types=initial_types,
968
+ transfer_learning=transfer_learning,
969
+ deep_learning=deep_learning,
970
+ task_type=task_type)
971
+ elif framework == 'xgboost':
972
+ onnx = _misc_to_onnx(model, initial_types=initial_types,
973
+ transfer_learning=transfer_learning,
974
+ deep_learning=deep_learning,
975
+ task_type=task_type)
976
+
977
+ elif framework == 'keras':
978
+ onnx = _keras_to_onnx(model, transfer_learning=transfer_learning,
979
+ deep_learning=deep_learning,
980
+ task_type=task_type,
981
+ epochs=epochs)
982
+
983
+
984
+ elif framework == 'pytorch':
985
+
986
+ onnx = _pytorch_to_onnx(model, model_input=model_input,
987
+ transfer_learning=transfer_learning,
988
+ deep_learning=deep_learning,
989
+ task_type=task_type,
990
+ epochs=epochs)
991
+
992
+ elif framework == 'pyspark':
993
+ try:
994
+ import pyspark
995
+ from pyspark.sql import SparkSession
996
+ from pyspark.ml import PipelineModel, Model
997
+ from pyspark.ml.tuning import CrossValidatorModel, TrainValidationSplitModel
998
+ from onnxmltools import convert_sparkml
999
+ except:
1000
+ check_optional("pyspark", "PySpark")
1001
+ onnx = _pyspark_to_onnx(model, initial_types=initial_types,
1002
+ transfer_learning=transfer_learning,
1003
+ deep_learning=deep_learning,
1004
+ task_type=task_type,
1005
+ spark_session=spark_session)
1006
+
1007
+ try:
1008
+ rt.InferenceSession(onnx.SerializeToString())
1009
+ except Exception as e:
1010
+ print(e)
1011
+
1012
+ return onnx
1013
+
1014
+
1015
+
1016
+ def model_to_onnx_timed(model_filepath, force_onnx=False, timeout=60, model_input=None):
1017
+
1018
+ if not (model_filepath == None or isinstance(model_filepath, str) or isinstance(model_filepath, onnx.ModelProto)):
1019
+
1020
+ if force_onnx:
1021
+ try:
1022
+ import torch
1023
+ if isinstance(model_filepath, torch.nn.Module):
1024
+ onnx_model = model_to_onnx(model_filepath, model_input=model_input)
1025
+ except:
1026
+ onnx_model = model_to_onnx(model_filepath)
1027
+ model_filepath = onnx_model
1028
+
1029
+ else:
1030
+
1031
+
1032
+ try: # try needed because signal.alarm not available on windows
1033
+ # interrupt if onnx conversion is taking too long
1034
+ def timeout_handler(num, stack):
1035
+ raise Exception("timeout")
1036
+
1037
+ signal.alarm(timeout)
1038
+ except:
1039
+ pass
1040
+
1041
+
1042
+ try:
1043
+ try:
1044
+ import torch
1045
+ if isinstance(model_filepath, torch.nn.Module):
1046
+ onnx_model = model_to_onnx(model_filepath, model_input=model_input)
1047
+ else:
1048
+ onnx_model = model_to_onnx(model_filepath)
1049
+ except:
1050
+ onnx_model = model_to_onnx(model_filepath)
1051
+ model_filepath = onnx_model
1052
+
1053
+ except:
1054
+ print("Timeout: Model to ONNX conversion is taking longer than expected. This can be the case for big models.")
1055
+
1056
+ # Detect CI/testing environment for non-interactive fallback
1057
+ is_non_interactive = (
1058
+ os.environ.get("PYTEST_CURRENT_TEST") is not None or
1059
+ os.environ.get("AIMS_NON_INTERACTIVE") == "1"
1060
+ )
1061
+
1062
+ if is_non_interactive:
1063
+ # Auto-fallback to predictions-only in CI/testing environment
1064
+ print("Non-interactive environment detected. Falling back to predictions-only submission.")
1065
+ model_filepath = None
1066
+ else:
1067
+ # Interactive prompt for manual runs
1068
+ response = ''
1069
+ while response not in {"1", "2"}:
1070
+ response = input("Do you want to keep trying (1) or submit predictions only (2)? ")
1071
+
1072
+ if response == "1":
1073
+ try:
1074
+ import torch
1075
+ if isinstance(model_filepath, torch.nn.Module):
1076
+ onnx_model = model_to_onnx(model_filepath, model_input=model_input)
1077
+ else:
1078
+ onnx_model = model_to_onnx(model_filepath)
1079
+ except Exception as e:
1080
+ # Final fallback - if torch-specific handling failed, try generic conversion
1081
+ # This handles cases where torch module detection fails but conversion might still work
1082
+ warnings.warn(f"PyTorch-specific ONNX conversion failed ({e}), attempting generic conversion")
1083
+ onnx_model = model_to_onnx(model_filepath, model_input=model_input)
1084
+ model_filepath = onnx_model
1085
+
1086
+ elif response == "2":
1087
+ model_filepath = None
1088
+
1089
+ finally:
1090
+ print()
1091
+ try: # try needed because signal.alarm not available on windows
1092
+ signal.alarm(0)
1093
+ except:
1094
+ pass
1095
+
1096
+ return model_filepath
1097
+
1098
+ def _get_metadata(onnx_model):
1099
+ '''Fetches previously extracted model metadata from ONNX object
1100
+ and returns model metadata dict.'''
1101
+
1102
+ # double check this
1103
+ #assert(isinstance(onnx_model, onnx.onnx_ml_pb2.ModelProto)), \
1104
+ #"Please pass a onnx model object."
1105
+
1106
+ # Handle None input gracefully - always return a dict
1107
+ if onnx_model is None:
1108
+ if os.environ.get("AIMODELSHARE_DEBUG_METADATA"):
1109
+ print("[DEBUG] _get_metadata: onnx_model is None, returning empty dict")
1110
+ return {}
1111
+
1112
+ try:
1113
+ onnx_meta = onnx_model.metadata_props
1114
+
1115
+ onnx_meta_dict = {'model_metadata': ''}
1116
+
1117
+ for i in onnx_meta:
1118
+ onnx_meta_dict[i.key] = i.value
1119
+
1120
+ onnx_meta_dict = ast.literal_eval(onnx_meta_dict['model_metadata'])
1121
+
1122
+ # Handle case where metadata is stored as a list instead of dict
1123
+ if isinstance(onnx_meta_dict, list):
1124
+ if os.environ.get("AIMODELSHARE_DEBUG_METADATA"):
1125
+ print(f"[DEBUG] _get_metadata: metadata is a list of length {len(onnx_meta_dict)}")
1126
+ if len(onnx_meta_dict) > 0 and isinstance(onnx_meta_dict[0], dict):
1127
+ onnx_meta_dict = onnx_meta_dict[0]
1128
+ if os.environ.get("AIMODELSHARE_DEBUG_METADATA"):
1129
+ print("[DEBUG] _get_metadata: Extracted first dict from list")
1130
+ else:
1131
+ # Return empty dict if list doesn't contain valid dicts
1132
+ if os.environ.get("AIMODELSHARE_DEBUG_METADATA"):
1133
+ print("[DEBUG] _get_metadata: List does not contain valid dicts, returning empty dict")
1134
+ return {}
1135
+
1136
+ # Ensure we have a dict at this point
1137
+ if not isinstance(onnx_meta_dict, dict):
1138
+ if os.environ.get("AIMODELSHARE_DEBUG_METADATA"):
1139
+ print(f"[DEBUG] _get_metadata: Unexpected metadata type {type(onnx_meta_dict)}, returning empty dict")
1140
+ return {}
1141
+
1142
+ #if onnx_meta_dict['model_config'] != None and \
1143
+ #onnx_meta_dict['ml_framework'] != 'pytorch':
1144
+ # onnx_meta_dict['model_config'] = ast.literal_eval(onnx_meta_dict['model_config'])
1145
+
1146
+ # Attempt to parse nested fields only if they are string representations of dicts
1147
+ if 'model_architecture' in onnx_meta_dict and onnx_meta_dict['model_architecture'] != None:
1148
+ try:
1149
+ if isinstance(onnx_meta_dict['model_architecture'], str):
1150
+ onnx_meta_dict['model_architecture'] = ast.literal_eval(onnx_meta_dict['model_architecture'])
1151
+ except (ValueError, SyntaxError):
1152
+ # Keep as-is if parsing fails
1153
+ pass
1154
+
1155
+ if 'model_config' in onnx_meta_dict and onnx_meta_dict['model_config'] != None:
1156
+ try:
1157
+ if isinstance(onnx_meta_dict['model_config'], str):
1158
+ onnx_meta_dict['model_config'] = ast.literal_eval(onnx_meta_dict['model_config'])
1159
+ except (ValueError, SyntaxError):
1160
+ # Keep as-is if parsing fails
1161
+ pass
1162
+
1163
+ if 'metadata_onnx' in onnx_meta_dict and onnx_meta_dict['metadata_onnx'] != None:
1164
+ try:
1165
+ if isinstance(onnx_meta_dict['metadata_onnx'], str):
1166
+ onnx_meta_dict['metadata_onnx'] = ast.literal_eval(onnx_meta_dict['metadata_onnx'])
1167
+ except (ValueError, SyntaxError):
1168
+ # Keep as-is if parsing fails
1169
+ pass
1170
+
1171
+ # onnx_meta_dict['model_image'] = onnx_to_image(onnx_model)
1172
+
1173
+ except Exception as e:
1174
+
1175
+ if os.environ.get("AIMODELSHARE_DEBUG_METADATA"):
1176
+ print(f"[DEBUG] _get_metadata: Exception during metadata extraction: {e}")
1177
+
1178
+ try:
1179
+ onnx_meta_dict = ast.literal_eval(onnx_meta_dict)
1180
+ # Handle list case in exception path as well
1181
+ if isinstance(onnx_meta_dict, list) and len(onnx_meta_dict) > 0 and isinstance(onnx_meta_dict[0], dict):
1182
+ onnx_meta_dict = onnx_meta_dict[0]
1183
+ elif not isinstance(onnx_meta_dict, dict):
1184
+ onnx_meta_dict = {}
1185
+ except:
1186
+ onnx_meta_dict = {}
1187
+
1188
+ # Final safety check: ensure we always return a dict
1189
+ if not isinstance(onnx_meta_dict, dict):
1190
+ if os.environ.get("AIMODELSHARE_DEBUG_METADATA"):
1191
+ print(f"[DEBUG] _get_metadata: Final check failed, returning empty dict instead of {type(onnx_meta_dict)}")
1192
+ return {}
1193
+
1194
+ return onnx_meta_dict
1195
+
1196
+
1197
+
1198
+ def _get_leaderboard_data(onnx_model, eval_metrics=None):
1199
+ '''Extract leaderboard data from ONNX model or return defaults.
1200
+
1201
+ This function performs single-pass normalization and safely handles:
1202
+ - None onnx_model (returns defaults)
1203
+ - Invalid metadata structures
1204
+ - Missing keys in metadata
1205
+ '''
1206
+
1207
+ # Start with eval_metrics if provided, otherwise empty dict
1208
+ if eval_metrics is not None:
1209
+ metadata = dict(eval_metrics) if isinstance(eval_metrics, dict) else {}
1210
+ else:
1211
+ metadata = {}
1212
+
1213
+ # Handle None onnx_model gracefully
1214
+ if onnx_model is None:
1215
+ if os.environ.get("AIMODELSHARE_DEBUG_METADATA"):
1216
+ print("[DEBUG] _get_leaderboard_data: onnx_model is None, using default metadata")
1217
+ # Return metadata with safe defaults injected
1218
+ metadata['ml_framework'] = metadata.get('ml_framework', None)
1219
+ metadata['transfer_learning'] = metadata.get('transfer_learning', None)
1220
+ metadata['deep_learning'] = metadata.get('deep_learning', None)
1221
+ metadata['model_type'] = metadata.get('model_type', None)
1222
+ metadata['depth'] = metadata.get('depth', 0)
1223
+ metadata['num_params'] = metadata.get('num_params', 0)
1224
+ return metadata
1225
+
1226
+ # Get metadata from ONNX - _get_metadata now always returns a dict
1227
+ metadata_raw = _get_metadata(onnx_model)
1228
+
1229
+ if os.environ.get("AIMODELSHARE_DEBUG_METADATA"):
1230
+ print(f"[DEBUG] _get_leaderboard_data: metadata_raw type={type(metadata_raw)}, keys={list(metadata_raw.keys()) if isinstance(metadata_raw, dict) else 'N/A'}")
1231
+
1232
+ # Single-pass normalization: ensure metadata_raw is a dict
1233
+ if not isinstance(metadata_raw, dict):
1234
+ if os.environ.get("AIMODELSHARE_DEBUG_METADATA"):
1235
+ print(f"[DEBUG] _get_leaderboard_data: metadata_raw is not a dict (type={type(metadata_raw)}), using empty dict")
1236
+ metadata_raw = {}
1237
+
1238
+ # get list of current layer types
1239
+ layer_list_keras, activation_list_keras = _get_layer_names()
1240
+ layer_list_pytorch, activation_list_pytorch = _get_layer_names_pytorch()
1241
+
1242
+ layer_list = list(set(layer_list_keras + layer_list_pytorch))
1243
+ activation_list = list(set(activation_list_keras + activation_list_pytorch))
1244
+
1245
+ # get general model info - use .get() for safety
1246
+ metadata['ml_framework'] = metadata_raw.get('ml_framework')
1247
+ metadata['transfer_learning'] = metadata_raw.get('transfer_learning')
1248
+ metadata['deep_learning'] = metadata_raw.get('deep_learning')
1249
+ metadata['model_type'] = metadata_raw.get('model_type')
1250
+
1251
+
1252
+ # get neural network metrics
1253
+ # Add isinstance check for model_architecture to prevent TypeError
1254
+ if (metadata_raw.get('ml_framework') in ['keras', 'pytorch'] or
1255
+ metadata_raw.get('model_type') in ['MLPClassifier', 'MLPRegressor']) and \
1256
+ isinstance(metadata_raw.get('model_architecture'), dict):
1257
+
1258
+ metadata['depth'] = metadata_raw['model_architecture'].get('layers_number', 0)
1259
+ metadata['num_params'] = sum(metadata_raw['model_architecture'].get('layers_n_params', []))
1260
+
1261
+ for i in layer_list:
1262
+ layers_summary = metadata_raw['model_architecture'].get('layers_summary', {})
1263
+ if i in layers_summary:
1264
+ metadata[i.lower()+'_layers'] = layers_summary[i]
1265
+ elif i.lower()+'_layers' not in metadata.keys():
1266
+ metadata[i.lower()+'_layers'] = 0
1267
+
1268
+ for i in activation_list:
1269
+ activations_summary = metadata_raw['model_architecture'].get('activations_summary', {})
1270
+ if i in activations_summary:
1271
+ if i.lower()+'_act' in metadata:
1272
+ metadata[i.lower()+'_act'] += activations_summary[i]
1273
+ else:
1274
+ metadata[i.lower()+'_act'] = activations_summary[i]
1275
+ else:
1276
+ if i.lower()+'_act' not in metadata:
1277
+ metadata[i.lower()+'_act'] = 0
1278
+
1279
+ metadata['loss'] = metadata_raw['model_architecture'].get('loss')
1280
+ metadata['optimizer'] = metadata_raw['model_architecture'].get('optimizer')
1281
+ metadata['model_config'] = metadata_raw.get('model_config')
1282
+ metadata['epochs'] = metadata_raw.get('epochs')
1283
+ metadata['memory_size'] = metadata_raw.get('memory_size')
1284
+
1285
+ # get sklearn & pyspark model metrics
1286
+ elif metadata_raw.get('ml_framework') in ['sklearn', 'xgboost', 'pyspark']:
1287
+ metadata['depth'] = 0
1288
+
1289
+ try:
1290
+ if isinstance(metadata_raw.get('model_architecture'), dict):
1291
+ metadata['num_params'] = sum(metadata_raw['model_architecture'].get('layers_n_params', []))
1292
+ else:
1293
+ metadata['num_params'] = 0
1294
+ except:
1295
+ metadata['num_params'] = 0
1296
+
1297
+ for i in layer_list:
1298
+ metadata[i.lower()+'_layers'] = 0
1299
+
1300
+ for i in activation_list:
1301
+ metadata[i.lower()+'_act'] = 0
1302
+
1303
+ metadata['loss'] = None
1304
+
1305
+ try:
1306
+ if isinstance(metadata_raw.get('model_architecture'), dict):
1307
+ metadata['optimizer'] = metadata_raw['model_architecture'].get('optimizer')
1308
+ else:
1309
+ metadata['optimizer'] = None
1310
+ except:
1311
+ metadata['optimizer'] = None
1312
+
1313
+ try:
1314
+ metadata['model_config'] = metadata_raw.get('model_config')
1315
+ except:
1316
+ metadata['model_config'] = None
1317
+
1318
+ # Default handling for unknown frameworks
1319
+ else:
1320
+ if os.environ.get("AIMODELSHARE_DEBUG_METADATA"):
1321
+ print(f"[DEBUG] _get_leaderboard_data: Unknown framework '{metadata_raw.get('ml_framework')}', using defaults")
1322
+ metadata.setdefault('depth', 0)
1323
+ metadata.setdefault('num_params', 0)
1324
+ for i in layer_list:
1325
+ metadata.setdefault(i.lower()+'_layers', 0)
1326
+ for i in activation_list:
1327
+ metadata.setdefault(i.lower()+'_act', 0)
1328
+
1329
+ return metadata
1330
+
1331
+
1332
+
1333
+ def _model_summary(meta_dict, from_onnx=False):
1334
+ '''Creates model summary table from model metadata dict.'''
1335
+ import io
1336
+
1337
+ assert(isinstance(meta_dict, dict)), \
1338
+ "Please pass valid metadata dict."
1339
+
1340
+ assert('model_architecture' in meta_dict.keys()), \
1341
+ "Please make sure model architecture data is included."
1342
+
1343
+ if from_onnx == True:
1344
+ model_summary = pd.read_json(io.StringIO(meta_dict['metadata_onnx']["model_summary"]))
1345
+ else:
1346
+ model_summary = pd.read_json(io.StringIO(meta_dict["model_summary"]))
1347
+
1348
+ return model_summary
1349
+
1350
+
1351
+
1352
+ def onnx_to_image(model):
1353
+ '''Creates model graph image in pydot format.'''
1354
+
1355
+ OP_STYLE = {
1356
+ 'shape': 'box',
1357
+ 'color': '#0F9D58',
1358
+ 'style': 'filled',
1359
+ 'fontcolor': '#FFFFFF'
1360
+ }
1361
+
1362
+ pydot_graph = GetPydotGraph(
1363
+ model.graph,
1364
+ name=model.graph.name,
1365
+ rankdir='TB',
1366
+ node_producer=GetOpNodeProducer(
1367
+ embed_docstring=False,
1368
+ **OP_STYLE
1369
+ )
1370
+ )
1371
+
1372
+ return pydot_graph
1373
+
1374
+
1375
+ def inspect_model(apiurl, version=None, naming_convention = None, submission_type="competition"):
1376
+ if all(["username" in os.environ,
1377
+ "password" in os.environ]):
1378
+ pass
1379
+ else:
1380
+ return print("'Inspect Model' unsuccessful. Please provide credentials with set_credentials().")
1381
+
1382
+ post_dict = {"y_pred": [],
1383
+ "return_eval": "False",
1384
+ "return_y": "False",
1385
+ "inspect_model": "True",
1386
+ "version": version,
1387
+ "submission_type": submission_type
1388
+ }
1389
+
1390
+ headers = { 'Content-Type':'application/json', 'authorizationToken': os.environ.get("AWS_TOKEN"),}
1391
+
1392
+ apiurl_eval=apiurl[:-1]+"eval"
1393
+
1394
+ inspect_json = requests.post(apiurl_eval,headers=headers,data=json.dumps(post_dict))
1395
+
1396
+ inspect_pd = pd.DataFrame(json.loads(inspect_json.text))
1397
+
1398
+ return inspect_pd
1399
+
1400
+
1401
+
1402
+ def color_pal_assign(val, naming_convention=None):
1403
+
1404
+ # find path of color mapping
1405
+ path = Path(__file__).parent
1406
+ if naming_convention == "keras":
1407
+ col_map = pd.read_csv(path / "color_mappings/color_mapping_keras.csv")
1408
+ elif naming_convention == "pytorch":
1409
+ col_map = pd.read_csv(path / "color_mappings/color_mapping_pytorch.csv")
1410
+
1411
+ # get color for layer key
1412
+ try:
1413
+ color = col_map[col_map.iloc[:,1] == val].iloc[:,2].values[0]
1414
+ except:
1415
+ color = "white"
1416
+
1417
+ return 'background: %s' % color
1418
+
1419
+
1420
+ def stylize_model_comparison(comp_dict_out, naming_convention=None):
1421
+
1422
+ for i in comp_dict_out.keys():
1423
+
1424
+ if i == 'nn':
1425
+
1426
+ df_styled = comp_dict_out['nn'].style.applymap(color_pal_assign, naming_convention=naming_convention)
1427
+
1428
+ df_styled = df_styled.set_properties(**{'color': 'black'})
1429
+
1430
+ df_styled = df_styled.set_caption('Model type: ' + 'Neural Network').set_table_styles([{'selector': 'caption',
1431
+ 'props': [('color', 'white'), ('font-size', '18px')]}])
1432
+
1433
+ df_styled = df_styled.set_properties(**{'color': 'black'})
1434
+
1435
+ df_styled = df_styled.set_caption('Model type: ' + 'Neural Network').set_table_styles([{'selector': 'caption',
1436
+ 'props': [('color', 'black'), ('font-size', '18px')]}])
1437
+
1438
+ display(HTML(Styler.to_html(df_styled)))
1439
+
1440
+ elif 'undefined' in i:
1441
+
1442
+ version = i.split('_')[-1]
1443
+
1444
+ df_styled = comp_dict_out[i].style
1445
+
1446
+ df_styled = df_styled.set_caption("No metadata available for model "+ str(version)).set_table_styles([{'selector': 'caption',
1447
+ 'props': [('color', 'black'), ('font-size', '18px')]}])
1448
+
1449
+ display(HTML(Styler.to_html(df_styled)))
1450
+ print('\n\n')
1451
+
1452
+ else:
1453
+
1454
+ comp_pd = comp_dict_out[i]
1455
+
1456
+ df_styled = comp_pd.style.apply(lambda x: ["background: tomato" if v != x.iloc[0] else "" for v in x],
1457
+ axis = 1, subset=comp_pd.columns[1:])
1458
+
1459
+ df_styled = df_styled.set_caption('Model type: ' + i).set_table_styles([{'selector': 'caption',
1460
+ 'props': [('color', 'black'), ('font-size', '18px')]}])
1461
+
1462
+ display(HTML(Styler.to_html(df_styled)))
1463
+ print('\n\n')
1464
+
1465
+
1466
+
1467
+ def compare_models(apiurl, version_list="None",
1468
+ by_model_type=None, best_model=None, verbose=1, naming_convention=None, submission_type="competition"):
1469
+ if all(["username" in os.environ,
1470
+ "password" in os.environ]):
1471
+ pass
1472
+ else:
1473
+ return print("'Inspect Model' unsuccessful. Please provide credentials with set_credentials().")
1474
+
1475
+ post_dict = {"y_pred": [],
1476
+ "return_eval": "False",
1477
+ "return_y": "False",
1478
+ "inspect_model": "False",
1479
+ "version": "None",
1480
+ "compare_models": "True",
1481
+ "version_list": version_list,
1482
+ "verbose": verbose,
1483
+ "naming_convention": naming_convention,
1484
+ "submission_type": submission_type}
1485
+
1486
+ headers = { 'Content-Type':'application/json', 'authorizationToken': os.environ.get("AWS_TOKEN"),}
1487
+
1488
+ apiurl_eval=apiurl[:-1]+"eval"
1489
+
1490
+ compare_json = requests.post(apiurl_eval,headers=headers,data=json.dumps(post_dict))
1491
+
1492
+ compare_dict = json.loads(compare_json.text)
1493
+
1494
+ comp_dict_out = {i: pd.DataFrame(json.loads(compare_dict[i])) for i in compare_dict}
1495
+
1496
+ return comp_dict_out
1497
+
1498
+
1499
+ def _get_onnx_from_string(onnx_string):
1500
+ # generate tempfile for onnx object
1501
+ temp_dir = tempfile.gettempdir()
1502
+ temp_path = os.path.join(temp_dir, 'temp_file_name')
1503
+
1504
+ # save onnx to temporary path
1505
+ with open(temp_path, "wb") as f:
1506
+ f.write(onnx_string)
1507
+
1508
+ # load onnx file from temporary path
1509
+ onx = onnx.load(temp_path)
1510
+ return onx
1511
+
1512
+ def _get_onnx_from_bucket(apiurl, aws_client, version=None):
1513
+
1514
+ # generate name of onnx model in bucket
1515
+ onnx_model_name = "/onnx_model_v{version}.onnx".format(version = version)
1516
+
1517
+ # Get bucket and model_id for user
1518
+ response, error = run_function_on_lambda(
1519
+ apiurl, **{"delete": "FALSE", "versionupdateget": "TRUE"}
1520
+ )
1521
+ if error is not None:
1522
+ raise error
1523
+
1524
+ _, bucket, model_id = json.loads(response.content.decode("utf-8"))
1525
+
1526
+ try:
1527
+ onnx_string = aws_client["client"].get_object(
1528
+ Bucket=bucket, Key=model_id + onnx_model_name
1529
+ )
1530
+
1531
+ onnx_string = onnx_string['Body'].read()
1532
+
1533
+ except Exception as err:
1534
+ raise err
1535
+
1536
+ onx = _get_onnx_from_string(onnx_string)
1537
+
1538
+ return onx
1539
+
1540
+
1541
+
1542
+ def instantiate_model(apiurl, version=None, trained=False, reproduce=False, submission_type="competition"):
1543
+ # Confirm that creds are loaded, print warning if not
1544
+ if all(["username" in os.environ,
1545
+ "password" in os.environ]):
1546
+ pass
1547
+ else:
1548
+ return print("'Submit Model' unsuccessful. Please provide credentials with set_credentials().")
1549
+
1550
+ post_dict = {
1551
+ "y_pred": [],
1552
+ "return_eval": "False",
1553
+ "return_y": "False",
1554
+ "inspect_model": "False",
1555
+ "version": "None",
1556
+ "compare_models": "False",
1557
+ "version_list": "None",
1558
+ "get_leaderboard": "False",
1559
+ "instantiate_model": "True",
1560
+ "reproduce": reproduce,
1561
+ "trained": trained,
1562
+ "model_version": version,
1563
+ "submission_type": submission_type
1564
+ }
1565
+
1566
+ headers = { 'Content-Type':'application/json', 'authorizationToken': os.environ.get("AWS_TOKEN"),}
1567
+
1568
+ apiurl_eval=apiurl[:-1]+"eval"
1569
+
1570
+ resp = requests.post(apiurl_eval,headers=headers,data=json.dumps(post_dict))
1571
+
1572
+ # Missing Check for response from Lambda.
1573
+ try :
1574
+ resp.raise_for_status()
1575
+ except requests.exceptions.HTTPError :
1576
+ raise Exception(f"Error: Received {resp.status_code} from AWS, Please check if Model Version is correct.")
1577
+
1578
+
1579
+ resp_dict = json.loads(resp.text)
1580
+
1581
+ if resp_dict['model_metadata'] == None:
1582
+ print("Model for this version doesn't exist or is not submitted by the author")
1583
+ return None
1584
+
1585
+ if reproduce:
1586
+ if resp_dict['reproducibility_env'] != None:
1587
+ set_reproducibility_env(resp_dict['reproducibility_env'])
1588
+ print("Your reproducibility environment is successfully setup")
1589
+ else:
1590
+ print("Reproducibility environment is not found")
1591
+
1592
+ print("Instantiate the model from metadata..")
1593
+
1594
+ model_metadata = resp_dict['model_metadata']
1595
+ model_weight_url = resp_dict['model_weight_url']
1596
+ model_config = ast.literal_eval(model_metadata['model_config'])
1597
+ ml_framework = model_metadata['ml_framework']
1598
+
1599
+ if ml_framework == 'sklearn':
1600
+ if trained == False or reproduce == True:
1601
+ model_type = model_metadata['model_type']
1602
+ model_class = model_from_string(model_type)
1603
+ model = model_class(**model_config)
1604
+
1605
+ elif trained == True:
1606
+ model_pkl = None
1607
+ temp = tempfile.mkdtemp()
1608
+ temp_path = temp + "/" + "onnx_model_v{}.onnx".format(version)
1609
+
1610
+ # Get leaderboard
1611
+ status = wget.download(model_weight_url, out=temp_path)
1612
+ onnx_model = onnx.load(temp_path)
1613
+ model_pkl = _get_metadata(onnx_model)['model_weights']
1614
+
1615
+ temp_dir = tempfile.gettempdir()
1616
+ temp_path = os.path.join(temp_dir, 'temp_file_name')
1617
+
1618
+ with open(temp_path, "wb") as f:
1619
+ f.write(model_pkl)
1620
+
1621
+ with open(temp_path, 'rb') as f:
1622
+ model = pickle.load(f)
1623
+
1624
+ if ml_framework == 'pyspark':
1625
+ try:
1626
+ if pyspark is None:
1627
+ raise("Error: Please install pyspark to enable pyspark features")
1628
+ except:
1629
+ raise("Error: Please install pyspark to enable pyspark features")
1630
+
1631
+ if not trained or reproduce:
1632
+ print("Pyspark model can only be instantiated in trained mode.")
1633
+ print("Please rerun the function with proper parameters.")
1634
+ return None
1635
+
1636
+ # pyspark model object is always trained. The unfitted / untrained one
1637
+ # is the estimator and cannot be treated as model.
1638
+ # Model is transformer and created by estimator
1639
+ model_pkl = None
1640
+ temp = tempfile.mkdtemp()
1641
+ temp_path = temp + "/" + "onnx_model_v{}.onnx".format(version)
1642
+
1643
+ # Get leaderboard
1644
+ status = wget.download(model_weight_url, out=temp_path)
1645
+ onnx_model = onnx.load(temp_path)
1646
+ model_pkl = _get_metadata(onnx_model)['model_weights']
1647
+
1648
+ temp_dir = tempfile.gettempdir()
1649
+ temp_path_zip = os.path.join(temp_dir, 'temp_pyspark_model.zip')
1650
+ temp_path = os.path.join(temp_dir, 'temp_pyspark_model')
1651
+
1652
+ if not os.path.exists(temp_path):
1653
+ os.mkdir(temp_path)
1654
+
1655
+ with open(temp_path_zip, "wb") as f:
1656
+ f.write(model_pkl)
1657
+
1658
+ dirname_idx = -1
1659
+ with ZipFile(temp_path_zip, 'r') as zip_file:
1660
+ zip_file.extractall(temp_path)
1661
+
1662
+ model_type = model_metadata['model_type']
1663
+ model_class = pyspark_model_from_string(model_type)
1664
+ # model_config is for the Estimator not the Transformer / Model
1665
+ model = model_class()
1666
+
1667
+ # Need spark session and context to instantiate model object
1668
+ spark = SparkSession \
1669
+ .builder \
1670
+ .appName('Pyspark Model') \
1671
+ .getOrCreate()
1672
+
1673
+ model = model.load(temp_path)
1674
+
1675
+ # clean up temp directory files for future runs
1676
+ try:
1677
+ shutil.rmtree(temp_path)
1678
+ os.remove(temp_path_zip)
1679
+ except:
1680
+ pass
1681
+
1682
+ if ml_framework == 'keras':
1683
+ if trained == False or reproduce == True:
1684
+ model = tf.keras.Sequential().from_config(model_config)
1685
+
1686
+ elif trained == True:
1687
+ model_weights = None
1688
+ temp = tempfile.mkdtemp()
1689
+ temp_path = temp + "/" + "onnx_model_v{}.onnx".format(version)
1690
+
1691
+ # Get leaderboard
1692
+ status = wget.download(model_weight_url, out=temp_path)
1693
+ onnx_model = onnx.load(temp_path)
1694
+ import pickle
1695
+ model_weights=pickle.loads(_get_metadata(onnx_model)['model_weights'])
1696
+
1697
+ model = tf.keras.Sequential().from_config(model_config)
1698
+
1699
+ model.set_weights(model_weights)
1700
+
1701
+ print("Your model is successfully instantiated.")
1702
+ return model
1703
+
1704
+
1705
+ def _get_layer_names():
1706
+
1707
+ try:
1708
+ activation_list = [i for i in dir(tf.keras.activations)]
1709
+ activation_list = [i for i in activation_list if callable(getattr(tf.keras.activations, i))]
1710
+ activation_list = [i for i in activation_list if not i.startswith("_")]
1711
+ activation_list.remove('deserialize')
1712
+ activation_list.remove('get')
1713
+ activation_list.remove('linear')
1714
+ activation_list = activation_list+['Activation', 'ReLU', 'Softmax', 'LeakyReLU', 'PReLU', 'ELU', 'ThresholdedReLU']
1715
+
1716
+
1717
+ layer_list = [i for i in dir(tf.keras.layers)]
1718
+ layer_list = [i for i in dir(tf.keras.layers) if callable(getattr(tf.keras.layers, i))]
1719
+ layer_list = [i for i in layer_list if not i.startswith("_")]
1720
+ layer_list = [i for i in layer_list if re.match('^[A-Z]', i)]
1721
+ layer_list = [i for i in layer_list if i.lower() not in [i.lower() for i in activation_list]]
1722
+ except:
1723
+ return [], []
1724
+
1725
+ return layer_list, activation_list
1726
+
1727
+
1728
+ def _get_layer_names_pytorch():
1729
+
1730
+ activation_list = ['ELU', 'Hardshrink', 'Hardsigmoid', 'Hardtanh', 'Hardswish', 'LeakyReLU', 'LogSigmoid',
1731
+ 'MultiheadAttention', 'PReLU', 'ReLU', 'ReLU6', 'RReLU', 'SELU', 'CELU', 'GELU', 'Sigmoid',
1732
+ 'SiLU', 'Mish', 'Softplus', 'Softshrink', 'Softsign', 'Tanh', 'Tanhshrink', 'Threshold',
1733
+ 'GLU', 'Softmin', 'Softmax', 'Softmax2d', 'LogSoftmax', 'AdaptiveLogSoftmaxWithLoss']
1734
+ try:
1735
+ import torch
1736
+ layer_list = [i for i in dir(torch.nn) if callable(getattr(torch.nn, i))]
1737
+ layer_list = [i for i in layer_list if not i in activation_list and not 'Loss' in i]
1738
+ except:
1739
+ layer_list=["no layers found"]
1740
+ pass
1741
+
1742
+
1743
+ return layer_list, activation_list
1744
+
1745
+
1746
+ def _get_sklearn_modules():
1747
+
1748
+ try:
1749
+ import sklearn
1750
+
1751
+ sklearn_modules = ['ensemble', 'gaussian_process', 'isotonic',
1752
+ 'linear_model', 'mixture', 'multiclass', 'naive_bayes',
1753
+ 'neighbors', 'neural_network', 'svm', 'tree',
1754
+ 'discriminant_analysis', 'calibration']
1755
+
1756
+ models_modules_dict = {}
1757
+
1758
+ for i in sklearn_modules:
1759
+ models_list = [j for j in dir(eval('sklearn.'+i)) if callable(getattr(eval('sklearn.'+i), j))]
1760
+ models_list = [j for j in models_list if re.match('^[A-Z]', j)]
1761
+
1762
+ for k in models_list:
1763
+ models_modules_dict[k] = 'sklearn.'+i
1764
+ except:
1765
+ models_modules_dict = {}
1766
+
1767
+ return models_modules_dict
1768
+
1769
+
1770
+ def model_from_string(model_type):
1771
+ models_modules_dict = _get_sklearn_modules()
1772
+ try:
1773
+ module = models_modules_dict[model_type]
1774
+ model_class = getattr(importlib.import_module(module), model_type)
1775
+ return model_class
1776
+ except KeyError:
1777
+ # Return a placeholder class if estimator not found
1778
+ import warnings
1779
+ warnings.warn(f"Model type '{model_type}' not found in sklearn modules. Returning placeholder class.")
1780
+
1781
+ # Create a minimal placeholder class that can be instantiated
1782
+ class PlaceholderModel:
1783
+ def __init__(self, **kwargs):
1784
+ self._model_type = model_type
1785
+ self._params = kwargs
1786
+
1787
+ def get_params(self, deep=True):
1788
+ return self._params
1789
+
1790
+ def __str__(self):
1791
+ return f"PlaceholderModel({self._model_type})"
1792
+
1793
+ def __repr__(self):
1794
+ return f"PlaceholderModel({self._model_type})"
1795
+
1796
+ return PlaceholderModel
1797
+
1798
+ def _get_pyspark_modules():
1799
+ try:
1800
+ if pyspark is None:
1801
+ raise Exception("Error: Please install pyspark to enable pyspark features")
1802
+ except:
1803
+ raise Exception("Error: Please install pyspark to enable pyspark features")
1804
+
1805
+ pyspark_modules = ['ml', 'ml.feature', 'ml.classification', 'ml.clustering', 'ml.regression']
1806
+
1807
+ models_modules_dict = {}
1808
+
1809
+ for i in pyspark_modules:
1810
+ models_list = [j for j in dir(eval('pyspark.'+i)) if callable(getattr(eval('pyspark.'+i), j))]
1811
+ models_list = [j for j in models_list if re.match('^[A-Z]', j)]
1812
+
1813
+ for k in models_list:
1814
+ models_modules_dict[k] = 'pyspark.'+i
1815
+
1816
+ return models_modules_dict
1817
+
1818
+
1819
+ def pyspark_model_from_string(model_type):
1820
+ try:
1821
+ if pyspark is None:
1822
+ raise Exception("Error: Please install pyspark to enable pyspark features")
1823
+ except:
1824
+ raise Exception("Error: Please install pyspark to enable pyspark features")
1825
+
1826
+ models_modules_dict = _get_pyspark_modules()
1827
+ module = models_modules_dict[model_type]
1828
+ model_class = getattr(importlib.import_module(module), model_type)
1829
+ return model_class
1830
+
1831
+
1832
+ def print_y_stats(y_stats):
1833
+
1834
+ print("y_test example: ", y_stats['ytest_example'])
1835
+ print("y_test class labels", y_stats['class_labels'])
1836
+ print("y_test class balance", y_stats['class_balance'])
1837
+ print("y_test label dtypes", y_stats['label_dtypes'])
1838
+
1839
+
1840
+ def inspect_y_test(apiurl, submission_type):
1841
+
1842
+ # Confirm that creds are loaded, print warning if not
1843
+ if all(["username" in os.environ,
1844
+ "password" in os.environ]):
1845
+ pass
1846
+ else:
1847
+ return print("'Submit Model' unsuccessful. Please provide credentials with set_credentials().")
1848
+
1849
+ post_dict = {"y_pred": [],
1850
+ "return_eval": "False",
1851
+ "return_y": "True",
1852
+ "submission_type": submission_type}
1853
+
1854
+ headers = { 'Content-Type':'application/json', 'authorizationToken': os.environ.get("AWS_TOKEN"),}
1855
+
1856
+ apiurl_eval=apiurl[:-1]+"eval"
1857
+
1858
+ y_stats = requests.post(apiurl_eval,headers=headers,data=json.dumps(post_dict))
1859
+
1860
+ y_stats_dict = json.loads(y_stats.text)
1861
+
1862
+ # print_y_stats(y_stats_dict)
1863
+
1864
+ return y_stats_dict
1865
+
1866
+
1867
+
1868
+ def model_summary_keras(model):
1869
+
1870
+ # extract model architecture metadata
1871
+ layer_names = []
1872
+ layer_types = []
1873
+ layer_n_params = []
1874
+ layer_shapes = []
1875
+ layer_connect = []
1876
+ activations = []
1877
+
1878
+ keras_layers = keras_unpack(model)
1879
+
1880
+ for i in keras_layers:
1881
+
1882
+ try:
1883
+ layer_names.append(i.name)
1884
+ except:
1885
+ layer_names.append(None)
1886
+
1887
+ try:
1888
+ layer_types.append(i.__class__.__name__)
1889
+ except:
1890
+ layer_types.append(None)
1891
+
1892
+ try:
1893
+ layer_shapes.append(i.output_shape)
1894
+ except:
1895
+ layer_shapes.append(None)
1896
+
1897
+ try:
1898
+ layer_n_params.append(i.count_params())
1899
+ except:
1900
+ layer_n_params.append(None)
1901
+
1902
+ try:
1903
+ if isinstance(i.inbound_nodes[0].inbound_layers, list):
1904
+ layer_connect.append([x.name for x in i.inbound_nodes[0].inbound_layers])
1905
+ else:
1906
+ layer_connect.append(i.inbound_nodes[0].inbound_layers.name)
1907
+ except:
1908
+ layer_connect.append(None)
1909
+
1910
+ try:
1911
+ activations.append(i.activation.__name__)
1912
+ except:
1913
+ activations.append(None)
1914
+
1915
+
1916
+ model_summary = pd.DataFrame({"Name": layer_names,
1917
+ "Layer": layer_types,
1918
+ "Shape": layer_shapes,
1919
+ "Params": layer_n_params,
1920
+ "Connect": layer_connect,
1921
+ "Activation": activations})
1922
+
1923
+ return model_summary
1924
+
1925
+
1926
+
1927
+ def model_graph_keras(model):
1928
+
1929
+ # extract model architecture metadata
1930
+ layer_names = []
1931
+ layer_types = []
1932
+ layer_n_params = []
1933
+ layer_shapes = []
1934
+ layer_connect = []
1935
+ activations = []
1936
+
1937
+ graph_nodes = []
1938
+ graph_edges = []
1939
+
1940
+
1941
+ for i in model.layers:
1942
+
1943
+ try:
1944
+ layer_name = i.name
1945
+ except:
1946
+ layer_name = None
1947
+ finally:
1948
+ layer_names.append(layer_name)
1949
+
1950
+
1951
+ try:
1952
+ layer_type = i.__class__.__name__
1953
+ except:
1954
+ layer_type = None
1955
+ finally:
1956
+ layer_types.append(layer_type)
1957
+
1958
+ try:
1959
+ layer_shape = i.output_shape
1960
+ except:
1961
+ layer_shape = None
1962
+ finally:
1963
+ layer_shapes.append(layer_shape)
1964
+
1965
+
1966
+ try:
1967
+ layer_params = i.count_params()
1968
+ except:
1969
+ layer_params = None
1970
+ finally:
1971
+ layer_n_params.append(layer_params)
1972
+
1973
+
1974
+ try:
1975
+ if isinstance(i.inbound_nodes[0].inbound_layers, list):
1976
+ layer_input = [x.name for x in i.inbound_nodes[0].inbound_layers]
1977
+ else:
1978
+ layer_input = i.inbound_nodes[0].inbound_layers.name
1979
+ except:
1980
+ layer_connect = None
1981
+ finally:
1982
+ layer_connect.append(layer_input)
1983
+
1984
+
1985
+ try:
1986
+ activation = i.activation.__name__
1987
+ except:
1988
+ activation = None
1989
+ finally:
1990
+ activations.append(activation)
1991
+
1992
+ layer_color = color_pal_assign(layer_type)
1993
+ layer_color = layer_color.split(' ')[-1]
1994
+
1995
+
1996
+ graph_nodes.append((layer_name, {"label": layer_type + '\n' + str(layer_shape),
1997
+ "URL": "https://keras.io/search.html?query="+layer_type.lower(),
1998
+ "color": layer_color,
1999
+ "style": "bold",
2000
+ "Name": layer_name,
2001
+ "Layer": layer_type,
2002
+ "Shape": layer_shape,
2003
+ "Params": layer_params,
2004
+ "Activation": activation}))
2005
+
2006
+ if isinstance(layer_input, list):
2007
+ for i in layer_input:
2008
+ graph_edges.append((i, layer_name))
2009
+ else:
2010
+ graph_edges.append((layer_input, layer_name))
2011
+
2012
+ G = nx.DiGraph()
2013
+ G.add_nodes_from(graph_nodes)
2014
+ G.add_edges_from(graph_edges)
2015
+
2016
+ G_pydot = nx.drawing.nx_pydot.to_pydot(G)
2017
+
2018
+ return G_pydot
2019
+
2020
+
2021
+ def plot_keras(model):
2022
+
2023
+ G = model_graph_keras(model)
2024
+
2025
+ display(SVG(G.create_svg()))
2026
+
2027
+
2028
+
2029
+ def torch_unpack(model):
2030
+
2031
+ layers = []
2032
+ keys = []
2033
+
2034
+ for key, module in model._modules.items():
2035
+
2036
+ if len(module._modules):
2037
+
2038
+ layers_out, keys_out = torch_unpack(module)
2039
+
2040
+ layers = layers + layers_out
2041
+ keys = keys + keys_out
2042
+
2043
+
2044
+ else:
2045
+
2046
+ layers.append(module)
2047
+ keys.append(key)
2048
+
2049
+ return layers, keys
2050
+
2051
+
2052
+ def keras_unpack(model):
2053
+ layers = []
2054
+ try:
2055
+ for module in model.layers:
2056
+ if isinstance(module, (tf.keras.Model, tf.keras.Sequential)):
2057
+ layers += keras_unpack(module)
2058
+ else:
2059
+ layers.append(module)
2060
+ except:
2061
+ pass
2062
+ return layers
2063
+
2064
+
2065
+ def torch_metadata(model):
2066
+
2067
+ name_list_out = []
2068
+ layer_list = []
2069
+ param_list = []
2070
+ weight_list = []
2071
+ activation_list = []
2072
+
2073
+ try:
2074
+ layers, name_list = torch_unpack(model)
2075
+
2076
+ layer_names, activation_names = _get_layer_names_pytorch()
2077
+
2078
+ for module, name in zip(layers, name_list):
2079
+
2080
+ module_name = module._get_name()
2081
+
2082
+
2083
+ if module_name in layer_names:
2084
+
2085
+ name_list_out.append(name)
2086
+
2087
+ layer_list.append(module_name)
2088
+
2089
+ params = sum([np.prod(p.size()) for p in module.parameters()])
2090
+ param_list.append(params)
2091
+
2092
+ weights = tuple([tuple(p.size()) for p in module.parameters()])
2093
+ weight_list.append(weights)
2094
+
2095
+ if module_name in activation_names:
2096
+
2097
+ activation_list.append(module_name)
2098
+ except:
2099
+ pass
2100
+
2101
+ return name_list_out, layer_list, param_list, weight_list, activation_list
2102
+
2103
+
2104
+ def layer_mapping(direction='torch_to_keras', activation=False):
2105
+
2106
+ torch_keras = {'AdaptiveAvgPool1d': 'AvgPool1D',
2107
+ 'AdaptiveAvgPool2d': 'AvgPool2D',
2108
+ 'AdaptiveAvgPool3d': 'AvgPool3D',
2109
+ 'AdaptiveMaxPool1d': 'MaxPool1D',
2110
+ 'AdaptiveMaxPool2d': 'MaxPool2D',
2111
+ 'AdaptiveMaxPool3d': 'MaxPool3D',
2112
+ 'AlphaDropout': None,
2113
+ 'AvgPool1d': 'AvgPool1D',
2114
+ 'AvgPool2d': 'AvgPool2D',
2115
+ 'AvgPool3d': 'AvgPool3D',
2116
+ 'BatchNorm1d': 'BatchNormalization',
2117
+ 'BatchNorm2d': 'BatchNormalization',
2118
+ 'BatchNorm3d': 'BatchNormalization',
2119
+ 'Bilinear': None,
2120
+ 'ConstantPad1d': None,
2121
+ 'ConstantPad2d': None,
2122
+ 'ConstantPad3d': None,
2123
+ 'Container': None,
2124
+ 'Conv1d': 'Conv1D',
2125
+ 'Conv2d': 'Conv2D',
2126
+ 'Conv3d': 'Conv3D',
2127
+ 'ConvTranspose1d': 'Conv1DTranspose',
2128
+ 'ConvTranspose2d': 'Conv2DTranspose',
2129
+ 'ConvTranspose3d': 'Conv3DTranspose',
2130
+ 'CosineSimilarity': None,
2131
+ 'CrossMapLRN2d': None,
2132
+ 'DataParallel': None,
2133
+ 'Dropout': 'Dropout',
2134
+ 'Dropout2d': 'Dropout',
2135
+ 'Dropout3d': 'Dropout',
2136
+ 'Embedding': 'Embedding',
2137
+ 'EmbeddingBag': 'Embedding',
2138
+ 'FeatureAlphaDropout': None,
2139
+ 'Flatten': 'Flatten',
2140
+ 'Fold': None,
2141
+ 'FractionalMaxPool2d': "MaxPool2D",
2142
+ 'FractionalMaxPool3d': "MaxPool3D",
2143
+ 'GRU': 'GRU',
2144
+ 'GRUCell': 'GRUCell',
2145
+ 'GroupNorm': None,
2146
+ 'Identity': None,
2147
+ 'InstanceNorm1d': None,
2148
+ 'InstanceNorm2d': None,
2149
+ 'InstanceNorm3d': None,
2150
+ 'LPPool1d': None,
2151
+ 'LPPool2d': None,
2152
+ 'LSTM': 'LSTM',
2153
+ 'LSTMCell': 'LSTMCell',
2154
+ 'LayerNorm': None,
2155
+ 'Linear': 'Dense',
2156
+ 'LocalResponseNorm': None,
2157
+ 'MaxPool1d': 'MaxPool1D',
2158
+ 'MaxPool2d': 'MaxPool2D',
2159
+ 'MaxPool3d': 'MaxPool3D',
2160
+ 'MaxUnpool1d': None,
2161
+ 'MaxUnpool2d': None,
2162
+ 'MaxUnpool3d': None,
2163
+ 'Module': None,
2164
+ 'ModuleDict': None,
2165
+ 'ModuleList': None,
2166
+ 'PairwiseDistance': None,
2167
+ 'Parameter': None,
2168
+ 'ParameterDict': None,
2169
+ 'ParameterList': None,
2170
+ 'PixelShuffle': None,
2171
+ 'RNN': 'RNN',
2172
+ 'RNNBase': None,
2173
+ 'RNNCell': None,
2174
+ 'RNNCellBase': None,
2175
+ 'ReflectionPad1d': None,
2176
+ 'ReflectionPad2d': None,
2177
+ 'ReplicationPad1d': None,
2178
+ 'ReplicationPad2d': None,
2179
+ 'ReplicationPad3d': None,
2180
+ 'Sequential': None,
2181
+ 'SyncBatchNorm': None,
2182
+ 'Transformer': None,
2183
+ 'TransformerDecoder': None,
2184
+ 'TransformerDecoderLayer': None,
2185
+ 'TransformerEncoder': None,
2186
+ 'TransformerEncoderLayer': None,
2187
+ 'Unfold': None,
2188
+ 'Upsample': 'UpSampling1D',
2189
+ 'UpsamplingBilinear2d': 'UpSampling2D',
2190
+ 'UpsamplingNearest2d': 'UpSampling2D',
2191
+ 'ZeroPad2d': 'ZeroPadding2D'}
2192
+
2193
+ keras_torch = {'AbstractRNNCell': None,
2194
+ 'Activation': None,
2195
+ 'ActivityRegularization': None,
2196
+ 'Add': None,
2197
+ 'AdditiveAttention': None,
2198
+ 'AlphaDropout': None,
2199
+ 'Attention': None,
2200
+ 'Average': None,
2201
+ 'AveragePooling1D': 'AvgPool1d',
2202
+ 'AveragePooling2D': 'AvgPool2d',
2203
+ 'AveragePooling3D': 'AvgPool3d',
2204
+ 'AvgPool1D': 'AvgPool1d',
2205
+ 'AvgPool2D': 'AvgPool2d',
2206
+ 'AvgPool3D': 'AvgPool3d',
2207
+ 'BatchNormalization': None,
2208
+ 'Bidirectional': None,
2209
+ 'Concatenate': None,
2210
+ 'Conv1D': 'Conv1d',
2211
+ 'Conv1DTranspose': 'ConvTranspose1d',
2212
+ 'Conv2D': 'Conv2d',
2213
+ 'Conv2DTranspose': 'ConvTranspose2d',
2214
+ 'Conv3D': 'Conv3d',
2215
+ 'Conv3DTranspose': 'ConvTranspose3d',
2216
+ 'ConvLSTM2D': None,
2217
+ 'Convolution1D': None,
2218
+ 'Convolution1DTranspose': None,
2219
+ 'Convolution2D': None,
2220
+ 'Convolution2DTranspose': None,
2221
+ 'Convolution3D': None,
2222
+ 'Convolution3DTranspose': None,
2223
+ 'Cropping1D': None,
2224
+ 'Cropping2D': None,
2225
+ 'Cropping3D': None,
2226
+ 'Dense': 'Linear',
2227
+ 'DenseFeatures': None,
2228
+ 'DepthwiseConv2D': None,
2229
+ 'Dot': None,
2230
+ 'Dropout': 'Dropout',
2231
+ 'Embedding': 'Embedding',
2232
+ 'Flatten': 'Flatten',
2233
+ 'GRU': 'GRU',
2234
+ 'GRUCell': 'GRUCell',
2235
+ 'GaussianDropout': None,
2236
+ 'GaussianNoise': None,
2237
+ 'GlobalAveragePooling1D': None,
2238
+ 'GlobalAveragePooling2D': None,
2239
+ 'GlobalAveragePooling3D': None,
2240
+ 'GlobalAvgPool1D': None,
2241
+ 'GlobalAvgPool2D': None,
2242
+ 'GlobalAvgPool3D': None,
2243
+ 'GlobalMaxPool1D': None,
2244
+ 'GlobalMaxPool2D': None,
2245
+ 'GlobalMaxPool3D': None,
2246
+ 'GlobalMaxPooling1D': None,
2247
+ 'GlobalMaxPooling2D': None,
2248
+ 'GlobalMaxPooling3D': None,
2249
+ 'Input': None,
2250
+ 'InputLayer': None,
2251
+ 'InputSpec': None,
2252
+ 'LSTM': 'LSTM',
2253
+ 'LSTMCell': 'LSTMCell',
2254
+ 'Lambda': None,
2255
+ 'Layer': None,
2256
+ 'LayerNormalization': None,
2257
+ 'LocallyConnected1D': None,
2258
+ 'LocallyConnected2D': None,
2259
+ 'Masking': None,
2260
+ 'MaxPool1D': 'MaxPool1d',
2261
+ 'MaxPool2D': 'MaxPool2d',
2262
+ 'MaxPool3D': 'MaxPool3d',
2263
+ 'MaxPooling1D': 'MaxPool1d',
2264
+ 'MaxPooling2D': 'MaxPool2d',
2265
+ 'MaxPooling3D': 'MaxPool3d',
2266
+ 'Maximum': None,
2267
+ 'Minimum': None,
2268
+ 'MultiHeadAttention': None,
2269
+ 'Multiply': None,
2270
+ 'Permute': None,
2271
+ 'RNN': 'RNN',
2272
+ 'RepeatVector': None,
2273
+ 'Reshape': None,
2274
+ 'SeparableConv1D': None,
2275
+ 'SeparableConv2D': None,
2276
+ 'SeparableConvolution1D': None,
2277
+ 'SeparableConvolution2D': None,
2278
+ 'SimpleRNN': None,
2279
+ 'SimpleRNNCell': None,
2280
+ 'SpatialDropout1D': None,
2281
+ 'SpatialDropout2D': None,
2282
+ 'SpatialDropout3D': None,
2283
+ 'StackedRNNCells': None,
2284
+ 'Subtract': None,
2285
+ 'TimeDistributed': None,
2286
+ 'UpSampling1D': 'Upsample',
2287
+ 'UpSampling2D': None,
2288
+ 'UpSampling3D': None,
2289
+ 'Wrapper': None,
2290
+ 'ZeroPadding1D': None,
2291
+ 'ZeroPadding2D': 'ZeroPad2d',
2292
+ 'ZeroPadding3D': None}
2293
+
2294
+ torch_keras_act = {
2295
+ 'AdaptiveLogSoftmaxWithLoss': None,
2296
+ 'CELU': None,
2297
+ 'ELU': 'elu',
2298
+ 'GELU': 'gelu',
2299
+ 'GLU': None,
2300
+ 'Hardshrink': None,
2301
+ 'Hardsigmoid': 'hard_sigmoid',
2302
+ 'Hardswish': None,
2303
+ 'Hardtanh': None,
2304
+ 'LeakyReLU': 'LeakyReLU',
2305
+ 'LogSigmoid': None,
2306
+ 'LogSoftmax': None,
2307
+ 'Mish': None,
2308
+ 'MultiheadAttention': None,
2309
+ 'PReLU': 'PReLU',
2310
+ 'RReLU': None,
2311
+ 'ReLU': 'relu',
2312
+ 'ReLU6': 'relu',
2313
+ 'SELU': 'selu',
2314
+ 'SiLU': 'swish',
2315
+ 'Sigmoid': 'sigmoid',
2316
+ 'Softmax': 'softmax',
2317
+ 'Softmax2d': None,
2318
+ 'Softmin': None,
2319
+ 'Softplus': 'softplus',
2320
+ 'Softshrink': None,
2321
+ 'Softsign': 'softsign',
2322
+ 'Tanh': 'tanh',
2323
+ 'Tanhshrink': None,
2324
+ 'Threshold': None}
2325
+
2326
+ keras_torch_act = {
2327
+ 'ELU': 'ELU',
2328
+ 'LeakyReLU': 'LeakyReLU',
2329
+ 'PReLU': 'PReLU',
2330
+ 'ReLU': 'ReLU',
2331
+ 'Softmax': 'Softmax',
2332
+ 'ThresholdedReLU': None,
2333
+ 'elu': 'ELU',
2334
+ 'exponential': None,
2335
+ 'gelu': 'GELU',
2336
+ 'hard_sigmoid': 'Hardsigmoid',
2337
+ 'relu': 'ReLU',
2338
+ 'selu': 'SELU',
2339
+ 'serialize': None,
2340
+ 'sigmoid': 'Sigmoid',
2341
+ 'softmax': 'Softmax',
2342
+ 'softplus': 'Softplus',
2343
+ 'softsign': 'Softsign',
2344
+ 'swish': 'SiLU',
2345
+ 'tanh': 'Tanh'}
2346
+
2347
+
2348
+ if direction == 'torch_to_keras' and activation:
2349
+
2350
+ return torch_keras_act
2351
+
2352
+ elif direction == 'kreas_to_torch' and not activation:
2353
+
2354
+ return keras_torch_act
2355
+
2356
+ elif direction == 'torch_to_keras':
2357
+
2358
+ return torch_keras
2359
+
2360
+ elif direction == 'keras_to_torch':
2361
+
2362
+ return keras_torch
2363
+
2364
+
2365
+ def rename_layers(in_layers, direction="torch_to_keras", activation=False):
2366
+
2367
+ mapping_dict = layer_mapping(direction=direction, activation=activation)
2368
+
2369
+ out_layers = []
2370
+
2371
+ for i in in_layers:
2372
+
2373
+ layer_name_temp = mapping_dict.get(i, None)
2374
+
2375
+ if layer_name_temp == None:
2376
+ out_layers.append(i)
2377
+ else:
2378
+ out_layers.append(layer_name_temp)
2379
+
2380
+ return out_layers
2381
+