workbench 0.8.174__py3-none-any.whl → 0.8.227__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of workbench might be problematic. Click here for more details.
- workbench/__init__.py +1 -0
- workbench/algorithms/dataframe/__init__.py +1 -2
- workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
- workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
- workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
- workbench/algorithms/dataframe/projection_2d.py +44 -21
- workbench/algorithms/dataframe/proximity.py +259 -305
- workbench/algorithms/graph/light/proximity_graph.py +12 -11
- workbench/algorithms/models/cleanlab_model.py +382 -0
- workbench/algorithms/models/noise_model.py +388 -0
- workbench/algorithms/sql/column_stats.py +0 -1
- workbench/algorithms/sql/correlations.py +0 -1
- workbench/algorithms/sql/descriptive_stats.py +0 -1
- workbench/algorithms/sql/outliers.py +3 -3
- workbench/api/__init__.py +5 -1
- workbench/api/df_store.py +17 -108
- workbench/api/endpoint.py +14 -12
- workbench/api/feature_set.py +117 -11
- workbench/api/meta.py +0 -1
- workbench/api/meta_model.py +289 -0
- workbench/api/model.py +52 -21
- workbench/api/parameter_store.py +3 -52
- workbench/cached/cached_meta.py +0 -1
- workbench/cached/cached_model.py +49 -11
- workbench/core/artifacts/__init__.py +11 -2
- workbench/core/artifacts/artifact.py +7 -7
- workbench/core/artifacts/data_capture_core.py +8 -1
- workbench/core/artifacts/df_store_core.py +114 -0
- workbench/core/artifacts/endpoint_core.py +323 -205
- workbench/core/artifacts/feature_set_core.py +249 -45
- workbench/core/artifacts/model_core.py +133 -101
- workbench/core/artifacts/parameter_store_core.py +98 -0
- workbench/core/cloud_platform/aws/aws_account_clamp.py +48 -2
- workbench/core/cloud_platform/cloud_meta.py +0 -1
- workbench/core/pipelines/pipeline_executor.py +1 -1
- workbench/core/transforms/features_to_model/features_to_model.py +60 -44
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +43 -10
- workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
- workbench/core/views/training_view.py +113 -42
- workbench/core/views/view.py +53 -3
- workbench/core/views/view_utils.py +4 -4
- workbench/model_script_utils/model_script_utils.py +339 -0
- workbench/model_script_utils/pytorch_utils.py +405 -0
- workbench/model_script_utils/uq_harness.py +277 -0
- workbench/model_scripts/chemprop/chemprop.template +774 -0
- workbench/model_scripts/chemprop/generated_model_script.py +774 -0
- workbench/model_scripts/chemprop/model_script_utils.py +339 -0
- workbench/model_scripts/chemprop/requirements.txt +3 -0
- workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
- workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +18 -7
- workbench/model_scripts/custom_models/chem_info/mol_standardize.py +80 -58
- workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
- workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -2
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
- workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
- workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
- workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
- workbench/model_scripts/custom_models/uq_models/ngboost.template +15 -16
- workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
- workbench/model_scripts/meta_model/generated_model_script.py +209 -0
- workbench/model_scripts/meta_model/meta_model.template +209 -0
- workbench/model_scripts/pytorch_model/generated_model_script.py +443 -499
- workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
- workbench/model_scripts/pytorch_model/pytorch.template +440 -496
- workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
- workbench/model_scripts/pytorch_model/requirements.txt +1 -1
- workbench/model_scripts/pytorch_model/uq_harness.py +277 -0
- workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
- workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
- workbench/model_scripts/script_generation.py +15 -12
- workbench/model_scripts/uq_models/generated_model_script.py +248 -0
- workbench/model_scripts/xgb_model/generated_model_script.py +371 -403
- workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
- workbench/model_scripts/xgb_model/uq_harness.py +277 -0
- workbench/model_scripts/xgb_model/xgb_model.template +367 -399
- workbench/repl/workbench_shell.py +18 -14
- workbench/resources/open_source_api.key +1 -1
- workbench/scripts/endpoint_test.py +162 -0
- workbench/scripts/lambda_test.py +73 -0
- workbench/scripts/meta_model_sim.py +35 -0
- workbench/scripts/ml_pipeline_sqs.py +122 -6
- workbench/scripts/training_test.py +85 -0
- workbench/themes/dark/custom.css +59 -0
- workbench/themes/dark/plotly.json +5 -5
- workbench/themes/light/custom.css +153 -40
- workbench/themes/light/plotly.json +9 -9
- workbench/themes/midnight_blue/custom.css +59 -0
- workbench/utils/aws_utils.py +0 -1
- workbench/utils/chem_utils/fingerprints.py +87 -46
- workbench/utils/chem_utils/mol_descriptors.py +18 -7
- workbench/utils/chem_utils/mol_standardize.py +80 -58
- workbench/utils/chem_utils/projections.py +16 -6
- workbench/utils/chem_utils/vis.py +25 -27
- workbench/utils/chemprop_utils.py +141 -0
- workbench/utils/config_manager.py +2 -6
- workbench/utils/endpoint_utils.py +5 -7
- workbench/utils/license_manager.py +2 -6
- workbench/utils/markdown_utils.py +57 -0
- workbench/utils/meta_model_simulator.py +499 -0
- workbench/utils/metrics_utils.py +256 -0
- workbench/utils/model_utils.py +274 -87
- workbench/utils/pipeline_utils.py +0 -1
- workbench/utils/plot_utils.py +159 -34
- workbench/utils/pytorch_utils.py +87 -0
- workbench/utils/shap_utils.py +11 -57
- workbench/utils/theme_manager.py +95 -30
- workbench/utils/xgboost_local_crossfold.py +267 -0
- workbench/utils/xgboost_model_utils.py +127 -220
- workbench/web_interface/components/experiments/outlier_plot.py +0 -1
- workbench/web_interface/components/model_plot.py +16 -2
- workbench/web_interface/components/plugin_unit_test.py +5 -3
- workbench/web_interface/components/plugins/ag_table.py +2 -4
- workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
- workbench/web_interface/components/plugins/model_details.py +48 -80
- workbench/web_interface/components/plugins/scatter_plot.py +192 -92
- workbench/web_interface/components/settings_menu.py +184 -0
- workbench/web_interface/page_views/main_page.py +0 -1
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/METADATA +31 -17
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/RECORD +125 -111
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/entry_points.txt +4 -0
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/licenses/LICENSE +1 -1
- workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
- workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
- workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
- workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
- workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
- workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
- workbench/model_scripts/custom_models/uq_models/mapie.template +0 -502
- workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -386
- workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
- workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
- workbench/model_scripts/quant_regression/quant_regression.template +0 -279
- workbench/model_scripts/quant_regression/requirements.txt +0 -1
- workbench/themes/quartz/base_css.url +0 -1
- workbench/themes/quartz/custom.css +0 -117
- workbench/themes/quartz/plotly.json +0 -642
- workbench/themes/quartz_dark/base_css.url +0 -1
- workbench/themes/quartz_dark/custom.css +0 -131
- workbench/themes/quartz_dark/plotly.json +0 -642
- workbench/utils/fast_inference.py +0 -167
- workbench/utils/resource_utils.py +0 -39
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/WHEEL +0 -0
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/top_level.txt +0 -0
|
@@ -1,16 +1,25 @@
|
|
|
1
|
+
# flake8: noqa: E402
|
|
2
|
+
import os
|
|
3
|
+
import sys
|
|
4
|
+
import logging
|
|
5
|
+
import importlib
|
|
6
|
+
import webbrowser
|
|
7
|
+
import readline # noqa: F401
|
|
8
|
+
|
|
9
|
+
# Disable OpenMP parallelism to avoid segfaults with PyTorch in iPython
|
|
10
|
+
# This is a known issue on macOS where libomp crashes during thread synchronization
|
|
11
|
+
# Must be set before importing numpy/pandas/torch or any library that uses OpenMP
|
|
12
|
+
os.environ.setdefault("OMP_NUM_THREADS", "1")
|
|
13
|
+
os.environ.setdefault("MKL_NUM_THREADS", "1")
|
|
14
|
+
|
|
1
15
|
import IPython
|
|
2
16
|
from IPython import start_ipython
|
|
3
17
|
from distutils.version import LooseVersion
|
|
4
18
|
from IPython.terminal.prompts import Prompts
|
|
5
19
|
from IPython.terminal.ipapp import load_default_config
|
|
6
20
|
from pygments.token import Token
|
|
7
|
-
import sys
|
|
8
|
-
import logging
|
|
9
|
-
import importlib
|
|
10
21
|
import botocore
|
|
11
|
-
import webbrowser
|
|
12
22
|
import pandas as pd
|
|
13
|
-
import readline # noqa
|
|
14
23
|
|
|
15
24
|
try:
|
|
16
25
|
import matplotlib.pyplot as plt # noqa
|
|
@@ -293,11 +302,6 @@ class WorkbenchShell:
|
|
|
293
302
|
self.commands["PandasToView"] = importlib.import_module("workbench.core.views.pandas_to_view").PandasToView
|
|
294
303
|
self.commands["Pipeline"] = importlib.import_module("workbench.api.pipeline").Pipeline
|
|
295
304
|
|
|
296
|
-
# Algorithms
|
|
297
|
-
self.commands["FSP"] = importlib.import_module(
|
|
298
|
-
"workbench.algorithms.dataframe.feature_space_proximity"
|
|
299
|
-
).FeatureSpaceProximity
|
|
300
|
-
|
|
301
305
|
# These are 'nice to have' imports
|
|
302
306
|
self.commands["pd"] = importlib.import_module("pandas")
|
|
303
307
|
self.commands["wr"] = importlib.import_module("awswrangler")
|
|
@@ -525,7 +529,7 @@ class WorkbenchShell:
|
|
|
525
529
|
def get_meta(self):
|
|
526
530
|
return self.meta
|
|
527
531
|
|
|
528
|
-
def plot_manager(self, data, plot_type: str = "
|
|
532
|
+
def plot_manager(self, data, plot_type: str = "scatter", **kwargs):
|
|
529
533
|
"""Plot Manager for Workbench"""
|
|
530
534
|
from workbench.web_interface.components.plugins import ag_table, graph_plot, scatter_plot
|
|
531
535
|
|
|
@@ -560,14 +564,14 @@ class WorkbenchShell:
|
|
|
560
564
|
from workbench.web_interface.components.plugin_unit_test import PluginUnitTest
|
|
561
565
|
|
|
562
566
|
# Get kwargs
|
|
563
|
-
theme = kwargs.get("theme", "
|
|
567
|
+
theme = kwargs.get("theme", "midnight_blue")
|
|
564
568
|
|
|
565
569
|
plugin_test = PluginUnitTest(plugin_class, theme=theme, input_data=data, **kwargs)
|
|
566
570
|
|
|
567
|
-
#
|
|
568
|
-
plugin_test.run()
|
|
571
|
+
# Open the browser and run the dash server
|
|
569
572
|
url = f"http://127.0.0.1:{plugin_test.port}"
|
|
570
573
|
webbrowser.open(url)
|
|
574
|
+
plugin_test.run()
|
|
571
575
|
|
|
572
576
|
|
|
573
577
|
# Launch Shell Entry Point
|
|
@@ -1 +1 @@
|
|
|
1
|
-
|
|
1
|
+
eyJsaWNlbnNlX2lkIjogIk9wZW5fU291cmNlX0xpY2Vuc2UiLCAiY29tcGFueSI6ICIiLCAiYXdzX2FjY291bnRfaWQiOiAiIiwgInRpZXIiOiAiRW50ZXJwcmlzZSBQcm8iLCAiZmVhdHVyZXMiOiBbInBsdWdpbnMiLCAicGFnZXMiLCAidGhlbWVzIiwgInBpcGVsaW5lcyIsICJicmFuZGluZyJdLCAiZXhwaXJlcyI6ICIyMDI2LTEyLTA1In1IsmpkuybFALADkRj_RfmkQ0LAIsQeXRE7Uoc3DL1UrDr-rSnwu-PDqsKBUkX6jPRFZV3DLxNjBapxPeEIFhfvxvjzz_sc6CwtxNpZ3bPmxSPs2W-j3xZS4-XyEqIilcwSkWh-NU1u27gCuuivn5eiUmIYJGAp0wdVkeE6_Z9dlg==
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Local test harness for SageMaker model scripts.
|
|
3
|
+
|
|
4
|
+
Usage:
|
|
5
|
+
python model_script_harness.py <local_script.py> <model_name>
|
|
6
|
+
|
|
7
|
+
Example:
|
|
8
|
+
python model_script_harness.py pytorch.py aqsol-reg-pytorch
|
|
9
|
+
|
|
10
|
+
This allows you to test LOCAL changes to a model script against deployed model artifacts.
|
|
11
|
+
Evaluation data is automatically pulled from the FeatureSet (training = FALSE rows).
|
|
12
|
+
|
|
13
|
+
Optional: testing/env.json with additional environment variables
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
import sys
|
|
18
|
+
import json
|
|
19
|
+
import importlib.util
|
|
20
|
+
import tempfile
|
|
21
|
+
import shutil
|
|
22
|
+
import pandas as pd
|
|
23
|
+
import torch
|
|
24
|
+
|
|
25
|
+
# Workbench Imports
|
|
26
|
+
from workbench.api import Model, FeatureSet
|
|
27
|
+
from workbench.utils.pytorch_utils import download_and_extract_model
|
|
28
|
+
|
|
29
|
+
# Force CPU mode BEFORE any PyTorch imports to avoid MPS/CUDA issues on Mac
|
|
30
|
+
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
|
31
|
+
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
|
32
|
+
torch.set_default_device("cpu")
|
|
33
|
+
# Disable MPS entirely
|
|
34
|
+
if hasattr(torch.backends, "mps"):
|
|
35
|
+
torch.backends.mps.is_available = lambda: False
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def get_eval_data(workbench_model: Model) -> pd.DataFrame:
|
|
39
|
+
"""Get evaluation data from the FeatureSet associated with this model."""
|
|
40
|
+
# Get the FeatureSet
|
|
41
|
+
fs_name = workbench_model.get_input()
|
|
42
|
+
fs = FeatureSet(fs_name)
|
|
43
|
+
if not fs.exists():
|
|
44
|
+
raise ValueError(f"No FeatureSet found: {fs_name}")
|
|
45
|
+
|
|
46
|
+
# Get evaluation data (training = FALSE)
|
|
47
|
+
table = workbench_model.training_view().table
|
|
48
|
+
print(f"Querying evaluation data from {table}...")
|
|
49
|
+
eval_df = fs.query(f'SELECT * FROM "{table}" WHERE training = FALSE')
|
|
50
|
+
print(f"Retrieved {len(eval_df)} evaluation rows")
|
|
51
|
+
|
|
52
|
+
return eval_df
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def load_model_script(script_path: str):
|
|
56
|
+
"""Dynamically load the model script module."""
|
|
57
|
+
if not os.path.exists(script_path):
|
|
58
|
+
raise FileNotFoundError(f"Script not found: {script_path}")
|
|
59
|
+
|
|
60
|
+
spec = importlib.util.spec_from_file_location("model_script", script_path)
|
|
61
|
+
module = importlib.util.module_from_spec(spec)
|
|
62
|
+
|
|
63
|
+
# Add to sys.modules so imports within the script work
|
|
64
|
+
sys.modules["model_script"] = module
|
|
65
|
+
|
|
66
|
+
spec.loader.exec_module(module)
|
|
67
|
+
return module
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def main():
|
|
71
|
+
if len(sys.argv) < 3:
|
|
72
|
+
print("Usage: python model_script_harness.py <local_script.py> <model_name>")
|
|
73
|
+
print("\nArguments:")
|
|
74
|
+
print(" local_script.py - Path to your LOCAL model script to test")
|
|
75
|
+
print(" model_name - Workbench model name (e.g., aqsol-reg-pytorch)")
|
|
76
|
+
print("\nOptional: testing/env.json with additional environment variables")
|
|
77
|
+
sys.exit(1)
|
|
78
|
+
|
|
79
|
+
script_path = sys.argv[1]
|
|
80
|
+
model_name = sys.argv[2]
|
|
81
|
+
|
|
82
|
+
# Validate local script exists
|
|
83
|
+
if not os.path.exists(script_path):
|
|
84
|
+
print(f"Error: Local script not found: {script_path}")
|
|
85
|
+
sys.exit(1)
|
|
86
|
+
|
|
87
|
+
# Initialize Workbench model
|
|
88
|
+
print(f"Loading Workbench model: {model_name}")
|
|
89
|
+
workbench_model = Model(model_name)
|
|
90
|
+
print(f"Model Framework: {workbench_model.model_framework}")
|
|
91
|
+
print()
|
|
92
|
+
|
|
93
|
+
# Create a temporary model directory
|
|
94
|
+
model_dir = tempfile.mkdtemp(prefix="model_harness_")
|
|
95
|
+
print(f"Using model directory: {model_dir}")
|
|
96
|
+
|
|
97
|
+
try:
|
|
98
|
+
# Load environment variables from env.json if it exists
|
|
99
|
+
if os.path.exists("testing/env.json"):
|
|
100
|
+
print("Loading environment variables from testing/env.json")
|
|
101
|
+
with open("testing/env.json") as f:
|
|
102
|
+
env_vars = json.load(f)
|
|
103
|
+
for key, value in env_vars.items():
|
|
104
|
+
os.environ[key] = value
|
|
105
|
+
print(f" Set {key} = {value}")
|
|
106
|
+
print()
|
|
107
|
+
|
|
108
|
+
# Set up SageMaker environment variables
|
|
109
|
+
os.environ["SM_MODEL_DIR"] = model_dir
|
|
110
|
+
print(f"Set SM_MODEL_DIR = {model_dir}")
|
|
111
|
+
|
|
112
|
+
# Download and extract model artifacts
|
|
113
|
+
s3_uri = workbench_model.model_data_url()
|
|
114
|
+
download_and_extract_model(s3_uri, model_dir)
|
|
115
|
+
print()
|
|
116
|
+
|
|
117
|
+
# Load the LOCAL model script
|
|
118
|
+
print(f"Loading LOCAL model script: {script_path}")
|
|
119
|
+
module = load_model_script(script_path)
|
|
120
|
+
print()
|
|
121
|
+
|
|
122
|
+
# Check for required functions
|
|
123
|
+
if not hasattr(module, "model_fn"):
|
|
124
|
+
raise AttributeError("Model script must have a model_fn function")
|
|
125
|
+
if not hasattr(module, "predict_fn"):
|
|
126
|
+
raise AttributeError("Model script must have a predict_fn function")
|
|
127
|
+
|
|
128
|
+
# Load the model
|
|
129
|
+
print("Calling model_fn...")
|
|
130
|
+
print("-" * 50)
|
|
131
|
+
model = module.model_fn(model_dir)
|
|
132
|
+
print("-" * 50)
|
|
133
|
+
print(f"Model loaded: {type(model)}")
|
|
134
|
+
print()
|
|
135
|
+
|
|
136
|
+
# Get evaluation data from FeatureSet
|
|
137
|
+
print("Pulling evaluation data from FeatureSet...")
|
|
138
|
+
df = get_eval_data(workbench_model)
|
|
139
|
+
print(f"Input shape: {df.shape}")
|
|
140
|
+
print(f"Columns: {df.columns.tolist()}")
|
|
141
|
+
print()
|
|
142
|
+
|
|
143
|
+
print("Calling predict_fn...")
|
|
144
|
+
print("-" * 50)
|
|
145
|
+
result = module.predict_fn(df, model)
|
|
146
|
+
print("-" * 50)
|
|
147
|
+
print()
|
|
148
|
+
|
|
149
|
+
print("Prediction result:")
|
|
150
|
+
print(f"Output shape: {result.shape}")
|
|
151
|
+
print(f"Output columns: {result.columns.tolist()}")
|
|
152
|
+
print()
|
|
153
|
+
print(result.head(10).to_string())
|
|
154
|
+
|
|
155
|
+
finally:
|
|
156
|
+
# Cleanup
|
|
157
|
+
print(f"\nCleaning up model directory: {model_dir}")
|
|
158
|
+
shutil.rmtree(model_dir, ignore_errors=True)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
if __name__ == "__main__":
|
|
162
|
+
main()
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Local test harness for AWS Lambda scripts.
|
|
3
|
+
|
|
4
|
+
Usage:
|
|
5
|
+
lambda_test <lambda_script.py>
|
|
6
|
+
|
|
7
|
+
Required: testing/event.json with the event definition
|
|
8
|
+
Options: testing/env.json with a set of ENV vars
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import sys
|
|
12
|
+
import os
|
|
13
|
+
import json
|
|
14
|
+
import importlib.util
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def main():
|
|
18
|
+
if len(sys.argv) != 2:
|
|
19
|
+
print("Usage: lambda_launcher <handler_module_name>")
|
|
20
|
+
print("\nOptional: testing/event.json with test event")
|
|
21
|
+
print("Optional: testing/env.json with environment variables")
|
|
22
|
+
sys.exit(1)
|
|
23
|
+
|
|
24
|
+
handler_file = sys.argv[1]
|
|
25
|
+
|
|
26
|
+
# Add .py if not present
|
|
27
|
+
if not handler_file.endswith(".py"):
|
|
28
|
+
handler_file += ".py"
|
|
29
|
+
|
|
30
|
+
# Check if file exists
|
|
31
|
+
if not os.path.exists(handler_file):
|
|
32
|
+
print(f"Error: File '{handler_file}' not found")
|
|
33
|
+
sys.exit(1)
|
|
34
|
+
|
|
35
|
+
# Load environment variables from env.json if it exists
|
|
36
|
+
if os.path.exists("testing/env.json"):
|
|
37
|
+
print("Loading environment variables from testing/env.json")
|
|
38
|
+
with open("testing/env.json") as f:
|
|
39
|
+
env_vars = json.load(f)
|
|
40
|
+
for key, value in env_vars.items():
|
|
41
|
+
os.environ[key] = value
|
|
42
|
+
print(f" Set {key} = {value}")
|
|
43
|
+
print()
|
|
44
|
+
|
|
45
|
+
# Load event configuration
|
|
46
|
+
if os.path.exists("testing/event.json"):
|
|
47
|
+
print("Loading event from testing/event.json")
|
|
48
|
+
with open("testing/event.json") as f:
|
|
49
|
+
event = json.load(f)
|
|
50
|
+
else:
|
|
51
|
+
print("No testing/event.json found, using empty event")
|
|
52
|
+
event = {}
|
|
53
|
+
|
|
54
|
+
# Load the module dynamically
|
|
55
|
+
spec = importlib.util.spec_from_file_location("lambda_module", handler_file)
|
|
56
|
+
lambda_module = importlib.util.module_from_spec(spec)
|
|
57
|
+
spec.loader.exec_module(lambda_module)
|
|
58
|
+
|
|
59
|
+
# Call the lambda_handler
|
|
60
|
+
print(f"Invoking lambda_handler from {handler_file}...")
|
|
61
|
+
print("-" * 50)
|
|
62
|
+
print(f"Event: {json.dumps(event, indent=2)}")
|
|
63
|
+
print("-" * 50)
|
|
64
|
+
|
|
65
|
+
result = lambda_module.lambda_handler(event, {})
|
|
66
|
+
|
|
67
|
+
print("-" * 50)
|
|
68
|
+
print("Result:")
|
|
69
|
+
print(json.dumps(result, indent=2))
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
if __name__ == "__main__":
|
|
73
|
+
main()
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
"""MetaModelSimulator: Simulate and analyze ensemble model performance.
|
|
2
|
+
|
|
3
|
+
This class helps evaluate whether a meta model (ensemble) would outperform
|
|
4
|
+
individual child models by analyzing endpoint inference predictions.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import argparse
|
|
8
|
+
from workbench.utils.meta_model_simulator import MetaModelSimulator
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def main():
|
|
12
|
+
parser = argparse.ArgumentParser(
|
|
13
|
+
description="Simulate and analyze ensemble model performance using MetaModelSimulator."
|
|
14
|
+
)
|
|
15
|
+
parser.add_argument(
|
|
16
|
+
"models",
|
|
17
|
+
nargs="+",
|
|
18
|
+
help="List of model endpoint names to include in the ensemble simulation.",
|
|
19
|
+
)
|
|
20
|
+
parser.add_argument(
|
|
21
|
+
"--id-column",
|
|
22
|
+
default="molecule_name",
|
|
23
|
+
help="Name of the ID column (default: molecule_name)",
|
|
24
|
+
)
|
|
25
|
+
args = parser.parse_args()
|
|
26
|
+
models = args.models
|
|
27
|
+
id_column = args.id_column
|
|
28
|
+
|
|
29
|
+
# Create MetaModelSimulator instance and generate report
|
|
30
|
+
sim = MetaModelSimulator(models, id_column=id_column)
|
|
31
|
+
sim.report()
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
if __name__ == "__main__":
|
|
35
|
+
main()
|
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
import argparse
|
|
2
|
+
import ast
|
|
2
3
|
import logging
|
|
3
4
|
import json
|
|
5
|
+
import re
|
|
4
6
|
from pathlib import Path
|
|
5
7
|
|
|
6
8
|
# Workbench Imports
|
|
@@ -13,27 +15,105 @@ cm = ConfigManager()
|
|
|
13
15
|
workbench_bucket = cm.get_config("WORKBENCH_BUCKET")
|
|
14
16
|
|
|
15
17
|
|
|
16
|
-
def
|
|
18
|
+
def parse_workbench_batch(script_content: str) -> dict | None:
|
|
19
|
+
"""Parse WORKBENCH_BATCH config from a script.
|
|
20
|
+
|
|
21
|
+
Looks for a dictionary assignment like:
|
|
22
|
+
WORKBENCH_BATCH = {
|
|
23
|
+
"outputs": ["feature_set_xyz"],
|
|
24
|
+
}
|
|
25
|
+
or:
|
|
26
|
+
WORKBENCH_BATCH = {
|
|
27
|
+
"inputs": ["feature_set_xyz"],
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
script_content: The Python script content as a string
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
The parsed dictionary or None if not found
|
|
35
|
+
"""
|
|
36
|
+
pattern = r"WORKBENCH_BATCH\s*=\s*(\{[^}]+\})"
|
|
37
|
+
match = re.search(pattern, script_content, re.DOTALL)
|
|
38
|
+
if match:
|
|
39
|
+
try:
|
|
40
|
+
return ast.literal_eval(match.group(1))
|
|
41
|
+
except (ValueError, SyntaxError) as e:
|
|
42
|
+
print(f"⚠️ Warning: Failed to parse WORKBENCH_BATCH: {e}")
|
|
43
|
+
return None
|
|
44
|
+
return None
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def get_message_group_id(batch_config: dict | None) -> str:
|
|
48
|
+
"""Derive MessageGroupId from outputs or inputs.
|
|
49
|
+
|
|
50
|
+
- Scripts with outputs use first output as group
|
|
51
|
+
- Scripts with inputs use first input as group
|
|
52
|
+
- Default to "ml-pipeline-jobs" if no config
|
|
53
|
+
"""
|
|
54
|
+
if not batch_config:
|
|
55
|
+
return "ml-pipeline-jobs"
|
|
56
|
+
|
|
57
|
+
outputs = batch_config.get("outputs", [])
|
|
58
|
+
inputs = batch_config.get("inputs", [])
|
|
59
|
+
|
|
60
|
+
if outputs:
|
|
61
|
+
return outputs[0]
|
|
62
|
+
elif inputs:
|
|
63
|
+
return inputs[0]
|
|
64
|
+
else:
|
|
65
|
+
return "ml-pipeline-jobs"
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def submit_to_sqs(
|
|
69
|
+
script_path: str,
|
|
70
|
+
size: str = "small",
|
|
71
|
+
realtime: bool = False,
|
|
72
|
+
dt: bool = False,
|
|
73
|
+
promote: bool = False,
|
|
74
|
+
) -> None:
|
|
17
75
|
"""
|
|
18
76
|
Upload script to S3 and submit message to SQS queue for processing.
|
|
77
|
+
|
|
19
78
|
Args:
|
|
20
79
|
script_path: Local path to the ML pipeline script
|
|
21
80
|
size: Job size tier - "small" (default), "medium", or "large"
|
|
81
|
+
realtime: If True, sets serverless=False for real-time processing (default: False)
|
|
82
|
+
dt: If True, sets DT=True in environment (default: False)
|
|
83
|
+
promote: If True, sets PROMOTE=True in environment (default: False)
|
|
84
|
+
|
|
85
|
+
Raises:
|
|
86
|
+
ValueError: If size is invalid or script file not found
|
|
22
87
|
"""
|
|
23
88
|
print(f"\n{'=' * 60}")
|
|
24
89
|
print("🚀 SUBMITTING ML PIPELINE JOB")
|
|
25
90
|
print(f"{'=' * 60}")
|
|
26
|
-
|
|
27
91
|
if size not in ["small", "medium", "large"]:
|
|
28
92
|
raise ValueError(f"Invalid size '{size}'. Must be 'small', 'medium', or 'large'")
|
|
93
|
+
|
|
29
94
|
# Validate script exists
|
|
30
95
|
script_file = Path(script_path)
|
|
31
96
|
if not script_file.exists():
|
|
32
97
|
raise FileNotFoundError(f"Script not found: {script_path}")
|
|
33
98
|
|
|
99
|
+
# Read script content and parse WORKBENCH_BATCH config
|
|
100
|
+
script_content = script_file.read_text()
|
|
101
|
+
batch_config = parse_workbench_batch(script_content)
|
|
102
|
+
group_id = get_message_group_id(batch_config)
|
|
103
|
+
outputs = (batch_config or {}).get("outputs", [])
|
|
104
|
+
inputs = (batch_config or {}).get("inputs", [])
|
|
105
|
+
|
|
34
106
|
print(f"📄 Script: {script_file.name}")
|
|
35
107
|
print(f"📏 Size tier: {size}")
|
|
108
|
+
print(f"⚡ Mode: {'Real-time' if realtime else 'Serverless'} (serverless={'False' if realtime else 'True'})")
|
|
109
|
+
print(f"🔄 DynamicTraining: {dt}")
|
|
110
|
+
print(f"🆕 Promote: {promote}")
|
|
36
111
|
print(f"🪣 Bucket: {workbench_bucket}")
|
|
112
|
+
if outputs:
|
|
113
|
+
print(f"📤 Outputs: {outputs}")
|
|
114
|
+
if inputs:
|
|
115
|
+
print(f"📥 Inputs: {inputs}")
|
|
116
|
+
print(f"📦 Batch Group: {group_id}")
|
|
37
117
|
sqs = AWSAccountClamp().boto3_session.client("sqs")
|
|
38
118
|
script_name = script_file.name
|
|
39
119
|
|
|
@@ -59,7 +139,7 @@ def submit_to_sqs(script_path: str, size: str = "small") -> None:
|
|
|
59
139
|
print(f" Destination: {s3_path}")
|
|
60
140
|
|
|
61
141
|
try:
|
|
62
|
-
upload_content_to_s3(
|
|
142
|
+
upload_content_to_s3(script_content, s3_path)
|
|
63
143
|
print("✅ Script uploaded successfully")
|
|
64
144
|
except Exception as e:
|
|
65
145
|
print(f"❌ Upload failed: {e}")
|
|
@@ -88,14 +168,21 @@ def submit_to_sqs(script_path: str, size: str = "small") -> None:
|
|
|
88
168
|
|
|
89
169
|
# Prepare message
|
|
90
170
|
message = {"script_path": s3_path, "size": size}
|
|
91
|
-
|
|
171
|
+
|
|
172
|
+
# Set environment variables
|
|
173
|
+
message["environment"] = {
|
|
174
|
+
"SERVERLESS": "False" if realtime else "True",
|
|
175
|
+
"DT": str(dt),
|
|
176
|
+
"PROMOTE": str(promote),
|
|
177
|
+
}
|
|
92
178
|
|
|
93
179
|
# Send the message to SQS
|
|
94
180
|
try:
|
|
181
|
+
print("\n📨 Sending message to SQS...")
|
|
95
182
|
response = sqs.send_message(
|
|
96
183
|
QueueUrl=queue_url,
|
|
97
184
|
MessageBody=json.dumps(message, indent=2),
|
|
98
|
-
MessageGroupId=
|
|
185
|
+
MessageGroupId=group_id, # From WORKBENCH_BATCH or default
|
|
99
186
|
)
|
|
100
187
|
message_id = response["MessageId"]
|
|
101
188
|
print("✅ Message sent successfully!")
|
|
@@ -110,6 +197,14 @@ def submit_to_sqs(script_path: str, size: str = "small") -> None:
|
|
|
110
197
|
print(f"{'=' * 60}")
|
|
111
198
|
print(f"📄 Script: {script_name}")
|
|
112
199
|
print(f"📏 Size: {size}")
|
|
200
|
+
print(f"⚡ Mode: {'Real-time' if realtime else 'Serverless'} (SERVERLESS={'False' if realtime else 'True'})")
|
|
201
|
+
print(f"🔄 DynamicTraining: {dt}")
|
|
202
|
+
print(f"🆕 Promote: {promote}")
|
|
203
|
+
if outputs:
|
|
204
|
+
print(f"📤 Outputs: {outputs}")
|
|
205
|
+
if inputs:
|
|
206
|
+
print(f"📥 Inputs: {inputs}")
|
|
207
|
+
print(f"📦 Batch Group: {group_id}")
|
|
113
208
|
print(f"🆔 Message ID: {message_id}")
|
|
114
209
|
print("\n🔍 MONITORING LOCATIONS:")
|
|
115
210
|
print(f" • SQS Queue: AWS Console → SQS → {queue_name}")
|
|
@@ -126,9 +221,30 @@ def main():
|
|
|
126
221
|
parser.add_argument(
|
|
127
222
|
"--size", default="small", choices=["small", "medium", "large"], help="Job size tier (default: small)"
|
|
128
223
|
)
|
|
224
|
+
parser.add_argument(
|
|
225
|
+
"--realtime",
|
|
226
|
+
action="store_true",
|
|
227
|
+
help="Create realtime endpoints (default is serverless)",
|
|
228
|
+
)
|
|
229
|
+
parser.add_argument(
|
|
230
|
+
"--dt",
|
|
231
|
+
action="store_true",
|
|
232
|
+
help="Set DT=True (models and endpoints will have '-dt' suffix)",
|
|
233
|
+
)
|
|
234
|
+
parser.add_argument(
|
|
235
|
+
"--promote",
|
|
236
|
+
action="store_true",
|
|
237
|
+
help="Set Promote=True (models and endpoints will use promoted naming",
|
|
238
|
+
)
|
|
129
239
|
args = parser.parse_args()
|
|
130
240
|
try:
|
|
131
|
-
submit_to_sqs(
|
|
241
|
+
submit_to_sqs(
|
|
242
|
+
args.script_file,
|
|
243
|
+
args.size,
|
|
244
|
+
realtime=args.realtime,
|
|
245
|
+
dt=args.dt,
|
|
246
|
+
promote=args.promote,
|
|
247
|
+
)
|
|
132
248
|
except Exception as e:
|
|
133
249
|
print(f"\n❌ ERROR: {e}")
|
|
134
250
|
log.error(f"Error: {e}")
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Local test harness for SageMaker training scripts.
|
|
3
|
+
|
|
4
|
+
Usage:
|
|
5
|
+
python training_test.py <model_script.py> <featureset_name>
|
|
6
|
+
|
|
7
|
+
Example:
|
|
8
|
+
python training_test.py ../model_scripts/pytorch_model/generated_model_script.py caco2-class-features
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import os
|
|
12
|
+
import shutil
|
|
13
|
+
import subprocess
|
|
14
|
+
import sys
|
|
15
|
+
import tempfile
|
|
16
|
+
|
|
17
|
+
import pandas as pd
|
|
18
|
+
|
|
19
|
+
from workbench.api import FeatureSet
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def get_training_data(featureset_name: str) -> pd.DataFrame:
|
|
23
|
+
"""Get training data from the FeatureSet."""
|
|
24
|
+
fs = FeatureSet(featureset_name)
|
|
25
|
+
return fs.pull_dataframe()
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def main():
|
|
29
|
+
if len(sys.argv) < 3:
|
|
30
|
+
print("Usage: python training_test.py <model_script.py> <featureset_name>")
|
|
31
|
+
sys.exit(1)
|
|
32
|
+
|
|
33
|
+
script_path = sys.argv[1]
|
|
34
|
+
featureset_name = sys.argv[2]
|
|
35
|
+
|
|
36
|
+
if not os.path.exists(script_path):
|
|
37
|
+
print(f"Error: Script not found: {script_path}")
|
|
38
|
+
sys.exit(1)
|
|
39
|
+
|
|
40
|
+
# Create temp directories
|
|
41
|
+
model_dir = tempfile.mkdtemp(prefix="training_model_")
|
|
42
|
+
train_dir = tempfile.mkdtemp(prefix="training_data_")
|
|
43
|
+
output_dir = tempfile.mkdtemp(prefix="training_output_")
|
|
44
|
+
|
|
45
|
+
print(f"Model dir: {model_dir}")
|
|
46
|
+
print(f"Train dir: {train_dir}")
|
|
47
|
+
|
|
48
|
+
try:
|
|
49
|
+
# Get training data and save to CSV
|
|
50
|
+
print(f"Loading FeatureSet: {featureset_name}")
|
|
51
|
+
df = get_training_data(featureset_name)
|
|
52
|
+
print(f"Data shape: {df.shape}")
|
|
53
|
+
|
|
54
|
+
train_file = os.path.join(train_dir, "training_data.csv")
|
|
55
|
+
df.to_csv(train_file, index=False)
|
|
56
|
+
|
|
57
|
+
# Set up environment
|
|
58
|
+
env = os.environ.copy()
|
|
59
|
+
env["SM_MODEL_DIR"] = model_dir
|
|
60
|
+
env["SM_CHANNEL_TRAIN"] = train_dir
|
|
61
|
+
env["SM_OUTPUT_DATA_DIR"] = output_dir
|
|
62
|
+
|
|
63
|
+
print("\n" + "=" * 60)
|
|
64
|
+
print("Starting training...")
|
|
65
|
+
print("=" * 60 + "\n")
|
|
66
|
+
|
|
67
|
+
# Run the script
|
|
68
|
+
cmd = [sys.executable, script_path, "--model-dir", model_dir, "--train", train_dir]
|
|
69
|
+
result = subprocess.run(cmd, env=env)
|
|
70
|
+
|
|
71
|
+
print("\n" + "=" * 60)
|
|
72
|
+
if result.returncode == 0:
|
|
73
|
+
print("Training completed successfully!")
|
|
74
|
+
else:
|
|
75
|
+
print(f"Training failed with return code: {result.returncode}")
|
|
76
|
+
print("=" * 60)
|
|
77
|
+
|
|
78
|
+
finally:
|
|
79
|
+
shutil.rmtree(model_dir, ignore_errors=True)
|
|
80
|
+
shutil.rmtree(train_dir, ignore_errors=True)
|
|
81
|
+
shutil.rmtree(output_dir, ignore_errors=True)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
if __name__ == "__main__":
|
|
85
|
+
main()
|
workbench/themes/dark/custom.css
CHANGED
|
@@ -110,6 +110,40 @@ a:hover {
|
|
|
110
110
|
color: rgb(100, 255, 100);
|
|
111
111
|
}
|
|
112
112
|
|
|
113
|
+
/* Dropdown styling (dcc.Dropdown) - override Bootstrap's variables */
|
|
114
|
+
.dash-dropdown {
|
|
115
|
+
--bs-body-bg: rgb(35, 35, 35);
|
|
116
|
+
--bs-body-color: rgb(210, 210, 210);
|
|
117
|
+
--bs-border-color: rgb(60, 60, 60);
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
/* Bootstrap form controls (dbc components) */
|
|
121
|
+
.form-select, .form-control {
|
|
122
|
+
background-color: rgb(35, 35, 35) !important;
|
|
123
|
+
border: 1px solid rgb(60, 60, 60) !important;
|
|
124
|
+
color: rgb(210, 210, 210) !important;
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
.form-select:focus, .form-control:focus {
|
|
128
|
+
background-color: rgb(45, 45, 45) !important;
|
|
129
|
+
border-color: rgb(80, 80, 80) !important;
|
|
130
|
+
box-shadow: 0 0 0 0.2rem rgba(80, 80, 80, 0.25) !important;
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
.dropdown-menu {
|
|
134
|
+
background-color: rgb(35, 35, 35) !important;
|
|
135
|
+
border: 1px solid rgb(60, 60, 60) !important;
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
.dropdown-item {
|
|
139
|
+
color: rgb(210, 210, 210) !important;
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
.dropdown-item:hover, .dropdown-item:focus {
|
|
143
|
+
background-color: rgb(50, 50, 50) !important;
|
|
144
|
+
color: rgb(230, 230, 230) !important;
|
|
145
|
+
}
|
|
146
|
+
|
|
113
147
|
/* Table styling */
|
|
114
148
|
table {
|
|
115
149
|
width: 100%;
|
|
@@ -128,4 +162,29 @@ td {
|
|
|
128
162
|
padding: 5px;
|
|
129
163
|
border: 0.5px solid #444;
|
|
130
164
|
text-align: center !important;
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
/* AG Grid table header colors - gradient theme */
|
|
168
|
+
/* Data Sources tables - red gradient */
|
|
169
|
+
#main_data_sources .ag-header,
|
|
170
|
+
#data_sources_table .ag-header {
|
|
171
|
+
background: linear-gradient(180deg, rgb(140, 60, 60) 0%, rgb(80, 35, 35) 100%) !important;
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
/* Feature Sets tables - yellow/olive gradient */
|
|
175
|
+
#main_feature_sets .ag-header,
|
|
176
|
+
#feature_sets_table .ag-header {
|
|
177
|
+
background: linear-gradient(180deg, rgb(120, 115, 55) 0%, rgb(70, 65, 30) 100%) !important;
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
/* Models tables - green gradient */
|
|
181
|
+
#main_models .ag-header,
|
|
182
|
+
#models_table .ag-header {
|
|
183
|
+
background: linear-gradient(180deg, rgb(55, 110, 55) 0%, rgb(30, 60, 30) 100%) !important;
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
/* Endpoints tables - purple gradient */
|
|
187
|
+
#main_endpoints .ag-header,
|
|
188
|
+
#endpoints_table .ag-header {
|
|
189
|
+
background: linear-gradient(180deg, rgb(100, 60, 120) 0%, rgb(55, 30, 70) 100%) !important;
|
|
131
190
|
}
|