workbench 0.8.162__py3-none-any.whl → 0.8.220__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of workbench might be problematic. Click here for more details.
- workbench/algorithms/dataframe/__init__.py +1 -2
- workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
- workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
- workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
- workbench/algorithms/dataframe/projection_2d.py +44 -21
- workbench/algorithms/dataframe/proximity.py +259 -305
- workbench/algorithms/graph/light/proximity_graph.py +14 -12
- workbench/algorithms/models/cleanlab_model.py +382 -0
- workbench/algorithms/models/noise_model.py +388 -0
- workbench/algorithms/sql/outliers.py +3 -3
- workbench/api/__init__.py +5 -1
- workbench/api/compound.py +1 -1
- workbench/api/df_store.py +17 -108
- workbench/api/endpoint.py +18 -5
- workbench/api/feature_set.py +121 -15
- workbench/api/meta.py +5 -2
- workbench/api/meta_model.py +289 -0
- workbench/api/model.py +55 -21
- workbench/api/monitor.py +1 -16
- workbench/api/parameter_store.py +3 -52
- workbench/cached/cached_model.py +4 -4
- workbench/core/artifacts/__init__.py +11 -2
- workbench/core/artifacts/artifact.py +16 -8
- workbench/core/artifacts/data_capture_core.py +355 -0
- workbench/core/artifacts/df_store_core.py +114 -0
- workbench/core/artifacts/endpoint_core.py +382 -253
- workbench/core/artifacts/feature_set_core.py +249 -45
- workbench/core/artifacts/model_core.py +135 -80
- workbench/core/artifacts/monitor_core.py +33 -248
- workbench/core/artifacts/parameter_store_core.py +98 -0
- workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
- workbench/core/cloud_platform/aws/aws_meta.py +12 -5
- workbench/core/cloud_platform/aws/aws_session.py +4 -4
- workbench/core/pipelines/pipeline_executor.py +1 -1
- workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
- workbench/core/transforms/features_to_model/features_to_model.py +62 -40
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +76 -15
- workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
- workbench/core/views/training_view.py +113 -42
- workbench/core/views/view.py +53 -3
- workbench/core/views/view_utils.py +4 -4
- workbench/model_script_utils/model_script_utils.py +339 -0
- workbench/model_script_utils/pytorch_utils.py +405 -0
- workbench/model_script_utils/uq_harness.py +278 -0
- workbench/model_scripts/chemprop/chemprop.template +649 -0
- workbench/model_scripts/chemprop/generated_model_script.py +649 -0
- workbench/model_scripts/chemprop/model_script_utils.py +339 -0
- workbench/model_scripts/chemprop/requirements.txt +3 -0
- workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
- workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
- workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
- workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
- workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
- workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
- workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
- workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
- workbench/model_scripts/custom_models/uq_models/ngboost.template +30 -18
- workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
- workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
- workbench/model_scripts/meta_model/generated_model_script.py +209 -0
- workbench/model_scripts/meta_model/meta_model.template +209 -0
- workbench/model_scripts/pytorch_model/generated_model_script.py +444 -500
- workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
- workbench/model_scripts/pytorch_model/pytorch.template +440 -496
- workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
- workbench/model_scripts/pytorch_model/requirements.txt +1 -1
- workbench/model_scripts/pytorch_model/uq_harness.py +278 -0
- workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
- workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
- workbench/model_scripts/script_generation.py +20 -11
- workbench/model_scripts/uq_models/generated_model_script.py +248 -0
- workbench/model_scripts/xgb_model/generated_model_script.py +372 -404
- workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
- workbench/model_scripts/xgb_model/uq_harness.py +278 -0
- workbench/model_scripts/xgb_model/xgb_model.template +369 -401
- workbench/repl/workbench_shell.py +28 -19
- workbench/resources/open_source_api.key +1 -1
- workbench/scripts/endpoint_test.py +162 -0
- workbench/scripts/lambda_test.py +73 -0
- workbench/scripts/meta_model_sim.py +35 -0
- workbench/scripts/ml_pipeline_batch.py +137 -0
- workbench/scripts/ml_pipeline_sqs.py +186 -0
- workbench/scripts/monitor_cloud_watch.py +20 -100
- workbench/scripts/training_test.py +85 -0
- workbench/utils/aws_utils.py +4 -3
- workbench/utils/chem_utils/__init__.py +0 -0
- workbench/utils/chem_utils/fingerprints.py +175 -0
- workbench/utils/chem_utils/misc.py +194 -0
- workbench/utils/chem_utils/mol_descriptors.py +483 -0
- workbench/utils/chem_utils/mol_standardize.py +450 -0
- workbench/utils/chem_utils/mol_tagging.py +348 -0
- workbench/utils/chem_utils/projections.py +219 -0
- workbench/utils/chem_utils/salts.py +256 -0
- workbench/utils/chem_utils/sdf.py +292 -0
- workbench/utils/chem_utils/toxicity.py +250 -0
- workbench/utils/chem_utils/vis.py +253 -0
- workbench/utils/chemprop_utils.py +141 -0
- workbench/utils/cloudwatch_handler.py +1 -1
- workbench/utils/cloudwatch_utils.py +137 -0
- workbench/utils/config_manager.py +3 -7
- workbench/utils/endpoint_utils.py +5 -7
- workbench/utils/license_manager.py +2 -6
- workbench/utils/meta_model_simulator.py +499 -0
- workbench/utils/metrics_utils.py +256 -0
- workbench/utils/model_utils.py +278 -79
- workbench/utils/monitor_utils.py +44 -62
- workbench/utils/pandas_utils.py +3 -3
- workbench/utils/pytorch_utils.py +87 -0
- workbench/utils/shap_utils.py +11 -57
- workbench/utils/workbench_logging.py +0 -3
- workbench/utils/workbench_sqs.py +1 -1
- workbench/utils/xgboost_local_crossfold.py +267 -0
- workbench/utils/xgboost_model_utils.py +127 -219
- workbench/web_interface/components/model_plot.py +14 -2
- workbench/web_interface/components/plugin_unit_test.py +5 -2
- workbench/web_interface/components/plugins/dashboard_status.py +3 -1
- workbench/web_interface/components/plugins/generated_compounds.py +1 -1
- workbench/web_interface/components/plugins/model_details.py +38 -74
- workbench/web_interface/components/plugins/scatter_plot.py +6 -10
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/METADATA +31 -9
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/RECORD +128 -96
- workbench-0.8.220.dist-info/entry_points.txt +11 -0
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/licenses/LICENSE +1 -1
- workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
- workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
- workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
- workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
- workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
- workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
- workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
- workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
- workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
- workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -273
- workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
- workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
- workbench/model_scripts/quant_regression/quant_regression.template +0 -279
- workbench/model_scripts/quant_regression/requirements.txt +0 -1
- workbench/utils/chem_utils.py +0 -1556
- workbench/utils/execution_environment.py +0 -211
- workbench/utils/fast_inference.py +0 -167
- workbench/utils/resource_utils.py +0 -39
- workbench-0.8.162.dist-info/entry_points.txt +0 -5
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/WHEEL +0 -0
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/top_level.txt +0 -0
|
@@ -1,14 +1,25 @@
|
|
|
1
|
+
# flake8: noqa: E402
|
|
2
|
+
import os
|
|
3
|
+
import sys
|
|
4
|
+
import logging
|
|
5
|
+
import importlib
|
|
6
|
+
import webbrowser
|
|
7
|
+
import readline # noqa: F401
|
|
8
|
+
|
|
9
|
+
# Disable OpenMP parallelism to avoid segfaults with PyTorch in iPython
|
|
10
|
+
# This is a known issue on macOS where libomp crashes during thread synchronization
|
|
11
|
+
# Must be set before importing numpy/pandas/torch or any library that uses OpenMP
|
|
12
|
+
os.environ.setdefault("OMP_NUM_THREADS", "1")
|
|
13
|
+
os.environ.setdefault("MKL_NUM_THREADS", "1")
|
|
14
|
+
|
|
15
|
+
import IPython
|
|
1
16
|
from IPython import start_ipython
|
|
17
|
+
from distutils.version import LooseVersion
|
|
2
18
|
from IPython.terminal.prompts import Prompts
|
|
3
19
|
from IPython.terminal.ipapp import load_default_config
|
|
4
20
|
from pygments.token import Token
|
|
5
|
-
import sys
|
|
6
|
-
import logging
|
|
7
|
-
import importlib
|
|
8
21
|
import botocore
|
|
9
|
-
import webbrowser
|
|
10
22
|
import pandas as pd
|
|
11
|
-
import readline # noqa
|
|
12
23
|
|
|
13
24
|
try:
|
|
14
25
|
import matplotlib.pyplot as plt # noqa
|
|
@@ -39,7 +50,7 @@ from workbench.cached.cached_meta import CachedMeta
|
|
|
39
50
|
try:
|
|
40
51
|
import rdkit # noqa
|
|
41
52
|
import mordred # noqa
|
|
42
|
-
from workbench.utils import
|
|
53
|
+
from workbench.utils.chem_utils import vis
|
|
43
54
|
|
|
44
55
|
HAVE_CHEM_UTILS = True
|
|
45
56
|
except ImportError:
|
|
@@ -70,7 +81,7 @@ if not ConfigManager().config_okay():
|
|
|
70
81
|
|
|
71
82
|
# Set the log level to important
|
|
72
83
|
log = logging.getLogger("workbench")
|
|
73
|
-
log.setLevel(
|
|
84
|
+
log.setLevel(logging.INFO)
|
|
74
85
|
log.addFilter(
|
|
75
86
|
lambda record: not (
|
|
76
87
|
record.getMessage().startswith("Async: Metadata") or record.getMessage().startswith("Updated Metadata")
|
|
@@ -176,12 +187,12 @@ class WorkbenchShell:
|
|
|
176
187
|
|
|
177
188
|
# Add cheminformatics utils if available
|
|
178
189
|
if HAVE_CHEM_UTILS:
|
|
179
|
-
self.commands["show"] =
|
|
190
|
+
self.commands["show"] = vis.show
|
|
180
191
|
|
|
181
192
|
def start(self):
|
|
182
193
|
"""Start the Workbench IPython shell"""
|
|
183
194
|
cprint("magenta", "\nWelcome to Workbench!")
|
|
184
|
-
if self.aws_status
|
|
195
|
+
if not self.aws_status:
|
|
185
196
|
cprint("red", "AWS Account Connection Failed...Review/Fix the Workbench Config:")
|
|
186
197
|
cprint("red", f"Path: {self.cm.site_config_path}")
|
|
187
198
|
self.show_config()
|
|
@@ -202,7 +213,10 @@ class WorkbenchShell:
|
|
|
202
213
|
|
|
203
214
|
# Start IPython with the config and commands in the namespace
|
|
204
215
|
try:
|
|
205
|
-
|
|
216
|
+
if LooseVersion(IPython.__version__) >= LooseVersion("9.0.0"):
|
|
217
|
+
ipython_argv = ["--no-tip", "--theme", "linux"]
|
|
218
|
+
else:
|
|
219
|
+
ipython_argv = []
|
|
206
220
|
start_ipython(ipython_argv, user_ns=locs, config=config)
|
|
207
221
|
finally:
|
|
208
222
|
spinner = self.spinner_start("Goodbye to AWS:")
|
|
@@ -288,11 +302,6 @@ class WorkbenchShell:
|
|
|
288
302
|
self.commands["PandasToView"] = importlib.import_module("workbench.core.views.pandas_to_view").PandasToView
|
|
289
303
|
self.commands["Pipeline"] = importlib.import_module("workbench.api.pipeline").Pipeline
|
|
290
304
|
|
|
291
|
-
# Algorithms
|
|
292
|
-
self.commands["FSP"] = importlib.import_module(
|
|
293
|
-
"workbench.algorithms.dataframe.feature_space_proximity"
|
|
294
|
-
).FeatureSpaceProximity
|
|
295
|
-
|
|
296
305
|
# These are 'nice to have' imports
|
|
297
306
|
self.commands["pd"] = importlib.import_module("pandas")
|
|
298
307
|
self.commands["wr"] = importlib.import_module("awswrangler")
|
|
@@ -520,7 +529,7 @@ class WorkbenchShell:
|
|
|
520
529
|
def get_meta(self):
|
|
521
530
|
return self.meta
|
|
522
531
|
|
|
523
|
-
def plot_manager(self, data, plot_type: str = "
|
|
532
|
+
def plot_manager(self, data, plot_type: str = "scatter", **kwargs):
|
|
524
533
|
"""Plot Manager for Workbench"""
|
|
525
534
|
from workbench.web_interface.components.plugins import ag_table, graph_plot, scatter_plot
|
|
526
535
|
|
|
@@ -555,14 +564,14 @@ class WorkbenchShell:
|
|
|
555
564
|
from workbench.web_interface.components.plugin_unit_test import PluginUnitTest
|
|
556
565
|
|
|
557
566
|
# Get kwargs
|
|
558
|
-
theme = kwargs.get("theme", "
|
|
567
|
+
theme = kwargs.get("theme", "midnight_blue")
|
|
559
568
|
|
|
560
569
|
plugin_test = PluginUnitTest(plugin_class, theme=theme, input_data=data, **kwargs)
|
|
561
570
|
|
|
562
|
-
#
|
|
563
|
-
plugin_test.run()
|
|
571
|
+
# Open the browser and run the dash server
|
|
564
572
|
url = f"http://127.0.0.1:{plugin_test.port}"
|
|
565
573
|
webbrowser.open(url)
|
|
574
|
+
plugin_test.run()
|
|
566
575
|
|
|
567
576
|
|
|
568
577
|
# Launch Shell Entry Point
|
|
@@ -1 +1 @@
|
|
|
1
|
-
|
|
1
|
+
eyJsaWNlbnNlX2lkIjogIk9wZW5fU291cmNlX0xpY2Vuc2UiLCAiY29tcGFueSI6ICIiLCAiYXdzX2FjY291bnRfaWQiOiAiIiwgInRpZXIiOiAiRW50ZXJwcmlzZSBQcm8iLCAiZmVhdHVyZXMiOiBbInBsdWdpbnMiLCAicGFnZXMiLCAidGhlbWVzIiwgInBpcGVsaW5lcyIsICJicmFuZGluZyJdLCAiZXhwaXJlcyI6ICIyMDI2LTEyLTA1In1IsmpkuybFALADkRj_RfmkQ0LAIsQeXRE7Uoc3DL1UrDr-rSnwu-PDqsKBUkX6jPRFZV3DLxNjBapxPeEIFhfvxvjzz_sc6CwtxNpZ3bPmxSPs2W-j3xZS4-XyEqIilcwSkWh-NU1u27gCuuivn5eiUmIYJGAp0wdVkeE6_Z9dlg==
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Local test harness for SageMaker model scripts.
|
|
3
|
+
|
|
4
|
+
Usage:
|
|
5
|
+
python model_script_harness.py <local_script.py> <model_name>
|
|
6
|
+
|
|
7
|
+
Example:
|
|
8
|
+
python model_script_harness.py pytorch.py aqsol-reg-pytorch
|
|
9
|
+
|
|
10
|
+
This allows you to test LOCAL changes to a model script against deployed model artifacts.
|
|
11
|
+
Evaluation data is automatically pulled from the FeatureSet (training = FALSE rows).
|
|
12
|
+
|
|
13
|
+
Optional: testing/env.json with additional environment variables
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
import sys
|
|
18
|
+
import json
|
|
19
|
+
import importlib.util
|
|
20
|
+
import tempfile
|
|
21
|
+
import shutil
|
|
22
|
+
import pandas as pd
|
|
23
|
+
import torch
|
|
24
|
+
|
|
25
|
+
# Workbench Imports
|
|
26
|
+
from workbench.api import Model, FeatureSet
|
|
27
|
+
from workbench.utils.pytorch_utils import download_and_extract_model
|
|
28
|
+
|
|
29
|
+
# Force CPU mode BEFORE any PyTorch imports to avoid MPS/CUDA issues on Mac
|
|
30
|
+
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
|
31
|
+
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
|
32
|
+
torch.set_default_device("cpu")
|
|
33
|
+
# Disable MPS entirely
|
|
34
|
+
if hasattr(torch.backends, "mps"):
|
|
35
|
+
torch.backends.mps.is_available = lambda: False
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def get_eval_data(workbench_model: Model) -> pd.DataFrame:
|
|
39
|
+
"""Get evaluation data from the FeatureSet associated with this model."""
|
|
40
|
+
# Get the FeatureSet
|
|
41
|
+
fs_name = workbench_model.get_input()
|
|
42
|
+
fs = FeatureSet(fs_name)
|
|
43
|
+
if not fs.exists():
|
|
44
|
+
raise ValueError(f"No FeatureSet found: {fs_name}")
|
|
45
|
+
|
|
46
|
+
# Get evaluation data (training = FALSE)
|
|
47
|
+
table = workbench_model.training_view().table
|
|
48
|
+
print(f"Querying evaluation data from {table}...")
|
|
49
|
+
eval_df = fs.query(f'SELECT * FROM "{table}" WHERE training = FALSE')
|
|
50
|
+
print(f"Retrieved {len(eval_df)} evaluation rows")
|
|
51
|
+
|
|
52
|
+
return eval_df
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def load_model_script(script_path: str):
|
|
56
|
+
"""Dynamically load the model script module."""
|
|
57
|
+
if not os.path.exists(script_path):
|
|
58
|
+
raise FileNotFoundError(f"Script not found: {script_path}")
|
|
59
|
+
|
|
60
|
+
spec = importlib.util.spec_from_file_location("model_script", script_path)
|
|
61
|
+
module = importlib.util.module_from_spec(spec)
|
|
62
|
+
|
|
63
|
+
# Add to sys.modules so imports within the script work
|
|
64
|
+
sys.modules["model_script"] = module
|
|
65
|
+
|
|
66
|
+
spec.loader.exec_module(module)
|
|
67
|
+
return module
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def main():
|
|
71
|
+
if len(sys.argv) < 3:
|
|
72
|
+
print("Usage: python model_script_harness.py <local_script.py> <model_name>")
|
|
73
|
+
print("\nArguments:")
|
|
74
|
+
print(" local_script.py - Path to your LOCAL model script to test")
|
|
75
|
+
print(" model_name - Workbench model name (e.g., aqsol-reg-pytorch)")
|
|
76
|
+
print("\nOptional: testing/env.json with additional environment variables")
|
|
77
|
+
sys.exit(1)
|
|
78
|
+
|
|
79
|
+
script_path = sys.argv[1]
|
|
80
|
+
model_name = sys.argv[2]
|
|
81
|
+
|
|
82
|
+
# Validate local script exists
|
|
83
|
+
if not os.path.exists(script_path):
|
|
84
|
+
print(f"Error: Local script not found: {script_path}")
|
|
85
|
+
sys.exit(1)
|
|
86
|
+
|
|
87
|
+
# Initialize Workbench model
|
|
88
|
+
print(f"Loading Workbench model: {model_name}")
|
|
89
|
+
workbench_model = Model(model_name)
|
|
90
|
+
print(f"Model Framework: {workbench_model.model_framework}")
|
|
91
|
+
print()
|
|
92
|
+
|
|
93
|
+
# Create a temporary model directory
|
|
94
|
+
model_dir = tempfile.mkdtemp(prefix="model_harness_")
|
|
95
|
+
print(f"Using model directory: {model_dir}")
|
|
96
|
+
|
|
97
|
+
try:
|
|
98
|
+
# Load environment variables from env.json if it exists
|
|
99
|
+
if os.path.exists("testing/env.json"):
|
|
100
|
+
print("Loading environment variables from testing/env.json")
|
|
101
|
+
with open("testing/env.json") as f:
|
|
102
|
+
env_vars = json.load(f)
|
|
103
|
+
for key, value in env_vars.items():
|
|
104
|
+
os.environ[key] = value
|
|
105
|
+
print(f" Set {key} = {value}")
|
|
106
|
+
print()
|
|
107
|
+
|
|
108
|
+
# Set up SageMaker environment variables
|
|
109
|
+
os.environ["SM_MODEL_DIR"] = model_dir
|
|
110
|
+
print(f"Set SM_MODEL_DIR = {model_dir}")
|
|
111
|
+
|
|
112
|
+
# Download and extract model artifacts
|
|
113
|
+
s3_uri = workbench_model.model_data_url()
|
|
114
|
+
download_and_extract_model(s3_uri, model_dir)
|
|
115
|
+
print()
|
|
116
|
+
|
|
117
|
+
# Load the LOCAL model script
|
|
118
|
+
print(f"Loading LOCAL model script: {script_path}")
|
|
119
|
+
module = load_model_script(script_path)
|
|
120
|
+
print()
|
|
121
|
+
|
|
122
|
+
# Check for required functions
|
|
123
|
+
if not hasattr(module, "model_fn"):
|
|
124
|
+
raise AttributeError("Model script must have a model_fn function")
|
|
125
|
+
if not hasattr(module, "predict_fn"):
|
|
126
|
+
raise AttributeError("Model script must have a predict_fn function")
|
|
127
|
+
|
|
128
|
+
# Load the model
|
|
129
|
+
print("Calling model_fn...")
|
|
130
|
+
print("-" * 50)
|
|
131
|
+
model = module.model_fn(model_dir)
|
|
132
|
+
print("-" * 50)
|
|
133
|
+
print(f"Model loaded: {type(model)}")
|
|
134
|
+
print()
|
|
135
|
+
|
|
136
|
+
# Get evaluation data from FeatureSet
|
|
137
|
+
print("Pulling evaluation data from FeatureSet...")
|
|
138
|
+
df = get_eval_data(workbench_model)
|
|
139
|
+
print(f"Input shape: {df.shape}")
|
|
140
|
+
print(f"Columns: {df.columns.tolist()}")
|
|
141
|
+
print()
|
|
142
|
+
|
|
143
|
+
print("Calling predict_fn...")
|
|
144
|
+
print("-" * 50)
|
|
145
|
+
result = module.predict_fn(df, model)
|
|
146
|
+
print("-" * 50)
|
|
147
|
+
print()
|
|
148
|
+
|
|
149
|
+
print("Prediction result:")
|
|
150
|
+
print(f"Output shape: {result.shape}")
|
|
151
|
+
print(f"Output columns: {result.columns.tolist()}")
|
|
152
|
+
print()
|
|
153
|
+
print(result.head(10).to_string())
|
|
154
|
+
|
|
155
|
+
finally:
|
|
156
|
+
# Cleanup
|
|
157
|
+
print(f"\nCleaning up model directory: {model_dir}")
|
|
158
|
+
shutil.rmtree(model_dir, ignore_errors=True)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
if __name__ == "__main__":
|
|
162
|
+
main()
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Local test harness for AWS Lambda scripts.
|
|
3
|
+
|
|
4
|
+
Usage:
|
|
5
|
+
lambda_test <lambda_script.py>
|
|
6
|
+
|
|
7
|
+
Required: testing/event.json with the event definition
|
|
8
|
+
Options: testing/env.json with a set of ENV vars
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import sys
|
|
12
|
+
import os
|
|
13
|
+
import json
|
|
14
|
+
import importlib.util
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def main():
|
|
18
|
+
if len(sys.argv) != 2:
|
|
19
|
+
print("Usage: lambda_launcher <handler_module_name>")
|
|
20
|
+
print("\nOptional: testing/event.json with test event")
|
|
21
|
+
print("Optional: testing/env.json with environment variables")
|
|
22
|
+
sys.exit(1)
|
|
23
|
+
|
|
24
|
+
handler_file = sys.argv[1]
|
|
25
|
+
|
|
26
|
+
# Add .py if not present
|
|
27
|
+
if not handler_file.endswith(".py"):
|
|
28
|
+
handler_file += ".py"
|
|
29
|
+
|
|
30
|
+
# Check if file exists
|
|
31
|
+
if not os.path.exists(handler_file):
|
|
32
|
+
print(f"Error: File '{handler_file}' not found")
|
|
33
|
+
sys.exit(1)
|
|
34
|
+
|
|
35
|
+
# Load environment variables from env.json if it exists
|
|
36
|
+
if os.path.exists("testing/env.json"):
|
|
37
|
+
print("Loading environment variables from testing/env.json")
|
|
38
|
+
with open("testing/env.json") as f:
|
|
39
|
+
env_vars = json.load(f)
|
|
40
|
+
for key, value in env_vars.items():
|
|
41
|
+
os.environ[key] = value
|
|
42
|
+
print(f" Set {key} = {value}")
|
|
43
|
+
print()
|
|
44
|
+
|
|
45
|
+
# Load event configuration
|
|
46
|
+
if os.path.exists("testing/event.json"):
|
|
47
|
+
print("Loading event from testing/event.json")
|
|
48
|
+
with open("testing/event.json") as f:
|
|
49
|
+
event = json.load(f)
|
|
50
|
+
else:
|
|
51
|
+
print("No testing/event.json found, using empty event")
|
|
52
|
+
event = {}
|
|
53
|
+
|
|
54
|
+
# Load the module dynamically
|
|
55
|
+
spec = importlib.util.spec_from_file_location("lambda_module", handler_file)
|
|
56
|
+
lambda_module = importlib.util.module_from_spec(spec)
|
|
57
|
+
spec.loader.exec_module(lambda_module)
|
|
58
|
+
|
|
59
|
+
# Call the lambda_handler
|
|
60
|
+
print(f"Invoking lambda_handler from {handler_file}...")
|
|
61
|
+
print("-" * 50)
|
|
62
|
+
print(f"Event: {json.dumps(event, indent=2)}")
|
|
63
|
+
print("-" * 50)
|
|
64
|
+
|
|
65
|
+
result = lambda_module.lambda_handler(event, {})
|
|
66
|
+
|
|
67
|
+
print("-" * 50)
|
|
68
|
+
print("Result:")
|
|
69
|
+
print(json.dumps(result, indent=2))
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
if __name__ == "__main__":
|
|
73
|
+
main()
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
"""MetaModelSimulator: Simulate and analyze ensemble model performance.
|
|
2
|
+
|
|
3
|
+
This class helps evaluate whether a meta model (ensemble) would outperform
|
|
4
|
+
individual child models by analyzing endpoint inference predictions.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import argparse
|
|
8
|
+
from workbench.utils.meta_model_simulator import MetaModelSimulator
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def main():
|
|
12
|
+
parser = argparse.ArgumentParser(
|
|
13
|
+
description="Simulate and analyze ensemble model performance using MetaModelSimulator."
|
|
14
|
+
)
|
|
15
|
+
parser.add_argument(
|
|
16
|
+
"models",
|
|
17
|
+
nargs="+",
|
|
18
|
+
help="List of model endpoint names to include in the ensemble simulation.",
|
|
19
|
+
)
|
|
20
|
+
parser.add_argument(
|
|
21
|
+
"--id-column",
|
|
22
|
+
default="molecule_name",
|
|
23
|
+
help="Name of the ID column (default: molecule_name)",
|
|
24
|
+
)
|
|
25
|
+
args = parser.parse_args()
|
|
26
|
+
models = args.models
|
|
27
|
+
id_column = args.id_column
|
|
28
|
+
|
|
29
|
+
# Create MetaModelSimulator instance and generate report
|
|
30
|
+
sim = MetaModelSimulator(models, id_column=id_column)
|
|
31
|
+
sim.report()
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
if __name__ == "__main__":
|
|
35
|
+
main()
|
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import logging
|
|
3
|
+
import time
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
# Workbench Imports
|
|
8
|
+
from workbench.core.cloud_platform.aws.aws_account_clamp import AWSAccountClamp
|
|
9
|
+
from workbench.utils.config_manager import ConfigManager
|
|
10
|
+
from workbench.utils.s3_utils import upload_content_to_s3
|
|
11
|
+
from workbench.utils.cloudwatch_utils import get_cloudwatch_logs_url
|
|
12
|
+
|
|
13
|
+
log = logging.getLogger("workbench")
|
|
14
|
+
cm = ConfigManager()
|
|
15
|
+
workbench_bucket = cm.get_config("WORKBENCH_BUCKET")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def get_ecr_image_uri() -> str:
|
|
19
|
+
"""Get the ECR image URI for the current region."""
|
|
20
|
+
region = AWSAccountClamp().region
|
|
21
|
+
return f"507740646243.dkr.ecr.{region}.amazonaws.com/aws-ml-images/py312-ml-pipelines:0.1"
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def get_batch_role_arn() -> str:
|
|
25
|
+
"""Get the Batch execution role ARN."""
|
|
26
|
+
account_id = AWSAccountClamp().account_id
|
|
27
|
+
return f"arn:aws:iam::{account_id}:role/Workbench-BatchRole"
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _log_cloudwatch_link(job: dict, message_prefix: str = "View logs") -> None:
|
|
31
|
+
"""
|
|
32
|
+
Helper method to log CloudWatch logs link with clickable URL and full URL display.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
job: Batch job description dictionary
|
|
36
|
+
message_prefix: Prefix for the log message (default: "View logs")
|
|
37
|
+
"""
|
|
38
|
+
log_stream = job.get("container", {}).get("logStreamName")
|
|
39
|
+
logs_url = get_cloudwatch_logs_url(log_group="/aws/batch/job", log_stream=log_stream)
|
|
40
|
+
if logs_url:
|
|
41
|
+
clickable_url = f"\033]8;;{logs_url}\033\\{logs_url}\033]8;;\033\\"
|
|
42
|
+
log.info(f"{message_prefix}: {clickable_url}")
|
|
43
|
+
else:
|
|
44
|
+
log.info("Check AWS Batch console for logs")
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def run_batch_job(script_path: str, size: str = "small") -> int:
|
|
48
|
+
"""
|
|
49
|
+
Submit and monitor an AWS Batch job for ML pipeline execution.
|
|
50
|
+
|
|
51
|
+
Uploads script to S3, submits Batch job, monitors until completion or 2 minutes of RUNNING.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
script_path: Local path to the ML pipeline script
|
|
55
|
+
size: Job size tier - "small" (default), "medium", or "large"
|
|
56
|
+
- small: 2 vCPU, 4GB RAM for lightweight processing
|
|
57
|
+
- medium: 4 vCPU, 8GB RAM for standard ML workloads
|
|
58
|
+
- large: 8 vCPU, 16GB RAM for heavy training/inference
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
Exit code (0 for success/disconnected, non-zero for failure)
|
|
62
|
+
"""
|
|
63
|
+
if size not in ["small", "medium", "large"]:
|
|
64
|
+
raise ValueError(f"Invalid size '{size}'. Must be 'small', 'medium', or 'large'")
|
|
65
|
+
|
|
66
|
+
batch = AWSAccountClamp().boto3_session.client("batch")
|
|
67
|
+
script_name = Path(script_path).stem
|
|
68
|
+
|
|
69
|
+
# Upload script to S3
|
|
70
|
+
s3_path = f"s3://{workbench_bucket}/batch-jobs/{Path(script_path).name}"
|
|
71
|
+
log.info(f"Uploading script to {s3_path}")
|
|
72
|
+
upload_content_to_s3(Path(script_path).read_text(), s3_path)
|
|
73
|
+
|
|
74
|
+
# Submit job
|
|
75
|
+
job_name = f"workbench_{script_name}_{datetime.now():%Y%m%d_%H%M%S}"
|
|
76
|
+
response = batch.submit_job(
|
|
77
|
+
jobName=job_name,
|
|
78
|
+
jobQueue="workbench-job-queue",
|
|
79
|
+
jobDefinition=f"workbench-batch-{size}",
|
|
80
|
+
containerOverrides={
|
|
81
|
+
"environment": [
|
|
82
|
+
{"name": "ML_PIPELINE_S3_PATH", "value": s3_path},
|
|
83
|
+
{"name": "WORKBENCH_BUCKET", "value": workbench_bucket},
|
|
84
|
+
]
|
|
85
|
+
},
|
|
86
|
+
)
|
|
87
|
+
job_id = response["jobId"]
|
|
88
|
+
log.info(f"Submitted job: {job_name} ({job_id}) using {size} tier")
|
|
89
|
+
|
|
90
|
+
# Monitor job
|
|
91
|
+
last_status, running_start = None, None
|
|
92
|
+
while True:
|
|
93
|
+
job = batch.describe_jobs(jobs=[job_id])["jobs"][0]
|
|
94
|
+
status = job["status"]
|
|
95
|
+
|
|
96
|
+
if status != last_status:
|
|
97
|
+
log.info(f"Job status: {status}")
|
|
98
|
+
last_status = status
|
|
99
|
+
if status == "RUNNING":
|
|
100
|
+
running_start = time.time()
|
|
101
|
+
|
|
102
|
+
# Disconnect after 2 minutes of running
|
|
103
|
+
if status == "RUNNING" and running_start and (time.time() - running_start >= 120):
|
|
104
|
+
log.info("✅ ML Pipeline is running successfully!")
|
|
105
|
+
_log_cloudwatch_link(job, "📊 Monitor logs")
|
|
106
|
+
return 0
|
|
107
|
+
|
|
108
|
+
# Handle completion
|
|
109
|
+
if status in ["SUCCEEDED", "FAILED"]:
|
|
110
|
+
exit_code = job.get("attempts", [{}])[-1].get("exitCode", 1)
|
|
111
|
+
msg = (
|
|
112
|
+
"Job completed successfully"
|
|
113
|
+
if status == "SUCCEEDED"
|
|
114
|
+
else f"Job failed: {job.get('statusReason', 'Unknown')}"
|
|
115
|
+
)
|
|
116
|
+
log.info(msg) if status == "SUCCEEDED" else log.error(msg)
|
|
117
|
+
_log_cloudwatch_link(job)
|
|
118
|
+
return exit_code
|
|
119
|
+
|
|
120
|
+
time.sleep(10)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def main():
|
|
124
|
+
"""CLI entry point for running ML pipelines on AWS Batch."""
|
|
125
|
+
parser = argparse.ArgumentParser(description="Run ML pipeline script on AWS Batch")
|
|
126
|
+
parser.add_argument("script_file", help="Local path to ML pipeline script")
|
|
127
|
+
args = parser.parse_args()
|
|
128
|
+
try:
|
|
129
|
+
exit_code = run_batch_job(args.script_file)
|
|
130
|
+
exit(exit_code)
|
|
131
|
+
except Exception as e:
|
|
132
|
+
log.error(f"Error: {e}")
|
|
133
|
+
exit(1)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
if __name__ == "__main__":
|
|
137
|
+
main()
|