workbench 0.8.161__py3-none-any.whl → 0.8.192__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- workbench/algorithms/dataframe/proximity.py +143 -102
- workbench/algorithms/graph/light/proximity_graph.py +2 -1
- workbench/api/compound.py +1 -1
- workbench/api/endpoint.py +12 -0
- workbench/api/feature_set.py +4 -4
- workbench/api/meta.py +5 -2
- workbench/api/model.py +16 -12
- workbench/api/monitor.py +1 -16
- workbench/core/artifacts/artifact.py +11 -3
- workbench/core/artifacts/data_capture_core.py +355 -0
- workbench/core/artifacts/endpoint_core.py +168 -78
- workbench/core/artifacts/feature_set_core.py +72 -13
- workbench/core/artifacts/model_core.py +50 -15
- workbench/core/artifacts/monitor_core.py +33 -248
- workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
- workbench/core/cloud_platform/aws/aws_meta.py +12 -5
- workbench/core/cloud_platform/aws/aws_session.py +4 -4
- workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
- workbench/core/transforms/features_to_model/features_to_model.py +9 -4
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +36 -6
- workbench/core/transforms/pandas_transforms/pandas_to_features.py +27 -0
- workbench/core/views/training_view.py +49 -53
- workbench/core/views/view.py +51 -1
- workbench/core/views/view_utils.py +4 -4
- workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
- workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
- workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +3 -5
- workbench/model_scripts/custom_models/proximity/proximity.py +143 -102
- workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
- workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +10 -17
- workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
- workbench/model_scripts/custom_models/uq_models/meta_uq.template +156 -58
- workbench/model_scripts/custom_models/uq_models/ngboost.template +20 -14
- workbench/model_scripts/custom_models/uq_models/proximity.py +143 -102
- workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
- workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +5 -13
- workbench/model_scripts/pytorch_model/pytorch.template +19 -20
- workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
- workbench/model_scripts/script_generation.py +7 -2
- workbench/model_scripts/uq_models/mapie.template +492 -0
- workbench/model_scripts/uq_models/requirements.txt +1 -0
- workbench/model_scripts/xgb_model/xgb_model.template +31 -40
- workbench/repl/workbench_shell.py +11 -6
- workbench/scripts/lambda_launcher.py +63 -0
- workbench/scripts/ml_pipeline_batch.py +137 -0
- workbench/scripts/ml_pipeline_sqs.py +186 -0
- workbench/scripts/monitor_cloud_watch.py +20 -100
- workbench/utils/aws_utils.py +4 -3
- workbench/utils/chem_utils/__init__.py +0 -0
- workbench/utils/chem_utils/fingerprints.py +134 -0
- workbench/utils/chem_utils/misc.py +194 -0
- workbench/utils/chem_utils/mol_descriptors.py +483 -0
- workbench/utils/chem_utils/mol_standardize.py +450 -0
- workbench/utils/chem_utils/mol_tagging.py +348 -0
- workbench/utils/chem_utils/projections.py +209 -0
- workbench/utils/chem_utils/salts.py +256 -0
- workbench/utils/chem_utils/sdf.py +292 -0
- workbench/utils/chem_utils/toxicity.py +250 -0
- workbench/utils/chem_utils/vis.py +253 -0
- workbench/utils/cloudwatch_handler.py +1 -1
- workbench/utils/cloudwatch_utils.py +137 -0
- workbench/utils/config_manager.py +3 -7
- workbench/utils/endpoint_utils.py +5 -7
- workbench/utils/license_manager.py +2 -6
- workbench/utils/model_utils.py +76 -30
- workbench/utils/monitor_utils.py +44 -62
- workbench/utils/pandas_utils.py +3 -3
- workbench/utils/shap_utils.py +10 -2
- workbench/utils/workbench_logging.py +0 -3
- workbench/utils/workbench_sqs.py +1 -1
- workbench/utils/xgboost_model_utils.py +283 -145
- workbench/web_interface/components/plugins/dashboard_status.py +3 -1
- workbench/web_interface/components/plugins/generated_compounds.py +1 -1
- workbench/web_interface/components/plugins/scatter_plot.py +3 -3
- {workbench-0.8.161.dist-info → workbench-0.8.192.dist-info}/METADATA +4 -4
- {workbench-0.8.161.dist-info → workbench-0.8.192.dist-info}/RECORD +81 -76
- {workbench-0.8.161.dist-info → workbench-0.8.192.dist-info}/entry_points.txt +3 -0
- workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
- workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
- workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
- workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
- workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
- workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
- workbench/model_scripts/pytorch_model/generated_model_script.py +0 -565
- workbench/model_scripts/quant_regression/quant_regression.template +0 -279
- workbench/model_scripts/quant_regression/requirements.txt +0 -1
- workbench/model_scripts/scikit_learn/generated_model_script.py +0 -307
- workbench/model_scripts/xgb_model/generated_model_script.py +0 -477
- workbench/utils/chem_utils.py +0 -1556
- workbench/utils/execution_environment.py +0 -211
- workbench/utils/fast_inference.py +0 -167
- workbench/utils/resource_utils.py +0 -39
- {workbench-0.8.161.dist-info → workbench-0.8.192.dist-info}/WHEEL +0 -0
- {workbench-0.8.161.dist-info → workbench-0.8.192.dist-info}/licenses/LICENSE +0 -0
- {workbench-0.8.161.dist-info → workbench-0.8.192.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,6 @@
|
|
|
1
|
+
import IPython
|
|
1
2
|
from IPython import start_ipython
|
|
3
|
+
from distutils.version import LooseVersion
|
|
2
4
|
from IPython.terminal.prompts import Prompts
|
|
3
5
|
from IPython.terminal.ipapp import load_default_config
|
|
4
6
|
from pygments.token import Token
|
|
@@ -39,7 +41,7 @@ from workbench.cached.cached_meta import CachedMeta
|
|
|
39
41
|
try:
|
|
40
42
|
import rdkit # noqa
|
|
41
43
|
import mordred # noqa
|
|
42
|
-
from workbench.utils import
|
|
44
|
+
from workbench.utils.chem_utils import vis
|
|
43
45
|
|
|
44
46
|
HAVE_CHEM_UTILS = True
|
|
45
47
|
except ImportError:
|
|
@@ -70,7 +72,7 @@ if not ConfigManager().config_okay():
|
|
|
70
72
|
|
|
71
73
|
# Set the log level to important
|
|
72
74
|
log = logging.getLogger("workbench")
|
|
73
|
-
log.setLevel(
|
|
75
|
+
log.setLevel(logging.INFO)
|
|
74
76
|
log.addFilter(
|
|
75
77
|
lambda record: not (
|
|
76
78
|
record.getMessage().startswith("Async: Metadata") or record.getMessage().startswith("Updated Metadata")
|
|
@@ -176,12 +178,12 @@ class WorkbenchShell:
|
|
|
176
178
|
|
|
177
179
|
# Add cheminformatics utils if available
|
|
178
180
|
if HAVE_CHEM_UTILS:
|
|
179
|
-
self.commands["show"] =
|
|
181
|
+
self.commands["show"] = vis.show
|
|
180
182
|
|
|
181
183
|
def start(self):
|
|
182
184
|
"""Start the Workbench IPython shell"""
|
|
183
185
|
cprint("magenta", "\nWelcome to Workbench!")
|
|
184
|
-
if self.aws_status
|
|
186
|
+
if not self.aws_status:
|
|
185
187
|
cprint("red", "AWS Account Connection Failed...Review/Fix the Workbench Config:")
|
|
186
188
|
cprint("red", f"Path: {self.cm.site_config_path}")
|
|
187
189
|
self.show_config()
|
|
@@ -202,7 +204,10 @@ class WorkbenchShell:
|
|
|
202
204
|
|
|
203
205
|
# Start IPython with the config and commands in the namespace
|
|
204
206
|
try:
|
|
205
|
-
|
|
207
|
+
if LooseVersion(IPython.__version__) >= LooseVersion("9.0.0"):
|
|
208
|
+
ipython_argv = ["--no-tip", "--theme", "linux"]
|
|
209
|
+
else:
|
|
210
|
+
ipython_argv = []
|
|
206
211
|
start_ipython(ipython_argv, user_ns=locs, config=config)
|
|
207
212
|
finally:
|
|
208
213
|
spinner = self.spinner_start("Goodbye to AWS:")
|
|
@@ -555,7 +560,7 @@ class WorkbenchShell:
|
|
|
555
560
|
from workbench.web_interface.components.plugin_unit_test import PluginUnitTest
|
|
556
561
|
|
|
557
562
|
# Get kwargs
|
|
558
|
-
theme = kwargs.get("theme", "
|
|
563
|
+
theme = kwargs.get("theme", "midnight_blue")
|
|
559
564
|
|
|
560
565
|
plugin_test = PluginUnitTest(plugin_class, theme=theme, input_data=data, **kwargs)
|
|
561
566
|
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
import os
|
|
3
|
+
import json
|
|
4
|
+
import importlib.util
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def main():
|
|
8
|
+
if len(sys.argv) != 2:
|
|
9
|
+
print("Usage: lambda_launcher <handler_module_name>")
|
|
10
|
+
print("\nOptional: testing/event.json with test event")
|
|
11
|
+
print("Optional: testing/env.json with environment variables")
|
|
12
|
+
sys.exit(1)
|
|
13
|
+
|
|
14
|
+
handler_file = sys.argv[1]
|
|
15
|
+
|
|
16
|
+
# Add .py if not present
|
|
17
|
+
if not handler_file.endswith(".py"):
|
|
18
|
+
handler_file += ".py"
|
|
19
|
+
|
|
20
|
+
# Check if file exists
|
|
21
|
+
if not os.path.exists(handler_file):
|
|
22
|
+
print(f"Error: File '{handler_file}' not found")
|
|
23
|
+
sys.exit(1)
|
|
24
|
+
|
|
25
|
+
# Load environment variables from env.json if it exists
|
|
26
|
+
if os.path.exists("testing/env.json"):
|
|
27
|
+
print("Loading environment variables from testing/env.json")
|
|
28
|
+
with open("testing/env.json") as f:
|
|
29
|
+
env_vars = json.load(f)
|
|
30
|
+
for key, value in env_vars.items():
|
|
31
|
+
os.environ[key] = value
|
|
32
|
+
print(f" Set {key} = {value}")
|
|
33
|
+
print()
|
|
34
|
+
|
|
35
|
+
# Load event configuration
|
|
36
|
+
if os.path.exists("testing/event.json"):
|
|
37
|
+
print("Loading event from testing/event.json")
|
|
38
|
+
with open("testing/event.json") as f:
|
|
39
|
+
event = json.load(f)
|
|
40
|
+
else:
|
|
41
|
+
print("No testing/event.json found, using empty event")
|
|
42
|
+
event = {}
|
|
43
|
+
|
|
44
|
+
# Load the module dynamically
|
|
45
|
+
spec = importlib.util.spec_from_file_location("lambda_module", handler_file)
|
|
46
|
+
lambda_module = importlib.util.module_from_spec(spec)
|
|
47
|
+
spec.loader.exec_module(lambda_module)
|
|
48
|
+
|
|
49
|
+
# Call the lambda_handler
|
|
50
|
+
print(f"Invoking lambda_handler from {handler_file}...")
|
|
51
|
+
print("-" * 50)
|
|
52
|
+
print(f"Event: {json.dumps(event, indent=2)}")
|
|
53
|
+
print("-" * 50)
|
|
54
|
+
|
|
55
|
+
result = lambda_module.lambda_handler(event, {})
|
|
56
|
+
|
|
57
|
+
print("-" * 50)
|
|
58
|
+
print("Result:")
|
|
59
|
+
print(json.dumps(result, indent=2))
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
if __name__ == "__main__":
|
|
63
|
+
main()
|
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import logging
|
|
3
|
+
import time
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
# Workbench Imports
|
|
8
|
+
from workbench.core.cloud_platform.aws.aws_account_clamp import AWSAccountClamp
|
|
9
|
+
from workbench.utils.config_manager import ConfigManager
|
|
10
|
+
from workbench.utils.s3_utils import upload_content_to_s3
|
|
11
|
+
from workbench.utils.cloudwatch_utils import get_cloudwatch_logs_url
|
|
12
|
+
|
|
13
|
+
log = logging.getLogger("workbench")
|
|
14
|
+
cm = ConfigManager()
|
|
15
|
+
workbench_bucket = cm.get_config("WORKBENCH_BUCKET")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def get_ecr_image_uri() -> str:
|
|
19
|
+
"""Get the ECR image URI for the current region."""
|
|
20
|
+
region = AWSAccountClamp().region
|
|
21
|
+
return f"507740646243.dkr.ecr.{region}.amazonaws.com/aws-ml-images/py312-ml-pipelines:0.1"
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def get_batch_role_arn() -> str:
|
|
25
|
+
"""Get the Batch execution role ARN."""
|
|
26
|
+
account_id = AWSAccountClamp().account_id
|
|
27
|
+
return f"arn:aws:iam::{account_id}:role/Workbench-BatchRole"
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _log_cloudwatch_link(job: dict, message_prefix: str = "View logs") -> None:
|
|
31
|
+
"""
|
|
32
|
+
Helper method to log CloudWatch logs link with clickable URL and full URL display.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
job: Batch job description dictionary
|
|
36
|
+
message_prefix: Prefix for the log message (default: "View logs")
|
|
37
|
+
"""
|
|
38
|
+
log_stream = job.get("container", {}).get("logStreamName")
|
|
39
|
+
logs_url = get_cloudwatch_logs_url(log_group="/aws/batch/job", log_stream=log_stream)
|
|
40
|
+
if logs_url:
|
|
41
|
+
clickable_url = f"\033]8;;{logs_url}\033\\{logs_url}\033]8;;\033\\"
|
|
42
|
+
log.info(f"{message_prefix}: {clickable_url}")
|
|
43
|
+
else:
|
|
44
|
+
log.info("Check AWS Batch console for logs")
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def run_batch_job(script_path: str, size: str = "small") -> int:
|
|
48
|
+
"""
|
|
49
|
+
Submit and monitor an AWS Batch job for ML pipeline execution.
|
|
50
|
+
|
|
51
|
+
Uploads script to S3, submits Batch job, monitors until completion or 2 minutes of RUNNING.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
script_path: Local path to the ML pipeline script
|
|
55
|
+
size: Job size tier - "small" (default), "medium", or "large"
|
|
56
|
+
- small: 2 vCPU, 4GB RAM for lightweight processing
|
|
57
|
+
- medium: 4 vCPU, 8GB RAM for standard ML workloads
|
|
58
|
+
- large: 8 vCPU, 16GB RAM for heavy training/inference
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
Exit code (0 for success/disconnected, non-zero for failure)
|
|
62
|
+
"""
|
|
63
|
+
if size not in ["small", "medium", "large"]:
|
|
64
|
+
raise ValueError(f"Invalid size '{size}'. Must be 'small', 'medium', or 'large'")
|
|
65
|
+
|
|
66
|
+
batch = AWSAccountClamp().boto3_session.client("batch")
|
|
67
|
+
script_name = Path(script_path).stem
|
|
68
|
+
|
|
69
|
+
# Upload script to S3
|
|
70
|
+
s3_path = f"s3://{workbench_bucket}/batch-jobs/{Path(script_path).name}"
|
|
71
|
+
log.info(f"Uploading script to {s3_path}")
|
|
72
|
+
upload_content_to_s3(Path(script_path).read_text(), s3_path)
|
|
73
|
+
|
|
74
|
+
# Submit job
|
|
75
|
+
job_name = f"workbench_{script_name}_{datetime.now():%Y%m%d_%H%M%S}"
|
|
76
|
+
response = batch.submit_job(
|
|
77
|
+
jobName=job_name,
|
|
78
|
+
jobQueue="workbench-job-queue",
|
|
79
|
+
jobDefinition=f"workbench-batch-{size}",
|
|
80
|
+
containerOverrides={
|
|
81
|
+
"environment": [
|
|
82
|
+
{"name": "ML_PIPELINE_S3_PATH", "value": s3_path},
|
|
83
|
+
{"name": "WORKBENCH_BUCKET", "value": workbench_bucket},
|
|
84
|
+
]
|
|
85
|
+
},
|
|
86
|
+
)
|
|
87
|
+
job_id = response["jobId"]
|
|
88
|
+
log.info(f"Submitted job: {job_name} ({job_id}) using {size} tier")
|
|
89
|
+
|
|
90
|
+
# Monitor job
|
|
91
|
+
last_status, running_start = None, None
|
|
92
|
+
while True:
|
|
93
|
+
job = batch.describe_jobs(jobs=[job_id])["jobs"][0]
|
|
94
|
+
status = job["status"]
|
|
95
|
+
|
|
96
|
+
if status != last_status:
|
|
97
|
+
log.info(f"Job status: {status}")
|
|
98
|
+
last_status = status
|
|
99
|
+
if status == "RUNNING":
|
|
100
|
+
running_start = time.time()
|
|
101
|
+
|
|
102
|
+
# Disconnect after 2 minutes of running
|
|
103
|
+
if status == "RUNNING" and running_start and (time.time() - running_start >= 120):
|
|
104
|
+
log.info("✅ ML Pipeline is running successfully!")
|
|
105
|
+
_log_cloudwatch_link(job, "📊 Monitor logs")
|
|
106
|
+
return 0
|
|
107
|
+
|
|
108
|
+
# Handle completion
|
|
109
|
+
if status in ["SUCCEEDED", "FAILED"]:
|
|
110
|
+
exit_code = job.get("attempts", [{}])[-1].get("exitCode", 1)
|
|
111
|
+
msg = (
|
|
112
|
+
"Job completed successfully"
|
|
113
|
+
if status == "SUCCEEDED"
|
|
114
|
+
else f"Job failed: {job.get('statusReason', 'Unknown')}"
|
|
115
|
+
)
|
|
116
|
+
log.info(msg) if status == "SUCCEEDED" else log.error(msg)
|
|
117
|
+
_log_cloudwatch_link(job)
|
|
118
|
+
return exit_code
|
|
119
|
+
|
|
120
|
+
time.sleep(10)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def main():
|
|
124
|
+
"""CLI entry point for running ML pipelines on AWS Batch."""
|
|
125
|
+
parser = argparse.ArgumentParser(description="Run ML pipeline script on AWS Batch")
|
|
126
|
+
parser.add_argument("script_file", help="Local path to ML pipeline script")
|
|
127
|
+
args = parser.parse_args()
|
|
128
|
+
try:
|
|
129
|
+
exit_code = run_batch_job(args.script_file)
|
|
130
|
+
exit(exit_code)
|
|
131
|
+
except Exception as e:
|
|
132
|
+
log.error(f"Error: {e}")
|
|
133
|
+
exit(1)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
if __name__ == "__main__":
|
|
137
|
+
main()
|
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import logging
|
|
3
|
+
import json
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
# Workbench Imports
|
|
7
|
+
from workbench.core.cloud_platform.aws.aws_account_clamp import AWSAccountClamp
|
|
8
|
+
from workbench.utils.config_manager import ConfigManager
|
|
9
|
+
from workbench.utils.s3_utils import upload_content_to_s3
|
|
10
|
+
|
|
11
|
+
log = logging.getLogger("workbench")
|
|
12
|
+
cm = ConfigManager()
|
|
13
|
+
workbench_bucket = cm.get_config("WORKBENCH_BUCKET")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def submit_to_sqs(
|
|
17
|
+
script_path: str,
|
|
18
|
+
size: str = "small",
|
|
19
|
+
realtime: bool = False,
|
|
20
|
+
dt: bool = False,
|
|
21
|
+
promote: bool = False,
|
|
22
|
+
) -> None:
|
|
23
|
+
"""
|
|
24
|
+
Upload script to S3 and submit message to SQS queue for processing.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
script_path: Local path to the ML pipeline script
|
|
28
|
+
size: Job size tier - "small" (default), "medium", or "large"
|
|
29
|
+
realtime: If True, sets serverless=False for real-time processing (default: False)
|
|
30
|
+
dt: If True, sets DT=True in environment (default: False)
|
|
31
|
+
promote: If True, sets PROMOTE=True in environment (default: False)
|
|
32
|
+
|
|
33
|
+
Raises:
|
|
34
|
+
ValueError: If size is invalid or script file not found
|
|
35
|
+
"""
|
|
36
|
+
print(f"\n{'=' * 60}")
|
|
37
|
+
print("🚀 SUBMITTING ML PIPELINE JOB")
|
|
38
|
+
print(f"{'=' * 60}")
|
|
39
|
+
if size not in ["small", "medium", "large"]:
|
|
40
|
+
raise ValueError(f"Invalid size '{size}'. Must be 'small', 'medium', or 'large'")
|
|
41
|
+
|
|
42
|
+
# Validate script exists
|
|
43
|
+
script_file = Path(script_path)
|
|
44
|
+
if not script_file.exists():
|
|
45
|
+
raise FileNotFoundError(f"Script not found: {script_path}")
|
|
46
|
+
|
|
47
|
+
print(f"📄 Script: {script_file.name}")
|
|
48
|
+
print(f"📏 Size tier: {size}")
|
|
49
|
+
print(f"⚡ Mode: {'Real-time' if realtime else 'Serverless'} (serverless={'False' if realtime else 'True'})")
|
|
50
|
+
print(f"🔄 DynamicTraining: {dt}")
|
|
51
|
+
print(f"🆕 Promote: {promote}")
|
|
52
|
+
print(f"🪣 Bucket: {workbench_bucket}")
|
|
53
|
+
sqs = AWSAccountClamp().boto3_session.client("sqs")
|
|
54
|
+
script_name = script_file.name
|
|
55
|
+
|
|
56
|
+
# List Workbench queues
|
|
57
|
+
print("\n📋 Listing Workbench SQS queues...")
|
|
58
|
+
try:
|
|
59
|
+
queues = sqs.list_queues(QueueNamePrefix="workbench-")
|
|
60
|
+
queue_urls = queues.get("QueueUrls", [])
|
|
61
|
+
if queue_urls:
|
|
62
|
+
print(f"✅ Found {len(queue_urls)} workbench queue(s):")
|
|
63
|
+
for url in queue_urls:
|
|
64
|
+
queue_name = url.split("/")[-1]
|
|
65
|
+
print(f" • {queue_name}")
|
|
66
|
+
else:
|
|
67
|
+
print("⚠️ No workbench queues found")
|
|
68
|
+
except Exception as e:
|
|
69
|
+
print(f"❌ Error listing queues: {e}")
|
|
70
|
+
|
|
71
|
+
# Upload script to S3
|
|
72
|
+
s3_path = f"s3://{workbench_bucket}/batch-jobs/{script_name}"
|
|
73
|
+
print("\n📤 Uploading script to S3...")
|
|
74
|
+
print(f" Source: {script_path}")
|
|
75
|
+
print(f" Destination: {s3_path}")
|
|
76
|
+
|
|
77
|
+
try:
|
|
78
|
+
upload_content_to_s3(script_file.read_text(), s3_path)
|
|
79
|
+
print("✅ Script uploaded successfully")
|
|
80
|
+
except Exception as e:
|
|
81
|
+
print(f"❌ Upload failed: {e}")
|
|
82
|
+
raise
|
|
83
|
+
# Get queue URL and info
|
|
84
|
+
queue_name = "workbench-ml-pipeline-queue.fifo"
|
|
85
|
+
print("\n🎯 Getting queue information...")
|
|
86
|
+
print(f" Queue name: {queue_name}")
|
|
87
|
+
|
|
88
|
+
try:
|
|
89
|
+
queue_url = sqs.get_queue_url(QueueName=queue_name)["QueueUrl"]
|
|
90
|
+
print(f" Queue URL: {queue_url}")
|
|
91
|
+
|
|
92
|
+
# Get queue attributes for additional info
|
|
93
|
+
attrs = sqs.get_queue_attributes(
|
|
94
|
+
QueueUrl=queue_url, AttributeNames=["ApproximateNumberOfMessages", "ApproximateNumberOfMessagesNotVisible"]
|
|
95
|
+
)
|
|
96
|
+
messages_available = attrs["Attributes"].get("ApproximateNumberOfMessages", "0")
|
|
97
|
+
messages_in_flight = attrs["Attributes"].get("ApproximateNumberOfMessagesNotVisible", "0")
|
|
98
|
+
print(f" Messages in queue: {messages_available}")
|
|
99
|
+
print(f" Messages in flight: {messages_in_flight}")
|
|
100
|
+
|
|
101
|
+
except Exception as e:
|
|
102
|
+
print(f"❌ Error accessing queue: {e}")
|
|
103
|
+
raise
|
|
104
|
+
|
|
105
|
+
# Prepare message
|
|
106
|
+
message = {"script_path": s3_path, "size": size}
|
|
107
|
+
|
|
108
|
+
# Set environment variables
|
|
109
|
+
message["environment"] = {
|
|
110
|
+
"SERVERLESS": "False" if realtime else "True",
|
|
111
|
+
"DT": str(dt),
|
|
112
|
+
"PROMOTE": str(promote),
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
# Send the message to SQS
|
|
116
|
+
try:
|
|
117
|
+
print("\n📨 Sending message to SQS...")
|
|
118
|
+
response = sqs.send_message(
|
|
119
|
+
QueueUrl=queue_url,
|
|
120
|
+
MessageBody=json.dumps(message, indent=2),
|
|
121
|
+
MessageGroupId="ml-pipeline-jobs", # Required for FIFO
|
|
122
|
+
)
|
|
123
|
+
message_id = response["MessageId"]
|
|
124
|
+
print("✅ Message sent successfully!")
|
|
125
|
+
print(f" Message ID: {message_id}")
|
|
126
|
+
except Exception as e:
|
|
127
|
+
print(f"❌ Failed to send message: {e}")
|
|
128
|
+
raise
|
|
129
|
+
|
|
130
|
+
# Success summary
|
|
131
|
+
print(f"\n{'=' * 60}")
|
|
132
|
+
print("✅ JOB SUBMISSION COMPLETE")
|
|
133
|
+
print(f"{'=' * 60}")
|
|
134
|
+
print(f"📄 Script: {script_name}")
|
|
135
|
+
print(f"📏 Size: {size}")
|
|
136
|
+
print(f"⚡ Mode: {'Real-time' if realtime else 'Serverless'} (SERVERLESS={'False' if realtime else 'True'})")
|
|
137
|
+
print(f"🔄 DynamicTraining: {dt}")
|
|
138
|
+
print(f"🆕 Promote: {promote}")
|
|
139
|
+
print(f"🆔 Message ID: {message_id}")
|
|
140
|
+
print("\n🔍 MONITORING LOCATIONS:")
|
|
141
|
+
print(f" • SQS Queue: AWS Console → SQS → {queue_name}")
|
|
142
|
+
print(" • Lambda Logs: AWS Console → Lambda → Functions")
|
|
143
|
+
print(" • Batch Jobs: AWS Console → Batch → Jobs")
|
|
144
|
+
print(" • CloudWatch: AWS Console → CloudWatch → Log groups")
|
|
145
|
+
print("\n⏳ Your job should start processing soon...")
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def main():
|
|
149
|
+
"""CLI entry point for submitting ML pipelines via SQS."""
|
|
150
|
+
parser = argparse.ArgumentParser(description="Submit ML pipeline to SQS queue for Batch processing")
|
|
151
|
+
parser.add_argument("script_file", help="Local path to ML pipeline script")
|
|
152
|
+
parser.add_argument(
|
|
153
|
+
"--size", default="small", choices=["small", "medium", "large"], help="Job size tier (default: small)"
|
|
154
|
+
)
|
|
155
|
+
parser.add_argument(
|
|
156
|
+
"--realtime",
|
|
157
|
+
action="store_true",
|
|
158
|
+
help="Create realtime endpoints (default is serverless)",
|
|
159
|
+
)
|
|
160
|
+
parser.add_argument(
|
|
161
|
+
"--dt",
|
|
162
|
+
action="store_true",
|
|
163
|
+
help="Set DT=True (models and endpoints will have '-dt' suffix)",
|
|
164
|
+
)
|
|
165
|
+
parser.add_argument(
|
|
166
|
+
"--promote",
|
|
167
|
+
action="store_true",
|
|
168
|
+
help="Set Promote=True (models and endpoints will use promoted naming",
|
|
169
|
+
)
|
|
170
|
+
args = parser.parse_args()
|
|
171
|
+
try:
|
|
172
|
+
submit_to_sqs(
|
|
173
|
+
args.script_file,
|
|
174
|
+
args.size,
|
|
175
|
+
realtime=args.realtime,
|
|
176
|
+
dt=args.dt,
|
|
177
|
+
promote=args.promote,
|
|
178
|
+
)
|
|
179
|
+
except Exception as e:
|
|
180
|
+
print(f"\n❌ ERROR: {e}")
|
|
181
|
+
log.error(f"Error: {e}")
|
|
182
|
+
exit(1)
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
if __name__ == "__main__":
|
|
186
|
+
main()
|
|
@@ -4,8 +4,10 @@ import sys
|
|
|
4
4
|
import time
|
|
5
5
|
import argparse
|
|
6
6
|
from datetime import datetime, timedelta, timezone
|
|
7
|
-
|
|
7
|
+
|
|
8
|
+
# Workbench Imports
|
|
8
9
|
from workbench.utils.repl_utils import cprint, Spinner
|
|
10
|
+
from workbench.utils.cloudwatch_utils import get_cloudwatch_client, get_active_log_streams, stream_log_events
|
|
9
11
|
|
|
10
12
|
# Define the log levels to include all log levels above the specified level
|
|
11
13
|
log_level_map = {
|
|
@@ -33,64 +35,6 @@ def date_display(dt):
|
|
|
33
35
|
return dt.strftime("%Y-%m-%d %I:%M%p") + "(UTC)"
|
|
34
36
|
|
|
35
37
|
|
|
36
|
-
def get_cloudwatch_client():
|
|
37
|
-
"""Get the CloudWatch Logs client using the Workbench assumed role session."""
|
|
38
|
-
session = AWSAccountClamp().boto3_session
|
|
39
|
-
return session.client("logs")
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
def get_active_log_streams(client, log_group_name, start_time_ms, stream_filter=None):
|
|
43
|
-
"""Retrieve log streams that have events after the specified start time."""
|
|
44
|
-
|
|
45
|
-
# Get all the streams in the log group
|
|
46
|
-
active_streams = []
|
|
47
|
-
stream_params = {
|
|
48
|
-
"logGroupName": log_group_name,
|
|
49
|
-
"orderBy": "LastEventTime",
|
|
50
|
-
"descending": True,
|
|
51
|
-
}
|
|
52
|
-
|
|
53
|
-
# Loop to retrieve all log streams (maximum 50 per call)
|
|
54
|
-
while True:
|
|
55
|
-
response = client.describe_log_streams(**stream_params)
|
|
56
|
-
log_streams = response.get("logStreams", [])
|
|
57
|
-
|
|
58
|
-
for log_stream in log_streams:
|
|
59
|
-
log_stream_name = log_stream["logStreamName"]
|
|
60
|
-
last_event_timestamp = log_stream.get("lastEventTimestamp")
|
|
61
|
-
|
|
62
|
-
# Include streams with events since the specified start time
|
|
63
|
-
# Note: There's some issue where the last event timestamp is 'off'
|
|
64
|
-
# so we're going to add 60 minutes from the last event timestamp
|
|
65
|
-
last_event_timestamp += 60 * 60 * 1000
|
|
66
|
-
if last_event_timestamp >= start_time_ms:
|
|
67
|
-
active_streams.append(log_stream_name)
|
|
68
|
-
else:
|
|
69
|
-
break # Stop if we reach streams older than the start time
|
|
70
|
-
|
|
71
|
-
# Check if there are more streams to retrieve
|
|
72
|
-
if "nextToken" in response:
|
|
73
|
-
stream_params["nextToken"] = response["nextToken"]
|
|
74
|
-
else:
|
|
75
|
-
break
|
|
76
|
-
|
|
77
|
-
# Sort and report the active log streams
|
|
78
|
-
active_streams.sort()
|
|
79
|
-
if active_streams:
|
|
80
|
-
print("Active log streams:", len(active_streams))
|
|
81
|
-
|
|
82
|
-
# Filter the active streams by a substring if provided
|
|
83
|
-
if stream_filter and active_streams:
|
|
84
|
-
print(f"Filtering active log streams by '{stream_filter}'...")
|
|
85
|
-
active_streams = [stream for stream in active_streams if stream_filter in stream]
|
|
86
|
-
|
|
87
|
-
for stream in active_streams:
|
|
88
|
-
print(f"\t - {stream}")
|
|
89
|
-
|
|
90
|
-
# Return the active log streams
|
|
91
|
-
return active_streams
|
|
92
|
-
|
|
93
|
-
|
|
94
38
|
def get_latest_log_events(client, log_group_name, start_time, end_time=None, stream_filter=None):
|
|
95
39
|
"""Retrieve the latest log events from the active/filtered log streams in a CloudWatch Logs group."""
|
|
96
40
|
|
|
@@ -99,11 +43,15 @@ def get_latest_log_events(client, log_group_name, start_time, end_time=None, str
|
|
|
99
43
|
get_latest_log_events.first_run = True
|
|
100
44
|
|
|
101
45
|
log_events = []
|
|
102
|
-
start_time_ms = int(start_time.timestamp() * 1000)
|
|
46
|
+
start_time_ms = int(start_time.timestamp() * 1000)
|
|
47
|
+
|
|
48
|
+
# Use the util function to get active streams
|
|
49
|
+
active_streams = get_active_log_streams(log_group_name, start_time_ms, stream_filter, client)
|
|
103
50
|
|
|
104
|
-
# Get the active log streams with events since start_time
|
|
105
|
-
active_streams = get_active_log_streams(client, log_group_name, start_time_ms, stream_filter)
|
|
106
51
|
if active_streams:
|
|
52
|
+
print(f"Active log streams: {len(active_streams)}")
|
|
53
|
+
for stream in active_streams:
|
|
54
|
+
print(f"\t - {stream}")
|
|
107
55
|
print(f"Processing log events from {date_display(start_time)} on {len(active_streams)} active log streams...")
|
|
108
56
|
get_latest_log_events.first_run = False
|
|
109
57
|
else:
|
|
@@ -114,50 +62,22 @@ def get_latest_log_events(client, log_group_name, start_time, end_time=None, str
|
|
|
114
62
|
print("Monitoring for new events...")
|
|
115
63
|
return log_events
|
|
116
64
|
|
|
117
|
-
#
|
|
65
|
+
# Use the util function to stream events from each log stream
|
|
118
66
|
for log_stream_name in active_streams:
|
|
119
|
-
params = {
|
|
120
|
-
"logGroupName": log_group_name,
|
|
121
|
-
"logStreamName": log_stream_name,
|
|
122
|
-
"startTime": start_time_ms, # Use start_time in milliseconds
|
|
123
|
-
"startFromHead": True, # Start from the nearest event to start_time
|
|
124
|
-
}
|
|
125
|
-
next_event_token = None
|
|
126
|
-
if end_time is not None:
|
|
127
|
-
params["endTime"] = int(end_time.timestamp() * 1000)
|
|
128
|
-
|
|
129
|
-
# Process the log events from this log stream
|
|
130
67
|
spinner = Spinner("lightpurple", f"Pulling events from {log_stream_name}:")
|
|
131
68
|
spinner.start()
|
|
132
69
|
log_stream_events = 0
|
|
133
70
|
|
|
134
|
-
#
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
71
|
+
# Stream events using the util function
|
|
72
|
+
for event in stream_log_events(
|
|
73
|
+
log_group_name, log_stream_name, start_time, end_time, follow=False, client=client
|
|
74
|
+
):
|
|
75
|
+
log_stream_events += 1
|
|
76
|
+
log_events.append(event)
|
|
139
77
|
|
|
140
|
-
|
|
141
|
-
|
|
78
|
+
spinner.stop()
|
|
79
|
+
print(f"Processed {log_stream_events} events from {log_stream_name} (Total: {len(log_events)})")
|
|
142
80
|
|
|
143
|
-
events = events_response.get("events", [])
|
|
144
|
-
for event in events:
|
|
145
|
-
event["logStreamName"] = log_stream_name
|
|
146
|
-
|
|
147
|
-
# Add the log stream events to our list of all log events
|
|
148
|
-
log_stream_events += len(events)
|
|
149
|
-
log_events.extend(events)
|
|
150
|
-
|
|
151
|
-
# Handle pagination for log events
|
|
152
|
-
next_event_token = events_response.get("nextForwardToken")
|
|
153
|
-
|
|
154
|
-
# Break the loop if there are no more events to fetch
|
|
155
|
-
if not next_event_token or next_event_token == params.get("nextToken"):
|
|
156
|
-
spinner.stop()
|
|
157
|
-
print(f"Processed {log_stream_events} events from {log_stream_name} (Total: {len(log_events)})")
|
|
158
|
-
break
|
|
159
|
-
|
|
160
|
-
# Return the log events
|
|
161
81
|
return log_events
|
|
162
82
|
|
|
163
83
|
|
|
@@ -206,6 +126,7 @@ def monitor_log_group(
|
|
|
206
126
|
print(f"Monitoring log group: {log_group_name} from {date_display(start_time)}")
|
|
207
127
|
print(f"Log levels: {log_levels}")
|
|
208
128
|
print(f"Search terms: {search_terms}")
|
|
129
|
+
|
|
209
130
|
while True:
|
|
210
131
|
# Get the latest log events with stream filtering if provided
|
|
211
132
|
all_log_events = get_latest_log_events(client, log_group_name, start_time, end_time, stream_filter)
|
|
@@ -218,7 +139,6 @@ def monitor_log_group(
|
|
|
218
139
|
|
|
219
140
|
# Check the search terms
|
|
220
141
|
if not search_terms or any(term in event["message"].lower() for term in search_terms):
|
|
221
|
-
|
|
222
142
|
# Calculate the start and end index for this match
|
|
223
143
|
start_index = max(i - before, 0)
|
|
224
144
|
end_index = min(i + after, len(all_log_events) - 1)
|
workbench/utils/aws_utils.py
CHANGED
|
@@ -55,7 +55,8 @@ def aws_throttle(func=None, retry_intervals=None):
|
|
|
55
55
|
if func is None:
|
|
56
56
|
return lambda f: aws_throttle(f, retry_intervals=retry_intervals)
|
|
57
57
|
|
|
58
|
-
|
|
58
|
+
# This is currently commented out (we might want to use it later)
|
|
59
|
+
# service_hold_time = 2 # Seconds to wait before calling AWS function
|
|
59
60
|
default_intervals = [2**i for i in range(1, 9)] # Default exponential backoff: 2, 4, 8... 256 seconds
|
|
60
61
|
intervals = retry_intervals or default_intervals
|
|
61
62
|
|
|
@@ -64,8 +65,8 @@ def aws_throttle(func=None, retry_intervals=None):
|
|
|
64
65
|
for attempt, delay in enumerate(intervals, start=1):
|
|
65
66
|
try:
|
|
66
67
|
# Add sleep before calling AWS func if running as a service
|
|
67
|
-
if cm.running_as_service:
|
|
68
|
-
|
|
68
|
+
# if cm.running_as_service:
|
|
69
|
+
# time.sleep(service_hold_time)
|
|
69
70
|
return func(*args, **kwargs)
|
|
70
71
|
except ClientError as e:
|
|
71
72
|
if e.response["Error"]["Code"] == "ThrottlingException":
|
|
File without changes
|