workbench 0.8.162__py3-none-any.whl ā 0.8.202__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of workbench might be problematic. Click here for more details.
- workbench/algorithms/dataframe/__init__.py +1 -2
- workbench/algorithms/dataframe/fingerprint_proximity.py +2 -2
- workbench/algorithms/dataframe/proximity.py +261 -235
- workbench/algorithms/graph/light/proximity_graph.py +10 -8
- workbench/api/__init__.py +2 -1
- workbench/api/compound.py +1 -1
- workbench/api/endpoint.py +11 -0
- workbench/api/feature_set.py +11 -8
- workbench/api/meta.py +5 -2
- workbench/api/model.py +16 -15
- workbench/api/monitor.py +1 -16
- workbench/core/artifacts/__init__.py +11 -2
- workbench/core/artifacts/artifact.py +11 -3
- workbench/core/artifacts/data_capture_core.py +355 -0
- workbench/core/artifacts/endpoint_core.py +256 -118
- workbench/core/artifacts/feature_set_core.py +265 -16
- workbench/core/artifacts/model_core.py +107 -60
- workbench/core/artifacts/monitor_core.py +33 -248
- workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
- workbench/core/cloud_platform/aws/aws_meta.py +12 -5
- workbench/core/cloud_platform/aws/aws_parameter_store.py +18 -2
- workbench/core/cloud_platform/aws/aws_session.py +4 -4
- workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
- workbench/core/transforms/features_to_model/features_to_model.py +42 -32
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +36 -6
- workbench/core/transforms/pandas_transforms/pandas_to_features.py +27 -0
- workbench/core/views/training_view.py +113 -42
- workbench/core/views/view.py +53 -3
- workbench/core/views/view_utils.py +4 -4
- workbench/model_scripts/chemprop/chemprop.template +852 -0
- workbench/model_scripts/chemprop/generated_model_script.py +852 -0
- workbench/model_scripts/chemprop/requirements.txt +11 -0
- workbench/model_scripts/custom_models/chem_info/fingerprints.py +134 -0
- workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
- workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
- workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
- workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +3 -5
- workbench/model_scripts/custom_models/proximity/proximity.py +261 -235
- workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
- workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
- workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
- workbench/model_scripts/custom_models/uq_models/meta_uq.template +166 -62
- workbench/model_scripts/custom_models/uq_models/ngboost.template +30 -18
- workbench/model_scripts/custom_models/uq_models/proximity.py +261 -235
- workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
- workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
- workbench/model_scripts/pytorch_model/generated_model_script.py +373 -190
- workbench/model_scripts/pytorch_model/pytorch.template +370 -187
- workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
- workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
- workbench/model_scripts/script_generation.py +17 -9
- workbench/model_scripts/uq_models/generated_model_script.py +605 -0
- workbench/model_scripts/uq_models/mapie.template +605 -0
- workbench/model_scripts/uq_models/requirements.txt +1 -0
- workbench/model_scripts/xgb_model/generated_model_script.py +37 -46
- workbench/model_scripts/xgb_model/xgb_model.template +44 -46
- workbench/repl/workbench_shell.py +28 -14
- workbench/scripts/endpoint_test.py +162 -0
- workbench/scripts/lambda_test.py +73 -0
- workbench/scripts/ml_pipeline_batch.py +137 -0
- workbench/scripts/ml_pipeline_sqs.py +186 -0
- workbench/scripts/monitor_cloud_watch.py +20 -100
- workbench/utils/aws_utils.py +4 -3
- workbench/utils/chem_utils/__init__.py +0 -0
- workbench/utils/chem_utils/fingerprints.py +134 -0
- workbench/utils/chem_utils/misc.py +194 -0
- workbench/utils/chem_utils/mol_descriptors.py +483 -0
- workbench/utils/chem_utils/mol_standardize.py +450 -0
- workbench/utils/chem_utils/mol_tagging.py +348 -0
- workbench/utils/chem_utils/projections.py +209 -0
- workbench/utils/chem_utils/salts.py +256 -0
- workbench/utils/chem_utils/sdf.py +292 -0
- workbench/utils/chem_utils/toxicity.py +250 -0
- workbench/utils/chem_utils/vis.py +253 -0
- workbench/utils/chemprop_utils.py +760 -0
- workbench/utils/cloudwatch_handler.py +1 -1
- workbench/utils/cloudwatch_utils.py +137 -0
- workbench/utils/config_manager.py +3 -7
- workbench/utils/endpoint_utils.py +5 -7
- workbench/utils/license_manager.py +2 -6
- workbench/utils/model_utils.py +95 -34
- workbench/utils/monitor_utils.py +44 -62
- workbench/utils/pandas_utils.py +3 -3
- workbench/utils/pytorch_utils.py +526 -0
- workbench/utils/shap_utils.py +10 -2
- workbench/utils/workbench_logging.py +0 -3
- workbench/utils/workbench_sqs.py +1 -1
- workbench/utils/xgboost_model_utils.py +371 -156
- workbench/web_interface/components/model_plot.py +7 -1
- workbench/web_interface/components/plugin_unit_test.py +5 -2
- workbench/web_interface/components/plugins/dashboard_status.py +3 -1
- workbench/web_interface/components/plugins/generated_compounds.py +1 -1
- workbench/web_interface/components/plugins/model_details.py +9 -7
- workbench/web_interface/components/plugins/scatter_plot.py +3 -3
- {workbench-0.8.162.dist-info ā workbench-0.8.202.dist-info}/METADATA +27 -6
- {workbench-0.8.162.dist-info ā workbench-0.8.202.dist-info}/RECORD +101 -85
- {workbench-0.8.162.dist-info ā workbench-0.8.202.dist-info}/entry_points.txt +4 -0
- {workbench-0.8.162.dist-info ā workbench-0.8.202.dist-info}/licenses/LICENSE +1 -1
- workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
- workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
- workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
- workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
- workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
- workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
- workbench/model_scripts/quant_regression/quant_regression.template +0 -279
- workbench/model_scripts/quant_regression/requirements.txt +0 -1
- workbench/utils/chem_utils.py +0 -1556
- workbench/utils/execution_environment.py +0 -211
- workbench/utils/fast_inference.py +0 -167
- workbench/utils/resource_utils.py +0 -39
- {workbench-0.8.162.dist-info ā workbench-0.8.202.dist-info}/WHEEL +0 -0
- {workbench-0.8.162.dist-info ā workbench-0.8.202.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Local test harness for SageMaker model scripts.
|
|
3
|
+
|
|
4
|
+
Usage:
|
|
5
|
+
python model_script_harness.py <local_script.py> <model_name>
|
|
6
|
+
|
|
7
|
+
Example:
|
|
8
|
+
python model_script_harness.py pytorch.py aqsol-pytorch-reg
|
|
9
|
+
|
|
10
|
+
This allows you to test LOCAL changes to a model script against deployed model artifacts.
|
|
11
|
+
Evaluation data is automatically pulled from the FeatureSet (training = FALSE rows).
|
|
12
|
+
|
|
13
|
+
Optional: testing/env.json with additional environment variables
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
import sys
|
|
18
|
+
import json
|
|
19
|
+
import importlib.util
|
|
20
|
+
import tempfile
|
|
21
|
+
import shutil
|
|
22
|
+
import pandas as pd
|
|
23
|
+
import torch
|
|
24
|
+
|
|
25
|
+
# Workbench Imports
|
|
26
|
+
from workbench.api import Model, FeatureSet
|
|
27
|
+
from workbench.utils.pytorch_utils import download_and_extract_model
|
|
28
|
+
|
|
29
|
+
# Force CPU mode BEFORE any PyTorch imports to avoid MPS/CUDA issues on Mac
|
|
30
|
+
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
|
31
|
+
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
|
32
|
+
torch.set_default_device("cpu")
|
|
33
|
+
# Disable MPS entirely
|
|
34
|
+
if hasattr(torch.backends, "mps"):
|
|
35
|
+
torch.backends.mps.is_available = lambda: False
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def get_eval_data(workbench_model: Model) -> pd.DataFrame:
|
|
39
|
+
"""Get evaluation data from the FeatureSet associated with this model."""
|
|
40
|
+
# Get the FeatureSet
|
|
41
|
+
fs_name = workbench_model.get_input()
|
|
42
|
+
fs = FeatureSet(fs_name)
|
|
43
|
+
if not fs.exists():
|
|
44
|
+
raise ValueError(f"No FeatureSet found: {fs_name}")
|
|
45
|
+
|
|
46
|
+
# Get evaluation data (training = FALSE)
|
|
47
|
+
table = workbench_model.training_view().table
|
|
48
|
+
print(f"Querying evaluation data from {table}...")
|
|
49
|
+
eval_df = fs.query(f'SELECT * FROM "{table}" WHERE training = FALSE')
|
|
50
|
+
print(f"Retrieved {len(eval_df)} evaluation rows")
|
|
51
|
+
|
|
52
|
+
return eval_df
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def load_model_script(script_path: str):
|
|
56
|
+
"""Dynamically load the model script module."""
|
|
57
|
+
if not os.path.exists(script_path):
|
|
58
|
+
raise FileNotFoundError(f"Script not found: {script_path}")
|
|
59
|
+
|
|
60
|
+
spec = importlib.util.spec_from_file_location("model_script", script_path)
|
|
61
|
+
module = importlib.util.module_from_spec(spec)
|
|
62
|
+
|
|
63
|
+
# Add to sys.modules so imports within the script work
|
|
64
|
+
sys.modules["model_script"] = module
|
|
65
|
+
|
|
66
|
+
spec.loader.exec_module(module)
|
|
67
|
+
return module
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def main():
|
|
71
|
+
if len(sys.argv) < 3:
|
|
72
|
+
print("Usage: python model_script_harness.py <local_script.py> <model_name>")
|
|
73
|
+
print("\nArguments:")
|
|
74
|
+
print(" local_script.py - Path to your LOCAL model script to test")
|
|
75
|
+
print(" model_name - Workbench model name (e.g., aqsol-pytorch-reg)")
|
|
76
|
+
print("\nOptional: testing/env.json with additional environment variables")
|
|
77
|
+
sys.exit(1)
|
|
78
|
+
|
|
79
|
+
script_path = sys.argv[1]
|
|
80
|
+
model_name = sys.argv[2]
|
|
81
|
+
|
|
82
|
+
# Validate local script exists
|
|
83
|
+
if not os.path.exists(script_path):
|
|
84
|
+
print(f"Error: Local script not found: {script_path}")
|
|
85
|
+
sys.exit(1)
|
|
86
|
+
|
|
87
|
+
# Initialize Workbench model
|
|
88
|
+
print(f"Loading Workbench model: {model_name}")
|
|
89
|
+
workbench_model = Model(model_name)
|
|
90
|
+
print(f"Model Framework: {workbench_model.model_framework}")
|
|
91
|
+
print()
|
|
92
|
+
|
|
93
|
+
# Create a temporary model directory
|
|
94
|
+
model_dir = tempfile.mkdtemp(prefix="model_harness_")
|
|
95
|
+
print(f"Using model directory: {model_dir}")
|
|
96
|
+
|
|
97
|
+
try:
|
|
98
|
+
# Load environment variables from env.json if it exists
|
|
99
|
+
if os.path.exists("testing/env.json"):
|
|
100
|
+
print("Loading environment variables from testing/env.json")
|
|
101
|
+
with open("testing/env.json") as f:
|
|
102
|
+
env_vars = json.load(f)
|
|
103
|
+
for key, value in env_vars.items():
|
|
104
|
+
os.environ[key] = value
|
|
105
|
+
print(f" Set {key} = {value}")
|
|
106
|
+
print()
|
|
107
|
+
|
|
108
|
+
# Set up SageMaker environment variables
|
|
109
|
+
os.environ["SM_MODEL_DIR"] = model_dir
|
|
110
|
+
print(f"Set SM_MODEL_DIR = {model_dir}")
|
|
111
|
+
|
|
112
|
+
# Download and extract model artifacts
|
|
113
|
+
s3_uri = workbench_model.model_data_url()
|
|
114
|
+
download_and_extract_model(s3_uri, model_dir)
|
|
115
|
+
print()
|
|
116
|
+
|
|
117
|
+
# Load the LOCAL model script
|
|
118
|
+
print(f"Loading LOCAL model script: {script_path}")
|
|
119
|
+
module = load_model_script(script_path)
|
|
120
|
+
print()
|
|
121
|
+
|
|
122
|
+
# Check for required functions
|
|
123
|
+
if not hasattr(module, "model_fn"):
|
|
124
|
+
raise AttributeError("Model script must have a model_fn function")
|
|
125
|
+
if not hasattr(module, "predict_fn"):
|
|
126
|
+
raise AttributeError("Model script must have a predict_fn function")
|
|
127
|
+
|
|
128
|
+
# Load the model
|
|
129
|
+
print("Calling model_fn...")
|
|
130
|
+
print("-" * 50)
|
|
131
|
+
model = module.model_fn(model_dir)
|
|
132
|
+
print("-" * 50)
|
|
133
|
+
print(f"Model loaded: {type(model)}")
|
|
134
|
+
print()
|
|
135
|
+
|
|
136
|
+
# Get evaluation data from FeatureSet
|
|
137
|
+
print("Pulling evaluation data from FeatureSet...")
|
|
138
|
+
df = get_eval_data(workbench_model)
|
|
139
|
+
print(f"Input shape: {df.shape}")
|
|
140
|
+
print(f"Columns: {df.columns.tolist()}")
|
|
141
|
+
print()
|
|
142
|
+
|
|
143
|
+
print("Calling predict_fn...")
|
|
144
|
+
print("-" * 50)
|
|
145
|
+
result = module.predict_fn(df, model)
|
|
146
|
+
print("-" * 50)
|
|
147
|
+
print()
|
|
148
|
+
|
|
149
|
+
print("Prediction result:")
|
|
150
|
+
print(f"Output shape: {result.shape}")
|
|
151
|
+
print(f"Output columns: {result.columns.tolist()}")
|
|
152
|
+
print()
|
|
153
|
+
print(result.head(10).to_string())
|
|
154
|
+
|
|
155
|
+
finally:
|
|
156
|
+
# Cleanup
|
|
157
|
+
print(f"\nCleaning up model directory: {model_dir}")
|
|
158
|
+
shutil.rmtree(model_dir, ignore_errors=True)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
if __name__ == "__main__":
|
|
162
|
+
main()
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Local test harness for AWS Lambda scripts.
|
|
3
|
+
|
|
4
|
+
Usage:
|
|
5
|
+
lambda_test <lambda_script.py>
|
|
6
|
+
|
|
7
|
+
Required: testing/event.json with the event definition
|
|
8
|
+
Options: testing/env.json with a set of ENV vars
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import sys
|
|
12
|
+
import os
|
|
13
|
+
import json
|
|
14
|
+
import importlib.util
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def main():
|
|
18
|
+
if len(sys.argv) != 2:
|
|
19
|
+
print("Usage: lambda_launcher <handler_module_name>")
|
|
20
|
+
print("\nOptional: testing/event.json with test event")
|
|
21
|
+
print("Optional: testing/env.json with environment variables")
|
|
22
|
+
sys.exit(1)
|
|
23
|
+
|
|
24
|
+
handler_file = sys.argv[1]
|
|
25
|
+
|
|
26
|
+
# Add .py if not present
|
|
27
|
+
if not handler_file.endswith(".py"):
|
|
28
|
+
handler_file += ".py"
|
|
29
|
+
|
|
30
|
+
# Check if file exists
|
|
31
|
+
if not os.path.exists(handler_file):
|
|
32
|
+
print(f"Error: File '{handler_file}' not found")
|
|
33
|
+
sys.exit(1)
|
|
34
|
+
|
|
35
|
+
# Load environment variables from env.json if it exists
|
|
36
|
+
if os.path.exists("testing/env.json"):
|
|
37
|
+
print("Loading environment variables from testing/env.json")
|
|
38
|
+
with open("testing/env.json") as f:
|
|
39
|
+
env_vars = json.load(f)
|
|
40
|
+
for key, value in env_vars.items():
|
|
41
|
+
os.environ[key] = value
|
|
42
|
+
print(f" Set {key} = {value}")
|
|
43
|
+
print()
|
|
44
|
+
|
|
45
|
+
# Load event configuration
|
|
46
|
+
if os.path.exists("testing/event.json"):
|
|
47
|
+
print("Loading event from testing/event.json")
|
|
48
|
+
with open("testing/event.json") as f:
|
|
49
|
+
event = json.load(f)
|
|
50
|
+
else:
|
|
51
|
+
print("No testing/event.json found, using empty event")
|
|
52
|
+
event = {}
|
|
53
|
+
|
|
54
|
+
# Load the module dynamically
|
|
55
|
+
spec = importlib.util.spec_from_file_location("lambda_module", handler_file)
|
|
56
|
+
lambda_module = importlib.util.module_from_spec(spec)
|
|
57
|
+
spec.loader.exec_module(lambda_module)
|
|
58
|
+
|
|
59
|
+
# Call the lambda_handler
|
|
60
|
+
print(f"Invoking lambda_handler from {handler_file}...")
|
|
61
|
+
print("-" * 50)
|
|
62
|
+
print(f"Event: {json.dumps(event, indent=2)}")
|
|
63
|
+
print("-" * 50)
|
|
64
|
+
|
|
65
|
+
result = lambda_module.lambda_handler(event, {})
|
|
66
|
+
|
|
67
|
+
print("-" * 50)
|
|
68
|
+
print("Result:")
|
|
69
|
+
print(json.dumps(result, indent=2))
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
if __name__ == "__main__":
|
|
73
|
+
main()
|
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import logging
|
|
3
|
+
import time
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
# Workbench Imports
|
|
8
|
+
from workbench.core.cloud_platform.aws.aws_account_clamp import AWSAccountClamp
|
|
9
|
+
from workbench.utils.config_manager import ConfigManager
|
|
10
|
+
from workbench.utils.s3_utils import upload_content_to_s3
|
|
11
|
+
from workbench.utils.cloudwatch_utils import get_cloudwatch_logs_url
|
|
12
|
+
|
|
13
|
+
log = logging.getLogger("workbench")
|
|
14
|
+
cm = ConfigManager()
|
|
15
|
+
workbench_bucket = cm.get_config("WORKBENCH_BUCKET")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def get_ecr_image_uri() -> str:
|
|
19
|
+
"""Get the ECR image URI for the current region."""
|
|
20
|
+
region = AWSAccountClamp().region
|
|
21
|
+
return f"507740646243.dkr.ecr.{region}.amazonaws.com/aws-ml-images/py312-ml-pipelines:0.1"
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def get_batch_role_arn() -> str:
|
|
25
|
+
"""Get the Batch execution role ARN."""
|
|
26
|
+
account_id = AWSAccountClamp().account_id
|
|
27
|
+
return f"arn:aws:iam::{account_id}:role/Workbench-BatchRole"
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _log_cloudwatch_link(job: dict, message_prefix: str = "View logs") -> None:
|
|
31
|
+
"""
|
|
32
|
+
Helper method to log CloudWatch logs link with clickable URL and full URL display.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
job: Batch job description dictionary
|
|
36
|
+
message_prefix: Prefix for the log message (default: "View logs")
|
|
37
|
+
"""
|
|
38
|
+
log_stream = job.get("container", {}).get("logStreamName")
|
|
39
|
+
logs_url = get_cloudwatch_logs_url(log_group="/aws/batch/job", log_stream=log_stream)
|
|
40
|
+
if logs_url:
|
|
41
|
+
clickable_url = f"\033]8;;{logs_url}\033\\{logs_url}\033]8;;\033\\"
|
|
42
|
+
log.info(f"{message_prefix}: {clickable_url}")
|
|
43
|
+
else:
|
|
44
|
+
log.info("Check AWS Batch console for logs")
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def run_batch_job(script_path: str, size: str = "small") -> int:
|
|
48
|
+
"""
|
|
49
|
+
Submit and monitor an AWS Batch job for ML pipeline execution.
|
|
50
|
+
|
|
51
|
+
Uploads script to S3, submits Batch job, monitors until completion or 2 minutes of RUNNING.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
script_path: Local path to the ML pipeline script
|
|
55
|
+
size: Job size tier - "small" (default), "medium", or "large"
|
|
56
|
+
- small: 2 vCPU, 4GB RAM for lightweight processing
|
|
57
|
+
- medium: 4 vCPU, 8GB RAM for standard ML workloads
|
|
58
|
+
- large: 8 vCPU, 16GB RAM for heavy training/inference
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
Exit code (0 for success/disconnected, non-zero for failure)
|
|
62
|
+
"""
|
|
63
|
+
if size not in ["small", "medium", "large"]:
|
|
64
|
+
raise ValueError(f"Invalid size '{size}'. Must be 'small', 'medium', or 'large'")
|
|
65
|
+
|
|
66
|
+
batch = AWSAccountClamp().boto3_session.client("batch")
|
|
67
|
+
script_name = Path(script_path).stem
|
|
68
|
+
|
|
69
|
+
# Upload script to S3
|
|
70
|
+
s3_path = f"s3://{workbench_bucket}/batch-jobs/{Path(script_path).name}"
|
|
71
|
+
log.info(f"Uploading script to {s3_path}")
|
|
72
|
+
upload_content_to_s3(Path(script_path).read_text(), s3_path)
|
|
73
|
+
|
|
74
|
+
# Submit job
|
|
75
|
+
job_name = f"workbench_{script_name}_{datetime.now():%Y%m%d_%H%M%S}"
|
|
76
|
+
response = batch.submit_job(
|
|
77
|
+
jobName=job_name,
|
|
78
|
+
jobQueue="workbench-job-queue",
|
|
79
|
+
jobDefinition=f"workbench-batch-{size}",
|
|
80
|
+
containerOverrides={
|
|
81
|
+
"environment": [
|
|
82
|
+
{"name": "ML_PIPELINE_S3_PATH", "value": s3_path},
|
|
83
|
+
{"name": "WORKBENCH_BUCKET", "value": workbench_bucket},
|
|
84
|
+
]
|
|
85
|
+
},
|
|
86
|
+
)
|
|
87
|
+
job_id = response["jobId"]
|
|
88
|
+
log.info(f"Submitted job: {job_name} ({job_id}) using {size} tier")
|
|
89
|
+
|
|
90
|
+
# Monitor job
|
|
91
|
+
last_status, running_start = None, None
|
|
92
|
+
while True:
|
|
93
|
+
job = batch.describe_jobs(jobs=[job_id])["jobs"][0]
|
|
94
|
+
status = job["status"]
|
|
95
|
+
|
|
96
|
+
if status != last_status:
|
|
97
|
+
log.info(f"Job status: {status}")
|
|
98
|
+
last_status = status
|
|
99
|
+
if status == "RUNNING":
|
|
100
|
+
running_start = time.time()
|
|
101
|
+
|
|
102
|
+
# Disconnect after 2 minutes of running
|
|
103
|
+
if status == "RUNNING" and running_start and (time.time() - running_start >= 120):
|
|
104
|
+
log.info("ā
ML Pipeline is running successfully!")
|
|
105
|
+
_log_cloudwatch_link(job, "š Monitor logs")
|
|
106
|
+
return 0
|
|
107
|
+
|
|
108
|
+
# Handle completion
|
|
109
|
+
if status in ["SUCCEEDED", "FAILED"]:
|
|
110
|
+
exit_code = job.get("attempts", [{}])[-1].get("exitCode", 1)
|
|
111
|
+
msg = (
|
|
112
|
+
"Job completed successfully"
|
|
113
|
+
if status == "SUCCEEDED"
|
|
114
|
+
else f"Job failed: {job.get('statusReason', 'Unknown')}"
|
|
115
|
+
)
|
|
116
|
+
log.info(msg) if status == "SUCCEEDED" else log.error(msg)
|
|
117
|
+
_log_cloudwatch_link(job)
|
|
118
|
+
return exit_code
|
|
119
|
+
|
|
120
|
+
time.sleep(10)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def main():
|
|
124
|
+
"""CLI entry point for running ML pipelines on AWS Batch."""
|
|
125
|
+
parser = argparse.ArgumentParser(description="Run ML pipeline script on AWS Batch")
|
|
126
|
+
parser.add_argument("script_file", help="Local path to ML pipeline script")
|
|
127
|
+
args = parser.parse_args()
|
|
128
|
+
try:
|
|
129
|
+
exit_code = run_batch_job(args.script_file)
|
|
130
|
+
exit(exit_code)
|
|
131
|
+
except Exception as e:
|
|
132
|
+
log.error(f"Error: {e}")
|
|
133
|
+
exit(1)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
if __name__ == "__main__":
|
|
137
|
+
main()
|
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import logging
|
|
3
|
+
import json
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
# Workbench Imports
|
|
7
|
+
from workbench.core.cloud_platform.aws.aws_account_clamp import AWSAccountClamp
|
|
8
|
+
from workbench.utils.config_manager import ConfigManager
|
|
9
|
+
from workbench.utils.s3_utils import upload_content_to_s3
|
|
10
|
+
|
|
11
|
+
log = logging.getLogger("workbench")
|
|
12
|
+
cm = ConfigManager()
|
|
13
|
+
workbench_bucket = cm.get_config("WORKBENCH_BUCKET")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def submit_to_sqs(
|
|
17
|
+
script_path: str,
|
|
18
|
+
size: str = "small",
|
|
19
|
+
realtime: bool = False,
|
|
20
|
+
dt: bool = False,
|
|
21
|
+
promote: bool = False,
|
|
22
|
+
) -> None:
|
|
23
|
+
"""
|
|
24
|
+
Upload script to S3 and submit message to SQS queue for processing.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
script_path: Local path to the ML pipeline script
|
|
28
|
+
size: Job size tier - "small" (default), "medium", or "large"
|
|
29
|
+
realtime: If True, sets serverless=False for real-time processing (default: False)
|
|
30
|
+
dt: If True, sets DT=True in environment (default: False)
|
|
31
|
+
promote: If True, sets PROMOTE=True in environment (default: False)
|
|
32
|
+
|
|
33
|
+
Raises:
|
|
34
|
+
ValueError: If size is invalid or script file not found
|
|
35
|
+
"""
|
|
36
|
+
print(f"\n{'=' * 60}")
|
|
37
|
+
print("š SUBMITTING ML PIPELINE JOB")
|
|
38
|
+
print(f"{'=' * 60}")
|
|
39
|
+
if size not in ["small", "medium", "large"]:
|
|
40
|
+
raise ValueError(f"Invalid size '{size}'. Must be 'small', 'medium', or 'large'")
|
|
41
|
+
|
|
42
|
+
# Validate script exists
|
|
43
|
+
script_file = Path(script_path)
|
|
44
|
+
if not script_file.exists():
|
|
45
|
+
raise FileNotFoundError(f"Script not found: {script_path}")
|
|
46
|
+
|
|
47
|
+
print(f"š Script: {script_file.name}")
|
|
48
|
+
print(f"š Size tier: {size}")
|
|
49
|
+
print(f"ā” Mode: {'Real-time' if realtime else 'Serverless'} (serverless={'False' if realtime else 'True'})")
|
|
50
|
+
print(f"š DynamicTraining: {dt}")
|
|
51
|
+
print(f"š Promote: {promote}")
|
|
52
|
+
print(f"šŖ£ Bucket: {workbench_bucket}")
|
|
53
|
+
sqs = AWSAccountClamp().boto3_session.client("sqs")
|
|
54
|
+
script_name = script_file.name
|
|
55
|
+
|
|
56
|
+
# List Workbench queues
|
|
57
|
+
print("\nš Listing Workbench SQS queues...")
|
|
58
|
+
try:
|
|
59
|
+
queues = sqs.list_queues(QueueNamePrefix="workbench-")
|
|
60
|
+
queue_urls = queues.get("QueueUrls", [])
|
|
61
|
+
if queue_urls:
|
|
62
|
+
print(f"ā
Found {len(queue_urls)} workbench queue(s):")
|
|
63
|
+
for url in queue_urls:
|
|
64
|
+
queue_name = url.split("/")[-1]
|
|
65
|
+
print(f" ⢠{queue_name}")
|
|
66
|
+
else:
|
|
67
|
+
print("ā ļø No workbench queues found")
|
|
68
|
+
except Exception as e:
|
|
69
|
+
print(f"ā Error listing queues: {e}")
|
|
70
|
+
|
|
71
|
+
# Upload script to S3
|
|
72
|
+
s3_path = f"s3://{workbench_bucket}/batch-jobs/{script_name}"
|
|
73
|
+
print("\nš¤ Uploading script to S3...")
|
|
74
|
+
print(f" Source: {script_path}")
|
|
75
|
+
print(f" Destination: {s3_path}")
|
|
76
|
+
|
|
77
|
+
try:
|
|
78
|
+
upload_content_to_s3(script_file.read_text(), s3_path)
|
|
79
|
+
print("ā
Script uploaded successfully")
|
|
80
|
+
except Exception as e:
|
|
81
|
+
print(f"ā Upload failed: {e}")
|
|
82
|
+
raise
|
|
83
|
+
# Get queue URL and info
|
|
84
|
+
queue_name = "workbench-ml-pipeline-queue.fifo"
|
|
85
|
+
print("\nšÆ Getting queue information...")
|
|
86
|
+
print(f" Queue name: {queue_name}")
|
|
87
|
+
|
|
88
|
+
try:
|
|
89
|
+
queue_url = sqs.get_queue_url(QueueName=queue_name)["QueueUrl"]
|
|
90
|
+
print(f" Queue URL: {queue_url}")
|
|
91
|
+
|
|
92
|
+
# Get queue attributes for additional info
|
|
93
|
+
attrs = sqs.get_queue_attributes(
|
|
94
|
+
QueueUrl=queue_url, AttributeNames=["ApproximateNumberOfMessages", "ApproximateNumberOfMessagesNotVisible"]
|
|
95
|
+
)
|
|
96
|
+
messages_available = attrs["Attributes"].get("ApproximateNumberOfMessages", "0")
|
|
97
|
+
messages_in_flight = attrs["Attributes"].get("ApproximateNumberOfMessagesNotVisible", "0")
|
|
98
|
+
print(f" Messages in queue: {messages_available}")
|
|
99
|
+
print(f" Messages in flight: {messages_in_flight}")
|
|
100
|
+
|
|
101
|
+
except Exception as e:
|
|
102
|
+
print(f"ā Error accessing queue: {e}")
|
|
103
|
+
raise
|
|
104
|
+
|
|
105
|
+
# Prepare message
|
|
106
|
+
message = {"script_path": s3_path, "size": size}
|
|
107
|
+
|
|
108
|
+
# Set environment variables
|
|
109
|
+
message["environment"] = {
|
|
110
|
+
"SERVERLESS": "False" if realtime else "True",
|
|
111
|
+
"DT": str(dt),
|
|
112
|
+
"PROMOTE": str(promote),
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
# Send the message to SQS
|
|
116
|
+
try:
|
|
117
|
+
print("\nšØ Sending message to SQS...")
|
|
118
|
+
response = sqs.send_message(
|
|
119
|
+
QueueUrl=queue_url,
|
|
120
|
+
MessageBody=json.dumps(message, indent=2),
|
|
121
|
+
MessageGroupId="ml-pipeline-jobs", # Required for FIFO
|
|
122
|
+
)
|
|
123
|
+
message_id = response["MessageId"]
|
|
124
|
+
print("ā
Message sent successfully!")
|
|
125
|
+
print(f" Message ID: {message_id}")
|
|
126
|
+
except Exception as e:
|
|
127
|
+
print(f"ā Failed to send message: {e}")
|
|
128
|
+
raise
|
|
129
|
+
|
|
130
|
+
# Success summary
|
|
131
|
+
print(f"\n{'=' * 60}")
|
|
132
|
+
print("ā
JOB SUBMISSION COMPLETE")
|
|
133
|
+
print(f"{'=' * 60}")
|
|
134
|
+
print(f"š Script: {script_name}")
|
|
135
|
+
print(f"š Size: {size}")
|
|
136
|
+
print(f"ā” Mode: {'Real-time' if realtime else 'Serverless'} (SERVERLESS={'False' if realtime else 'True'})")
|
|
137
|
+
print(f"š DynamicTraining: {dt}")
|
|
138
|
+
print(f"š Promote: {promote}")
|
|
139
|
+
print(f"š Message ID: {message_id}")
|
|
140
|
+
print("\nš MONITORING LOCATIONS:")
|
|
141
|
+
print(f" ⢠SQS Queue: AWS Console ā SQS ā {queue_name}")
|
|
142
|
+
print(" ⢠Lambda Logs: AWS Console ā Lambda ā Functions")
|
|
143
|
+
print(" ⢠Batch Jobs: AWS Console ā Batch ā Jobs")
|
|
144
|
+
print(" ⢠CloudWatch: AWS Console ā CloudWatch ā Log groups")
|
|
145
|
+
print("\nā³ Your job should start processing soon...")
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def main():
|
|
149
|
+
"""CLI entry point for submitting ML pipelines via SQS."""
|
|
150
|
+
parser = argparse.ArgumentParser(description="Submit ML pipeline to SQS queue for Batch processing")
|
|
151
|
+
parser.add_argument("script_file", help="Local path to ML pipeline script")
|
|
152
|
+
parser.add_argument(
|
|
153
|
+
"--size", default="small", choices=["small", "medium", "large"], help="Job size tier (default: small)"
|
|
154
|
+
)
|
|
155
|
+
parser.add_argument(
|
|
156
|
+
"--realtime",
|
|
157
|
+
action="store_true",
|
|
158
|
+
help="Create realtime endpoints (default is serverless)",
|
|
159
|
+
)
|
|
160
|
+
parser.add_argument(
|
|
161
|
+
"--dt",
|
|
162
|
+
action="store_true",
|
|
163
|
+
help="Set DT=True (models and endpoints will have '-dt' suffix)",
|
|
164
|
+
)
|
|
165
|
+
parser.add_argument(
|
|
166
|
+
"--promote",
|
|
167
|
+
action="store_true",
|
|
168
|
+
help="Set Promote=True (models and endpoints will use promoted naming",
|
|
169
|
+
)
|
|
170
|
+
args = parser.parse_args()
|
|
171
|
+
try:
|
|
172
|
+
submit_to_sqs(
|
|
173
|
+
args.script_file,
|
|
174
|
+
args.size,
|
|
175
|
+
realtime=args.realtime,
|
|
176
|
+
dt=args.dt,
|
|
177
|
+
promote=args.promote,
|
|
178
|
+
)
|
|
179
|
+
except Exception as e:
|
|
180
|
+
print(f"\nā ERROR: {e}")
|
|
181
|
+
log.error(f"Error: {e}")
|
|
182
|
+
exit(1)
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
if __name__ == "__main__":
|
|
186
|
+
main()
|