aimodelshare 0.1.21__py3-none-any.whl → 0.1.62__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of aimodelshare might be problematic. Click here for more details.
- aimodelshare/__init__.py +94 -14
- aimodelshare/aimsonnx.py +417 -262
- aimodelshare/api.py +8 -7
- aimodelshare/auth.py +163 -0
- aimodelshare/aws.py +4 -4
- aimodelshare/base_image.py +1 -1
- aimodelshare/containerisation.py +1 -1
- aimodelshare/data_sharing/download_data.py +145 -88
- aimodelshare/generatemodelapi.py +7 -6
- aimodelshare/main/eval_lambda.txt +81 -13
- aimodelshare/model.py +493 -197
- aimodelshare/modeluser.py +89 -1
- aimodelshare/moral_compass/README.md +408 -0
- aimodelshare/moral_compass/__init__.py +37 -0
- aimodelshare/moral_compass/_version.py +3 -0
- aimodelshare/moral_compass/api_client.py +601 -0
- aimodelshare/moral_compass/apps/__init__.py +17 -0
- aimodelshare/moral_compass/apps/tutorial.py +198 -0
- aimodelshare/moral_compass/challenge.py +365 -0
- aimodelshare/moral_compass/config.py +187 -0
- aimodelshare/playground.py +26 -14
- aimodelshare/preprocessormodules.py +60 -6
- aimodelshare/reproducibility.py +20 -5
- aimodelshare/utils/__init__.py +78 -0
- aimodelshare/utils/optional_deps.py +38 -0
- aimodelshare-0.1.62.dist-info/METADATA +298 -0
- {aimodelshare-0.1.21.dist-info → aimodelshare-0.1.62.dist-info}/RECORD +30 -22
- {aimodelshare-0.1.21.dist-info → aimodelshare-0.1.62.dist-info}/WHEEL +1 -1
- aimodelshare-0.1.62.dist-info/licenses/LICENSE +5 -0
- {aimodelshare-0.1.21.dist-info → aimodelshare-0.1.62.dist-info}/top_level.txt +0 -1
- aimodelshare-0.1.21.dist-info/LICENSE +0 -22
- aimodelshare-0.1.21.dist-info/METADATA +0 -68
- tests/__init__.py +0 -0
- tests/test_aimsonnx.py +0 -135
- tests/test_playground.py +0 -721
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Configuration module for moral_compass API client.
|
|
3
|
+
|
|
4
|
+
Provides API base URL discovery via:
|
|
5
|
+
1. Environment variable MORAL_COMPASS_API_BASE_URL or AIMODELSHARE_API_BASE_URL
|
|
6
|
+
2. Cached terraform outputs file (infra/terraform_outputs.json)
|
|
7
|
+
3. Terraform command execution (fallback)
|
|
8
|
+
|
|
9
|
+
Also provides AWS region discovery for region-aware table naming.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import os
|
|
13
|
+
import json
|
|
14
|
+
import logging
|
|
15
|
+
import subprocess
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
from typing import Optional
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger("aimodelshare.moral_compass")
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def get_aws_region() -> Optional[str]:
|
|
23
|
+
"""
|
|
24
|
+
Discover AWS region from multiple sources.
|
|
25
|
+
|
|
26
|
+
Resolution order:
|
|
27
|
+
1. AWS_REGION environment variable
|
|
28
|
+
2. AWS_DEFAULT_REGION environment variable
|
|
29
|
+
3. Cached terraform outputs file
|
|
30
|
+
4. None (caller should handle default)
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
Optional[str]: AWS region name or None
|
|
34
|
+
"""
|
|
35
|
+
# Strategy 1: Check environment variables
|
|
36
|
+
region = os.getenv("AWS_REGION") or os.getenv("AWS_DEFAULT_REGION")
|
|
37
|
+
if region:
|
|
38
|
+
logger.debug(f"Using AWS region from environment: {region}")
|
|
39
|
+
return region
|
|
40
|
+
|
|
41
|
+
# Strategy 2: Try cached terraform outputs
|
|
42
|
+
cached_region = _get_region_from_cached_outputs()
|
|
43
|
+
if cached_region:
|
|
44
|
+
logger.debug(f"Using AWS region from cached terraform outputs: {cached_region}")
|
|
45
|
+
return cached_region
|
|
46
|
+
|
|
47
|
+
# No region found - return None and let caller decide default
|
|
48
|
+
logger.debug("AWS region not found, caller should use default")
|
|
49
|
+
return None
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def get_api_base_url() -> str:
|
|
53
|
+
"""
|
|
54
|
+
Discover API base URL using multiple strategies in order:
|
|
55
|
+
1. Environment variables (MORAL_COMPASS_API_BASE_URL or AIMODELSHARE_API_BASE_URL)
|
|
56
|
+
2. Cached terraform outputs file
|
|
57
|
+
3. Terraform command execution
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
str: The API base URL
|
|
61
|
+
|
|
62
|
+
Raises:
|
|
63
|
+
RuntimeError: If API base URL cannot be determined
|
|
64
|
+
"""
|
|
65
|
+
# Strategy 1: Check environment variables
|
|
66
|
+
env_url = os.getenv("MORAL_COMPASS_API_BASE_URL") or os.getenv("AIMODELSHARE_API_BASE_URL")
|
|
67
|
+
if env_url:
|
|
68
|
+
logger.debug(f"Using API base URL from environment: {env_url}")
|
|
69
|
+
return env_url.rstrip("/")
|
|
70
|
+
|
|
71
|
+
# Strategy 2: Try cached terraform outputs
|
|
72
|
+
cached_url = _get_url_from_cached_outputs()
|
|
73
|
+
if cached_url:
|
|
74
|
+
logger.debug(f"Using API base URL from cached terraform outputs: {cached_url}")
|
|
75
|
+
return cached_url
|
|
76
|
+
|
|
77
|
+
# Strategy 3: Try terraform command (last resort)
|
|
78
|
+
terraform_url = _get_url_from_terraform_command()
|
|
79
|
+
if terraform_url:
|
|
80
|
+
logger.debug(f"Using API base URL from terraform command: {terraform_url}")
|
|
81
|
+
return terraform_url
|
|
82
|
+
|
|
83
|
+
raise RuntimeError(
|
|
84
|
+
"Could not determine API base URL. Please set MORAL_COMPASS_API_BASE_URL "
|
|
85
|
+
"environment variable or ensure terraform outputs are accessible."
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _get_url_from_cached_outputs() -> Optional[str]:
|
|
90
|
+
"""
|
|
91
|
+
Read API base URL from cached terraform outputs file.
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
Optional[str]: API base URL if found in cache, None otherwise
|
|
95
|
+
"""
|
|
96
|
+
# Look for terraform_outputs.json in infra directory
|
|
97
|
+
repo_root = Path(__file__).parent.parent.parent.parent
|
|
98
|
+
outputs_file = repo_root / "infra" / "terraform_outputs.json"
|
|
99
|
+
|
|
100
|
+
if not outputs_file.exists():
|
|
101
|
+
logger.debug(f"Cached terraform outputs not found at {outputs_file}")
|
|
102
|
+
return None
|
|
103
|
+
|
|
104
|
+
try:
|
|
105
|
+
with open(outputs_file, "r") as f:
|
|
106
|
+
outputs = json.load(f)
|
|
107
|
+
|
|
108
|
+
# Handle both formats: {"api_base_url": {"value": "..."}} or {"api_base_url": "..."}
|
|
109
|
+
api_base_url = outputs.get("api_base_url")
|
|
110
|
+
if isinstance(api_base_url, dict):
|
|
111
|
+
url = api_base_url.get("value")
|
|
112
|
+
else:
|
|
113
|
+
url = api_base_url
|
|
114
|
+
|
|
115
|
+
if url and url != "null":
|
|
116
|
+
return url.rstrip("/")
|
|
117
|
+
except (json.JSONDecodeError, IOError) as e:
|
|
118
|
+
logger.warning(f"Error reading cached terraform outputs: {e}")
|
|
119
|
+
|
|
120
|
+
return None
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def _get_region_from_cached_outputs() -> Optional[str]:
|
|
124
|
+
"""
|
|
125
|
+
Read AWS region from cached terraform outputs file.
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
Optional[str]: AWS region if found in cache, None otherwise
|
|
129
|
+
"""
|
|
130
|
+
# Look for terraform_outputs.json in infra directory
|
|
131
|
+
repo_root = Path(__file__).parent.parent.parent.parent
|
|
132
|
+
outputs_file = repo_root / "infra" / "terraform_outputs.json"
|
|
133
|
+
|
|
134
|
+
if not outputs_file.exists():
|
|
135
|
+
logger.debug(f"Cached terraform outputs not found at {outputs_file}")
|
|
136
|
+
return None
|
|
137
|
+
|
|
138
|
+
try:
|
|
139
|
+
with open(outputs_file, "r") as f:
|
|
140
|
+
outputs = json.load(f)
|
|
141
|
+
|
|
142
|
+
# Handle both formats: {"region": {"value": "..."}} or {"region": "..."}
|
|
143
|
+
region = outputs.get("region") or outputs.get("aws_region")
|
|
144
|
+
if isinstance(region, dict):
|
|
145
|
+
region_value = region.get("value")
|
|
146
|
+
else:
|
|
147
|
+
region_value = region
|
|
148
|
+
|
|
149
|
+
if region_value and region_value != "null":
|
|
150
|
+
return region_value
|
|
151
|
+
except (json.JSONDecodeError, IOError) as e:
|
|
152
|
+
logger.warning(f"Error reading cached terraform outputs: {e}")
|
|
153
|
+
|
|
154
|
+
return None
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def _get_url_from_terraform_command() -> Optional[str]:
|
|
158
|
+
"""
|
|
159
|
+
Execute terraform command to get API base URL.
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
Optional[str]: API base URL if terraform command succeeds, None otherwise
|
|
163
|
+
"""
|
|
164
|
+
repo_root = Path(__file__).parent.parent.parent.parent
|
|
165
|
+
infra_dir = repo_root / "infra"
|
|
166
|
+
|
|
167
|
+
if not infra_dir.exists():
|
|
168
|
+
logger.debug(f"Infra directory not found at {infra_dir}")
|
|
169
|
+
return None
|
|
170
|
+
|
|
171
|
+
try:
|
|
172
|
+
result = subprocess.run(
|
|
173
|
+
["terraform", "output", "-raw", "api_base_url"],
|
|
174
|
+
cwd=infra_dir,
|
|
175
|
+
capture_output=True,
|
|
176
|
+
text=True,
|
|
177
|
+
timeout=10
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
if result.returncode == 0:
|
|
181
|
+
url = result.stdout.strip()
|
|
182
|
+
if url and url != "null":
|
|
183
|
+
return url.rstrip("/")
|
|
184
|
+
except (subprocess.TimeoutExpired, FileNotFoundError, subprocess.SubprocessError) as e:
|
|
185
|
+
logger.debug(f"Terraform command failed: {e}")
|
|
186
|
+
|
|
187
|
+
return None
|
aimodelshare/playground.py
CHANGED
|
@@ -1246,13 +1246,19 @@ class ModelPlayground:
|
|
|
1246
1246
|
with HiddenPrints():
|
|
1247
1247
|
competition = Competition(self.playground_url)
|
|
1248
1248
|
|
|
1249
|
-
|
|
1250
|
-
|
|
1251
|
-
|
|
1252
|
-
|
|
1253
|
-
|
|
1254
|
-
|
|
1255
|
-
|
|
1249
|
+
comp_result = competition.submit_model(model=model,
|
|
1250
|
+
prediction_submission=prediction_submission,
|
|
1251
|
+
preprocessor=preprocessor,
|
|
1252
|
+
reproducibility_env_filepath=reproducibility_env_filepath,
|
|
1253
|
+
custom_metadata=custom_metadata,
|
|
1254
|
+
input_dict=input_dict,
|
|
1255
|
+
print_output=False)
|
|
1256
|
+
|
|
1257
|
+
# Validate return structure before unpacking
|
|
1258
|
+
if not isinstance(comp_result, tuple) or len(comp_result) != 2:
|
|
1259
|
+
raise RuntimeError(f"Invalid return from competition.submit_model: expected (version, url) tuple, got {type(comp_result)}")
|
|
1260
|
+
|
|
1261
|
+
version_comp, model_page = comp_result
|
|
1256
1262
|
|
|
1257
1263
|
print(f"Your model has been submitted to competition as model version {version_comp}.")
|
|
1258
1264
|
|
|
@@ -1260,13 +1266,19 @@ class ModelPlayground:
|
|
|
1260
1266
|
with HiddenPrints():
|
|
1261
1267
|
experiment = Experiment(self.playground_url)
|
|
1262
1268
|
|
|
1263
|
-
|
|
1264
|
-
|
|
1265
|
-
|
|
1266
|
-
|
|
1267
|
-
|
|
1268
|
-
|
|
1269
|
-
|
|
1269
|
+
exp_result = experiment.submit_model(model=model,
|
|
1270
|
+
prediction_submission=prediction_submission,
|
|
1271
|
+
preprocessor=preprocessor,
|
|
1272
|
+
reproducibility_env_filepath=reproducibility_env_filepath,
|
|
1273
|
+
custom_metadata=custom_metadata,
|
|
1274
|
+
input_dict=input_dict,
|
|
1275
|
+
print_output=False)
|
|
1276
|
+
|
|
1277
|
+
# Validate return structure before unpacking
|
|
1278
|
+
if not isinstance(exp_result, tuple) or len(exp_result) != 2:
|
|
1279
|
+
raise RuntimeError(f"Invalid return from experiment.submit_model: expected (version, url) tuple, got {type(exp_result)}")
|
|
1280
|
+
|
|
1281
|
+
version_exp, model_page = exp_result
|
|
1270
1282
|
|
|
1271
1283
|
print(f"Your model has been submitted to experiment as model version {version_exp}.")
|
|
1272
1284
|
|
|
@@ -116,6 +116,26 @@ def import_preprocessor(filepath):
|
|
|
116
116
|
|
|
117
117
|
return preprocessor
|
|
118
118
|
|
|
119
|
+
def _test_object_serialization(obj, obj_name):
|
|
120
|
+
"""
|
|
121
|
+
Test if an object can be serialized with pickle.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
obj: Object to test
|
|
125
|
+
obj_name: Name of the object for error reporting
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
tuple: (success: bool, error_msg: str or None)
|
|
129
|
+
"""
|
|
130
|
+
import pickle
|
|
131
|
+
|
|
132
|
+
try:
|
|
133
|
+
pickle.dumps(obj)
|
|
134
|
+
return True, None
|
|
135
|
+
except Exception as e:
|
|
136
|
+
return False, f"{type(e).__name__}: {str(e)}"
|
|
137
|
+
|
|
138
|
+
|
|
119
139
|
def export_preprocessor(preprocessor_fxn,directory, globs=globals()):
|
|
120
140
|
"""
|
|
121
141
|
Exports preprocessor and related objects into zip file for model deployment
|
|
@@ -167,7 +187,7 @@ def export_preprocessor(preprocessor_fxn,directory, globs=globals()):
|
|
|
167
187
|
function_objects=list(inspect.getclosurevars(preprocessor_fxn).globals.keys())
|
|
168
188
|
|
|
169
189
|
import sys
|
|
170
|
-
import
|
|
190
|
+
import importlib.util
|
|
171
191
|
modulenames = ["sklearn","keras","tensorflow","cv2","resize","pytorch","librosa","pyspark"]
|
|
172
192
|
|
|
173
193
|
# List all standard libraries not covered by sys.builtin_module_names
|
|
@@ -185,9 +205,12 @@ def export_preprocessor(preprocessor_fxn,directory, globs=globals()):
|
|
|
185
205
|
modulenames.append(module_name)
|
|
186
206
|
continue
|
|
187
207
|
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
208
|
+
# Use importlib.util instead of deprecated imp
|
|
209
|
+
spec = importlib.util.find_spec(module_name)
|
|
210
|
+
if spec and spec.origin:
|
|
211
|
+
module_path = spec.origin
|
|
212
|
+
if os.path.dirname(module_path) in stdlib:
|
|
213
|
+
modulenames.append(module_name)
|
|
191
214
|
except Exception as e:
|
|
192
215
|
# print(e)
|
|
193
216
|
continue
|
|
@@ -232,12 +255,19 @@ def export_preprocessor(preprocessor_fxn,directory, globs=globals()):
|
|
|
232
255
|
|
|
233
256
|
export_methods = []
|
|
234
257
|
savedpreprocessorobjectslist = []
|
|
258
|
+
failed_objects = [] # Track failed serializations for better diagnostics
|
|
259
|
+
|
|
235
260
|
for function_objects_nomodule in function_objects_nomodules:
|
|
236
261
|
try:
|
|
237
262
|
savedpreprocessorobjectslist.append(savetopickle(function_objects_nomodule))
|
|
238
263
|
export_methods.append("pickle")
|
|
239
264
|
except Exception as e:
|
|
240
|
-
#
|
|
265
|
+
# Track this failure for diagnostics
|
|
266
|
+
can_serialize, error_msg = _test_object_serialization(
|
|
267
|
+
globals().get(function_objects_nomodule),
|
|
268
|
+
function_objects_nomodule
|
|
269
|
+
)
|
|
270
|
+
|
|
241
271
|
try:
|
|
242
272
|
os.remove(os.path.join(temp_dir, function_objects_nomodule+".pkl"))
|
|
243
273
|
except:
|
|
@@ -246,7 +276,14 @@ def export_preprocessor(preprocessor_fxn,directory, globs=globals()):
|
|
|
246
276
|
try:
|
|
247
277
|
savedpreprocessorobjectslist.append(save_to_zip(function_objects_nomodule))
|
|
248
278
|
export_methods.append("zip")
|
|
249
|
-
except Exception as
|
|
279
|
+
except Exception as zip_e:
|
|
280
|
+
# Both pickle and zip failed - record this
|
|
281
|
+
failed_objects.append({
|
|
282
|
+
'name': function_objects_nomodule,
|
|
283
|
+
'type': type(globals().get(function_objects_nomodule, None)).__name__,
|
|
284
|
+
'pickle_error': str(e),
|
|
285
|
+
'zip_error': str(zip_e)
|
|
286
|
+
})
|
|
250
287
|
# print(e)
|
|
251
288
|
pass
|
|
252
289
|
|
|
@@ -265,6 +302,20 @@ def export_preprocessor(preprocessor_fxn,directory, globs=globals()):
|
|
|
265
302
|
# close the Zip File
|
|
266
303
|
zipObj.close()
|
|
267
304
|
|
|
305
|
+
# If any critical objects failed to serialize, raise an error with details
|
|
306
|
+
if failed_objects:
|
|
307
|
+
failed_names = [obj['name'] for obj in failed_objects]
|
|
308
|
+
error_details = "\n".join([
|
|
309
|
+
f" - {obj['name']} (type: {obj['type']}): {obj['pickle_error'][:100]}"
|
|
310
|
+
for obj in failed_objects
|
|
311
|
+
])
|
|
312
|
+
raise RuntimeError(
|
|
313
|
+
f"Preprocessor export encountered serialization failures for {len(failed_objects)} closure variable(s): "
|
|
314
|
+
f"{', '.join(failed_names)}.\n\nDetails:\n{error_details}\n\n"
|
|
315
|
+
f"These objects are referenced by your preprocessor function but cannot be serialized. "
|
|
316
|
+
f"Common causes include open file handles, database connections, or thread locks."
|
|
317
|
+
)
|
|
318
|
+
|
|
268
319
|
try:
|
|
269
320
|
# clean up temp directory files for future runs
|
|
270
321
|
os.remove(os.path.join(temp_dir,"preprocessor.py"))
|
|
@@ -279,6 +330,9 @@ def export_preprocessor(preprocessor_fxn,directory, globs=globals()):
|
|
|
279
330
|
pass
|
|
280
331
|
|
|
281
332
|
except Exception as e:
|
|
333
|
+
# Re-raise RuntimeError with preserved message
|
|
334
|
+
if isinstance(e, RuntimeError):
|
|
335
|
+
raise
|
|
282
336
|
print(e)
|
|
283
337
|
|
|
284
338
|
return print("Your preprocessor is now saved to 'preprocessor.zip'")
|
aimodelshare/reproducibility.py
CHANGED
|
@@ -3,11 +3,22 @@ import sys
|
|
|
3
3
|
import json
|
|
4
4
|
import random
|
|
5
5
|
import tempfile
|
|
6
|
-
import pkg_resources
|
|
7
6
|
import requests
|
|
8
7
|
|
|
9
8
|
import numpy as np
|
|
10
|
-
|
|
9
|
+
|
|
10
|
+
# TensorFlow is optional - only needed for reproducibility setup with TF models
|
|
11
|
+
try:
|
|
12
|
+
import tensorflow as tf
|
|
13
|
+
_TF_AVAILABLE = True
|
|
14
|
+
except ImportError:
|
|
15
|
+
_TF_AVAILABLE = False
|
|
16
|
+
tf = None
|
|
17
|
+
|
|
18
|
+
try:
|
|
19
|
+
import importlib.metadata as md
|
|
20
|
+
except ImportError: # pragma: no cover
|
|
21
|
+
import importlib_metadata as md
|
|
11
22
|
|
|
12
23
|
from aimodelshare.aws import get_s3_iam_client, run_function_on_lambda, get_aws_client
|
|
13
24
|
|
|
@@ -44,9 +55,13 @@ def export_reproducibility_env(seed, directory, mode="gpu"):
|
|
|
44
55
|
else:
|
|
45
56
|
raise Exception("Error: unknown 'mode' value, expected 'gpu' or 'cpu'")
|
|
46
57
|
|
|
47
|
-
|
|
48
|
-
installed_packages_list =
|
|
49
|
-
|
|
58
|
+
# Get installed packages using importlib.metadata
|
|
59
|
+
installed_packages_list = []
|
|
60
|
+
for dist in md.distributions():
|
|
61
|
+
name = dist.metadata.get("Name") or "unknown"
|
|
62
|
+
version = dist.version
|
|
63
|
+
installed_packages_list.append(f"{name}=={version}")
|
|
64
|
+
installed_packages_list = sorted(installed_packages_list)
|
|
50
65
|
|
|
51
66
|
data["session_runtime_info"] = {
|
|
52
67
|
"installed_packages": installed_packages_list,
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
"""Utility modules for aimodelshare."""
|
|
2
|
+
import os
|
|
3
|
+
import sys
|
|
4
|
+
import shutil
|
|
5
|
+
import tempfile
|
|
6
|
+
import functools
|
|
7
|
+
import warnings
|
|
8
|
+
from typing import Type
|
|
9
|
+
|
|
10
|
+
from .optional_deps import check_optional
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def delete_files_from_temp_dir(temp_dir_file_deletion_list):
|
|
14
|
+
temp_dir = tempfile.gettempdir()
|
|
15
|
+
for file_name in temp_dir_file_deletion_list:
|
|
16
|
+
file_path = os.path.join(temp_dir, file_name)
|
|
17
|
+
if os.path.exists(file_path):
|
|
18
|
+
os.remove(file_path)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def delete_folder(folder_path):
|
|
22
|
+
if os.path.exists(folder_path):
|
|
23
|
+
shutil.rmtree(folder_path)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def make_folder(folder_path):
|
|
27
|
+
os.makedirs(folder_path, exist_ok=True)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class HiddenPrints:
|
|
31
|
+
"""Context manager that suppresses stdout and stderr (used for silencing noisy outputs)."""
|
|
32
|
+
def __enter__(self):
|
|
33
|
+
self._original_stdout = sys.stdout
|
|
34
|
+
self._original_stderr = sys.stderr
|
|
35
|
+
self._devnull_stdout = open(os.devnull, 'w')
|
|
36
|
+
self._devnull_stderr = open(os.devnull, 'w')
|
|
37
|
+
sys.stdout = self._devnull_stdout
|
|
38
|
+
sys.stderr = self._devnull_stderr
|
|
39
|
+
return self
|
|
40
|
+
|
|
41
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
42
|
+
sys.stdout = self._original_stdout
|
|
43
|
+
sys.stderr = self._original_stderr
|
|
44
|
+
self._devnull_stdout.close()
|
|
45
|
+
self._devnull_stderr.close()
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def ignore_warning(warning: Type[Warning]):
|
|
49
|
+
"""
|
|
50
|
+
Ignore a given warning occurring during method execution.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
warning (Warning): warning type to ignore.
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
the inner function
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def inner(func):
|
|
60
|
+
@functools.wraps(func)
|
|
61
|
+
def wrapper(*args, **kwargs):
|
|
62
|
+
with warnings.catch_warnings():
|
|
63
|
+
warnings.filterwarnings("ignore", category=warning)
|
|
64
|
+
return func(*args, **kwargs)
|
|
65
|
+
|
|
66
|
+
return wrapper
|
|
67
|
+
|
|
68
|
+
return inner
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
__all__ = [
|
|
72
|
+
"check_optional",
|
|
73
|
+
"HiddenPrints",
|
|
74
|
+
"ignore_warning",
|
|
75
|
+
"delete_files_from_temp_dir",
|
|
76
|
+
"delete_folder",
|
|
77
|
+
"make_folder",
|
|
78
|
+
]
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
"""Optional dependency checking utilities."""
|
|
2
|
+
import os
|
|
3
|
+
import importlib.util
|
|
4
|
+
import warnings
|
|
5
|
+
|
|
6
|
+
_DEF_SUPPRESS_ENV = "AIMODELSHARE_SUPPRESS_OPTIONAL_WARNINGS"
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def check_optional(name: str, feature_label: str, suppress_env: str = _DEF_SUPPRESS_ENV) -> bool:
|
|
10
|
+
"""Check if an optional dependency is available.
|
|
11
|
+
|
|
12
|
+
Print a single warning (via warnings) if missing and suppression env var is not set.
|
|
13
|
+
Returns True if available, False otherwise.
|
|
14
|
+
|
|
15
|
+
Parameters
|
|
16
|
+
----------
|
|
17
|
+
name : str
|
|
18
|
+
The name of the module to check (e.g., 'xgboost', 'pyspark')
|
|
19
|
+
feature_label : str
|
|
20
|
+
A human-readable label for the feature that requires this dependency
|
|
21
|
+
suppress_env : str, optional
|
|
22
|
+
Environment variable name to check for suppression (default: AIMODELSHARE_SUPPRESS_OPTIONAL_WARNINGS)
|
|
23
|
+
|
|
24
|
+
Returns
|
|
25
|
+
-------
|
|
26
|
+
bool
|
|
27
|
+
True if the module is available, False otherwise
|
|
28
|
+
"""
|
|
29
|
+
spec = importlib.util.find_spec(name)
|
|
30
|
+
if spec is None:
|
|
31
|
+
if not os.environ.get(suppress_env):
|
|
32
|
+
warnings.warn(
|
|
33
|
+
f"{feature_label} support unavailable. Install `{name}` to enable.",
|
|
34
|
+
category=UserWarning,
|
|
35
|
+
stacklevel=2,
|
|
36
|
+
)
|
|
37
|
+
return False
|
|
38
|
+
return True
|