workbench 0.8.162__py3-none-any.whl → 0.8.202__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (113) hide show
  1. workbench/algorithms/dataframe/__init__.py +1 -2
  2. workbench/algorithms/dataframe/fingerprint_proximity.py +2 -2
  3. workbench/algorithms/dataframe/proximity.py +261 -235
  4. workbench/algorithms/graph/light/proximity_graph.py +10 -8
  5. workbench/api/__init__.py +2 -1
  6. workbench/api/compound.py +1 -1
  7. workbench/api/endpoint.py +11 -0
  8. workbench/api/feature_set.py +11 -8
  9. workbench/api/meta.py +5 -2
  10. workbench/api/model.py +16 -15
  11. workbench/api/monitor.py +1 -16
  12. workbench/core/artifacts/__init__.py +11 -2
  13. workbench/core/artifacts/artifact.py +11 -3
  14. workbench/core/artifacts/data_capture_core.py +355 -0
  15. workbench/core/artifacts/endpoint_core.py +256 -118
  16. workbench/core/artifacts/feature_set_core.py +265 -16
  17. workbench/core/artifacts/model_core.py +107 -60
  18. workbench/core/artifacts/monitor_core.py +33 -248
  19. workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
  20. workbench/core/cloud_platform/aws/aws_meta.py +12 -5
  21. workbench/core/cloud_platform/aws/aws_parameter_store.py +18 -2
  22. workbench/core/cloud_platform/aws/aws_session.py +4 -4
  23. workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
  24. workbench/core/transforms/features_to_model/features_to_model.py +42 -32
  25. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +36 -6
  26. workbench/core/transforms/pandas_transforms/pandas_to_features.py +27 -0
  27. workbench/core/views/training_view.py +113 -42
  28. workbench/core/views/view.py +53 -3
  29. workbench/core/views/view_utils.py +4 -4
  30. workbench/model_scripts/chemprop/chemprop.template +852 -0
  31. workbench/model_scripts/chemprop/generated_model_script.py +852 -0
  32. workbench/model_scripts/chemprop/requirements.txt +11 -0
  33. workbench/model_scripts/custom_models/chem_info/fingerprints.py +134 -0
  34. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
  35. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
  36. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
  37. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
  38. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +3 -5
  39. workbench/model_scripts/custom_models/proximity/proximity.py +261 -235
  40. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  41. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  42. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  43. workbench/model_scripts/custom_models/uq_models/meta_uq.template +166 -62
  44. workbench/model_scripts/custom_models/uq_models/ngboost.template +30 -18
  45. workbench/model_scripts/custom_models/uq_models/proximity.py +261 -235
  46. workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
  47. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  48. workbench/model_scripts/pytorch_model/generated_model_script.py +373 -190
  49. workbench/model_scripts/pytorch_model/pytorch.template +370 -187
  50. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  51. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  52. workbench/model_scripts/script_generation.py +17 -9
  53. workbench/model_scripts/uq_models/generated_model_script.py +605 -0
  54. workbench/model_scripts/uq_models/mapie.template +605 -0
  55. workbench/model_scripts/uq_models/requirements.txt +1 -0
  56. workbench/model_scripts/xgb_model/generated_model_script.py +37 -46
  57. workbench/model_scripts/xgb_model/xgb_model.template +44 -46
  58. workbench/repl/workbench_shell.py +28 -14
  59. workbench/scripts/endpoint_test.py +162 -0
  60. workbench/scripts/lambda_test.py +73 -0
  61. workbench/scripts/ml_pipeline_batch.py +137 -0
  62. workbench/scripts/ml_pipeline_sqs.py +186 -0
  63. workbench/scripts/monitor_cloud_watch.py +20 -100
  64. workbench/utils/aws_utils.py +4 -3
  65. workbench/utils/chem_utils/__init__.py +0 -0
  66. workbench/utils/chem_utils/fingerprints.py +134 -0
  67. workbench/utils/chem_utils/misc.py +194 -0
  68. workbench/utils/chem_utils/mol_descriptors.py +483 -0
  69. workbench/utils/chem_utils/mol_standardize.py +450 -0
  70. workbench/utils/chem_utils/mol_tagging.py +348 -0
  71. workbench/utils/chem_utils/projections.py +209 -0
  72. workbench/utils/chem_utils/salts.py +256 -0
  73. workbench/utils/chem_utils/sdf.py +292 -0
  74. workbench/utils/chem_utils/toxicity.py +250 -0
  75. workbench/utils/chem_utils/vis.py +253 -0
  76. workbench/utils/chemprop_utils.py +760 -0
  77. workbench/utils/cloudwatch_handler.py +1 -1
  78. workbench/utils/cloudwatch_utils.py +137 -0
  79. workbench/utils/config_manager.py +3 -7
  80. workbench/utils/endpoint_utils.py +5 -7
  81. workbench/utils/license_manager.py +2 -6
  82. workbench/utils/model_utils.py +95 -34
  83. workbench/utils/monitor_utils.py +44 -62
  84. workbench/utils/pandas_utils.py +3 -3
  85. workbench/utils/pytorch_utils.py +526 -0
  86. workbench/utils/shap_utils.py +10 -2
  87. workbench/utils/workbench_logging.py +0 -3
  88. workbench/utils/workbench_sqs.py +1 -1
  89. workbench/utils/xgboost_model_utils.py +371 -156
  90. workbench/web_interface/components/model_plot.py +7 -1
  91. workbench/web_interface/components/plugin_unit_test.py +5 -2
  92. workbench/web_interface/components/plugins/dashboard_status.py +3 -1
  93. workbench/web_interface/components/plugins/generated_compounds.py +1 -1
  94. workbench/web_interface/components/plugins/model_details.py +9 -7
  95. workbench/web_interface/components/plugins/scatter_plot.py +3 -3
  96. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/METADATA +27 -6
  97. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/RECORD +101 -85
  98. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/entry_points.txt +4 -0
  99. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/licenses/LICENSE +1 -1
  100. workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
  101. workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
  102. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  103. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  104. workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
  105. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  106. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  107. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  108. workbench/utils/chem_utils.py +0 -1556
  109. workbench/utils/execution_environment.py +0 -211
  110. workbench/utils/fast_inference.py +0 -167
  111. workbench/utils/resource_utils.py +0 -39
  112. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/WHEEL +0 -0
  113. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,162 @@
1
+ """
2
+ Local test harness for SageMaker model scripts.
3
+
4
+ Usage:
5
+ python model_script_harness.py <local_script.py> <model_name>
6
+
7
+ Example:
8
+ python model_script_harness.py pytorch.py aqsol-pytorch-reg
9
+
10
+ This allows you to test LOCAL changes to a model script against deployed model artifacts.
11
+ Evaluation data is automatically pulled from the FeatureSet (training = FALSE rows).
12
+
13
+ Optional: testing/env.json with additional environment variables
14
+ """
15
+
16
+ import os
17
+ import sys
18
+ import json
19
+ import importlib.util
20
+ import tempfile
21
+ import shutil
22
+ import pandas as pd
23
+ import torch
24
+
25
+ # Workbench Imports
26
+ from workbench.api import Model, FeatureSet
27
+ from workbench.utils.pytorch_utils import download_and_extract_model
28
+
29
+ # Force CPU mode BEFORE any PyTorch imports to avoid MPS/CUDA issues on Mac
30
+ os.environ["CUDA_VISIBLE_DEVICES"] = ""
31
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
32
+ torch.set_default_device("cpu")
33
+ # Disable MPS entirely
34
+ if hasattr(torch.backends, "mps"):
35
+ torch.backends.mps.is_available = lambda: False
36
+
37
+
38
+ def get_eval_data(workbench_model: Model) -> pd.DataFrame:
39
+ """Get evaluation data from the FeatureSet associated with this model."""
40
+ # Get the FeatureSet
41
+ fs_name = workbench_model.get_input()
42
+ fs = FeatureSet(fs_name)
43
+ if not fs.exists():
44
+ raise ValueError(f"No FeatureSet found: {fs_name}")
45
+
46
+ # Get evaluation data (training = FALSE)
47
+ table = workbench_model.training_view().table
48
+ print(f"Querying evaluation data from {table}...")
49
+ eval_df = fs.query(f'SELECT * FROM "{table}" WHERE training = FALSE')
50
+ print(f"Retrieved {len(eval_df)} evaluation rows")
51
+
52
+ return eval_df
53
+
54
+
55
+ def load_model_script(script_path: str):
56
+ """Dynamically load the model script module."""
57
+ if not os.path.exists(script_path):
58
+ raise FileNotFoundError(f"Script not found: {script_path}")
59
+
60
+ spec = importlib.util.spec_from_file_location("model_script", script_path)
61
+ module = importlib.util.module_from_spec(spec)
62
+
63
+ # Add to sys.modules so imports within the script work
64
+ sys.modules["model_script"] = module
65
+
66
+ spec.loader.exec_module(module)
67
+ return module
68
+
69
+
70
+ def main():
71
+ if len(sys.argv) < 3:
72
+ print("Usage: python model_script_harness.py <local_script.py> <model_name>")
73
+ print("\nArguments:")
74
+ print(" local_script.py - Path to your LOCAL model script to test")
75
+ print(" model_name - Workbench model name (e.g., aqsol-pytorch-reg)")
76
+ print("\nOptional: testing/env.json with additional environment variables")
77
+ sys.exit(1)
78
+
79
+ script_path = sys.argv[1]
80
+ model_name = sys.argv[2]
81
+
82
+ # Validate local script exists
83
+ if not os.path.exists(script_path):
84
+ print(f"Error: Local script not found: {script_path}")
85
+ sys.exit(1)
86
+
87
+ # Initialize Workbench model
88
+ print(f"Loading Workbench model: {model_name}")
89
+ workbench_model = Model(model_name)
90
+ print(f"Model Framework: {workbench_model.model_framework}")
91
+ print()
92
+
93
+ # Create a temporary model directory
94
+ model_dir = tempfile.mkdtemp(prefix="model_harness_")
95
+ print(f"Using model directory: {model_dir}")
96
+
97
+ try:
98
+ # Load environment variables from env.json if it exists
99
+ if os.path.exists("testing/env.json"):
100
+ print("Loading environment variables from testing/env.json")
101
+ with open("testing/env.json") as f:
102
+ env_vars = json.load(f)
103
+ for key, value in env_vars.items():
104
+ os.environ[key] = value
105
+ print(f" Set {key} = {value}")
106
+ print()
107
+
108
+ # Set up SageMaker environment variables
109
+ os.environ["SM_MODEL_DIR"] = model_dir
110
+ print(f"Set SM_MODEL_DIR = {model_dir}")
111
+
112
+ # Download and extract model artifacts
113
+ s3_uri = workbench_model.model_data_url()
114
+ download_and_extract_model(s3_uri, model_dir)
115
+ print()
116
+
117
+ # Load the LOCAL model script
118
+ print(f"Loading LOCAL model script: {script_path}")
119
+ module = load_model_script(script_path)
120
+ print()
121
+
122
+ # Check for required functions
123
+ if not hasattr(module, "model_fn"):
124
+ raise AttributeError("Model script must have a model_fn function")
125
+ if not hasattr(module, "predict_fn"):
126
+ raise AttributeError("Model script must have a predict_fn function")
127
+
128
+ # Load the model
129
+ print("Calling model_fn...")
130
+ print("-" * 50)
131
+ model = module.model_fn(model_dir)
132
+ print("-" * 50)
133
+ print(f"Model loaded: {type(model)}")
134
+ print()
135
+
136
+ # Get evaluation data from FeatureSet
137
+ print("Pulling evaluation data from FeatureSet...")
138
+ df = get_eval_data(workbench_model)
139
+ print(f"Input shape: {df.shape}")
140
+ print(f"Columns: {df.columns.tolist()}")
141
+ print()
142
+
143
+ print("Calling predict_fn...")
144
+ print("-" * 50)
145
+ result = module.predict_fn(df, model)
146
+ print("-" * 50)
147
+ print()
148
+
149
+ print("Prediction result:")
150
+ print(f"Output shape: {result.shape}")
151
+ print(f"Output columns: {result.columns.tolist()}")
152
+ print()
153
+ print(result.head(10).to_string())
154
+
155
+ finally:
156
+ # Cleanup
157
+ print(f"\nCleaning up model directory: {model_dir}")
158
+ shutil.rmtree(model_dir, ignore_errors=True)
159
+
160
+
161
+ if __name__ == "__main__":
162
+ main()
@@ -0,0 +1,73 @@
1
+ """
2
+ Local test harness for AWS Lambda scripts.
3
+
4
+ Usage:
5
+ lambda_test <lambda_script.py>
6
+
7
+ Required: testing/event.json with the event definition
8
+ Options: testing/env.json with a set of ENV vars
9
+ """
10
+
11
+ import sys
12
+ import os
13
+ import json
14
+ import importlib.util
15
+
16
+
17
+ def main():
18
+ if len(sys.argv) != 2:
19
+ print("Usage: lambda_launcher <handler_module_name>")
20
+ print("\nOptional: testing/event.json with test event")
21
+ print("Optional: testing/env.json with environment variables")
22
+ sys.exit(1)
23
+
24
+ handler_file = sys.argv[1]
25
+
26
+ # Add .py if not present
27
+ if not handler_file.endswith(".py"):
28
+ handler_file += ".py"
29
+
30
+ # Check if file exists
31
+ if not os.path.exists(handler_file):
32
+ print(f"Error: File '{handler_file}' not found")
33
+ sys.exit(1)
34
+
35
+ # Load environment variables from env.json if it exists
36
+ if os.path.exists("testing/env.json"):
37
+ print("Loading environment variables from testing/env.json")
38
+ with open("testing/env.json") as f:
39
+ env_vars = json.load(f)
40
+ for key, value in env_vars.items():
41
+ os.environ[key] = value
42
+ print(f" Set {key} = {value}")
43
+ print()
44
+
45
+ # Load event configuration
46
+ if os.path.exists("testing/event.json"):
47
+ print("Loading event from testing/event.json")
48
+ with open("testing/event.json") as f:
49
+ event = json.load(f)
50
+ else:
51
+ print("No testing/event.json found, using empty event")
52
+ event = {}
53
+
54
+ # Load the module dynamically
55
+ spec = importlib.util.spec_from_file_location("lambda_module", handler_file)
56
+ lambda_module = importlib.util.module_from_spec(spec)
57
+ spec.loader.exec_module(lambda_module)
58
+
59
+ # Call the lambda_handler
60
+ print(f"Invoking lambda_handler from {handler_file}...")
61
+ print("-" * 50)
62
+ print(f"Event: {json.dumps(event, indent=2)}")
63
+ print("-" * 50)
64
+
65
+ result = lambda_module.lambda_handler(event, {})
66
+
67
+ print("-" * 50)
68
+ print("Result:")
69
+ print(json.dumps(result, indent=2))
70
+
71
+
72
+ if __name__ == "__main__":
73
+ main()
@@ -0,0 +1,137 @@
1
+ import argparse
2
+ import logging
3
+ import time
4
+ from datetime import datetime
5
+ from pathlib import Path
6
+
7
+ # Workbench Imports
8
+ from workbench.core.cloud_platform.aws.aws_account_clamp import AWSAccountClamp
9
+ from workbench.utils.config_manager import ConfigManager
10
+ from workbench.utils.s3_utils import upload_content_to_s3
11
+ from workbench.utils.cloudwatch_utils import get_cloudwatch_logs_url
12
+
13
+ log = logging.getLogger("workbench")
14
+ cm = ConfigManager()
15
+ workbench_bucket = cm.get_config("WORKBENCH_BUCKET")
16
+
17
+
18
+ def get_ecr_image_uri() -> str:
19
+ """Get the ECR image URI for the current region."""
20
+ region = AWSAccountClamp().region
21
+ return f"507740646243.dkr.ecr.{region}.amazonaws.com/aws-ml-images/py312-ml-pipelines:0.1"
22
+
23
+
24
+ def get_batch_role_arn() -> str:
25
+ """Get the Batch execution role ARN."""
26
+ account_id = AWSAccountClamp().account_id
27
+ return f"arn:aws:iam::{account_id}:role/Workbench-BatchRole"
28
+
29
+
30
+ def _log_cloudwatch_link(job: dict, message_prefix: str = "View logs") -> None:
31
+ """
32
+ Helper method to log CloudWatch logs link with clickable URL and full URL display.
33
+
34
+ Args:
35
+ job: Batch job description dictionary
36
+ message_prefix: Prefix for the log message (default: "View logs")
37
+ """
38
+ log_stream = job.get("container", {}).get("logStreamName")
39
+ logs_url = get_cloudwatch_logs_url(log_group="/aws/batch/job", log_stream=log_stream)
40
+ if logs_url:
41
+ clickable_url = f"\033]8;;{logs_url}\033\\{logs_url}\033]8;;\033\\"
42
+ log.info(f"{message_prefix}: {clickable_url}")
43
+ else:
44
+ log.info("Check AWS Batch console for logs")
45
+
46
+
47
+ def run_batch_job(script_path: str, size: str = "small") -> int:
48
+ """
49
+ Submit and monitor an AWS Batch job for ML pipeline execution.
50
+
51
+ Uploads script to S3, submits Batch job, monitors until completion or 2 minutes of RUNNING.
52
+
53
+ Args:
54
+ script_path: Local path to the ML pipeline script
55
+ size: Job size tier - "small" (default), "medium", or "large"
56
+ - small: 2 vCPU, 4GB RAM for lightweight processing
57
+ - medium: 4 vCPU, 8GB RAM for standard ML workloads
58
+ - large: 8 vCPU, 16GB RAM for heavy training/inference
59
+
60
+ Returns:
61
+ Exit code (0 for success/disconnected, non-zero for failure)
62
+ """
63
+ if size not in ["small", "medium", "large"]:
64
+ raise ValueError(f"Invalid size '{size}'. Must be 'small', 'medium', or 'large'")
65
+
66
+ batch = AWSAccountClamp().boto3_session.client("batch")
67
+ script_name = Path(script_path).stem
68
+
69
+ # Upload script to S3
70
+ s3_path = f"s3://{workbench_bucket}/batch-jobs/{Path(script_path).name}"
71
+ log.info(f"Uploading script to {s3_path}")
72
+ upload_content_to_s3(Path(script_path).read_text(), s3_path)
73
+
74
+ # Submit job
75
+ job_name = f"workbench_{script_name}_{datetime.now():%Y%m%d_%H%M%S}"
76
+ response = batch.submit_job(
77
+ jobName=job_name,
78
+ jobQueue="workbench-job-queue",
79
+ jobDefinition=f"workbench-batch-{size}",
80
+ containerOverrides={
81
+ "environment": [
82
+ {"name": "ML_PIPELINE_S3_PATH", "value": s3_path},
83
+ {"name": "WORKBENCH_BUCKET", "value": workbench_bucket},
84
+ ]
85
+ },
86
+ )
87
+ job_id = response["jobId"]
88
+ log.info(f"Submitted job: {job_name} ({job_id}) using {size} tier")
89
+
90
+ # Monitor job
91
+ last_status, running_start = None, None
92
+ while True:
93
+ job = batch.describe_jobs(jobs=[job_id])["jobs"][0]
94
+ status = job["status"]
95
+
96
+ if status != last_status:
97
+ log.info(f"Job status: {status}")
98
+ last_status = status
99
+ if status == "RUNNING":
100
+ running_start = time.time()
101
+
102
+ # Disconnect after 2 minutes of running
103
+ if status == "RUNNING" and running_start and (time.time() - running_start >= 120):
104
+ log.info("āœ… ML Pipeline is running successfully!")
105
+ _log_cloudwatch_link(job, "šŸ“Š Monitor logs")
106
+ return 0
107
+
108
+ # Handle completion
109
+ if status in ["SUCCEEDED", "FAILED"]:
110
+ exit_code = job.get("attempts", [{}])[-1].get("exitCode", 1)
111
+ msg = (
112
+ "Job completed successfully"
113
+ if status == "SUCCEEDED"
114
+ else f"Job failed: {job.get('statusReason', 'Unknown')}"
115
+ )
116
+ log.info(msg) if status == "SUCCEEDED" else log.error(msg)
117
+ _log_cloudwatch_link(job)
118
+ return exit_code
119
+
120
+ time.sleep(10)
121
+
122
+
123
+ def main():
124
+ """CLI entry point for running ML pipelines on AWS Batch."""
125
+ parser = argparse.ArgumentParser(description="Run ML pipeline script on AWS Batch")
126
+ parser.add_argument("script_file", help="Local path to ML pipeline script")
127
+ args = parser.parse_args()
128
+ try:
129
+ exit_code = run_batch_job(args.script_file)
130
+ exit(exit_code)
131
+ except Exception as e:
132
+ log.error(f"Error: {e}")
133
+ exit(1)
134
+
135
+
136
+ if __name__ == "__main__":
137
+ main()
@@ -0,0 +1,186 @@
1
+ import argparse
2
+ import logging
3
+ import json
4
+ from pathlib import Path
5
+
6
+ # Workbench Imports
7
+ from workbench.core.cloud_platform.aws.aws_account_clamp import AWSAccountClamp
8
+ from workbench.utils.config_manager import ConfigManager
9
+ from workbench.utils.s3_utils import upload_content_to_s3
10
+
11
+ log = logging.getLogger("workbench")
12
+ cm = ConfigManager()
13
+ workbench_bucket = cm.get_config("WORKBENCH_BUCKET")
14
+
15
+
16
+ def submit_to_sqs(
17
+ script_path: str,
18
+ size: str = "small",
19
+ realtime: bool = False,
20
+ dt: bool = False,
21
+ promote: bool = False,
22
+ ) -> None:
23
+ """
24
+ Upload script to S3 and submit message to SQS queue for processing.
25
+
26
+ Args:
27
+ script_path: Local path to the ML pipeline script
28
+ size: Job size tier - "small" (default), "medium", or "large"
29
+ realtime: If True, sets serverless=False for real-time processing (default: False)
30
+ dt: If True, sets DT=True in environment (default: False)
31
+ promote: If True, sets PROMOTE=True in environment (default: False)
32
+
33
+ Raises:
34
+ ValueError: If size is invalid or script file not found
35
+ """
36
+ print(f"\n{'=' * 60}")
37
+ print("šŸš€ SUBMITTING ML PIPELINE JOB")
38
+ print(f"{'=' * 60}")
39
+ if size not in ["small", "medium", "large"]:
40
+ raise ValueError(f"Invalid size '{size}'. Must be 'small', 'medium', or 'large'")
41
+
42
+ # Validate script exists
43
+ script_file = Path(script_path)
44
+ if not script_file.exists():
45
+ raise FileNotFoundError(f"Script not found: {script_path}")
46
+
47
+ print(f"šŸ“„ Script: {script_file.name}")
48
+ print(f"šŸ“ Size tier: {size}")
49
+ print(f"⚔ Mode: {'Real-time' if realtime else 'Serverless'} (serverless={'False' if realtime else 'True'})")
50
+ print(f"šŸ”„ DynamicTraining: {dt}")
51
+ print(f"šŸ†• Promote: {promote}")
52
+ print(f"🪣 Bucket: {workbench_bucket}")
53
+ sqs = AWSAccountClamp().boto3_session.client("sqs")
54
+ script_name = script_file.name
55
+
56
+ # List Workbench queues
57
+ print("\nšŸ“‹ Listing Workbench SQS queues...")
58
+ try:
59
+ queues = sqs.list_queues(QueueNamePrefix="workbench-")
60
+ queue_urls = queues.get("QueueUrls", [])
61
+ if queue_urls:
62
+ print(f"āœ… Found {len(queue_urls)} workbench queue(s):")
63
+ for url in queue_urls:
64
+ queue_name = url.split("/")[-1]
65
+ print(f" • {queue_name}")
66
+ else:
67
+ print("āš ļø No workbench queues found")
68
+ except Exception as e:
69
+ print(f"āŒ Error listing queues: {e}")
70
+
71
+ # Upload script to S3
72
+ s3_path = f"s3://{workbench_bucket}/batch-jobs/{script_name}"
73
+ print("\nšŸ“¤ Uploading script to S3...")
74
+ print(f" Source: {script_path}")
75
+ print(f" Destination: {s3_path}")
76
+
77
+ try:
78
+ upload_content_to_s3(script_file.read_text(), s3_path)
79
+ print("āœ… Script uploaded successfully")
80
+ except Exception as e:
81
+ print(f"āŒ Upload failed: {e}")
82
+ raise
83
+ # Get queue URL and info
84
+ queue_name = "workbench-ml-pipeline-queue.fifo"
85
+ print("\nšŸŽÆ Getting queue information...")
86
+ print(f" Queue name: {queue_name}")
87
+
88
+ try:
89
+ queue_url = sqs.get_queue_url(QueueName=queue_name)["QueueUrl"]
90
+ print(f" Queue URL: {queue_url}")
91
+
92
+ # Get queue attributes for additional info
93
+ attrs = sqs.get_queue_attributes(
94
+ QueueUrl=queue_url, AttributeNames=["ApproximateNumberOfMessages", "ApproximateNumberOfMessagesNotVisible"]
95
+ )
96
+ messages_available = attrs["Attributes"].get("ApproximateNumberOfMessages", "0")
97
+ messages_in_flight = attrs["Attributes"].get("ApproximateNumberOfMessagesNotVisible", "0")
98
+ print(f" Messages in queue: {messages_available}")
99
+ print(f" Messages in flight: {messages_in_flight}")
100
+
101
+ except Exception as e:
102
+ print(f"āŒ Error accessing queue: {e}")
103
+ raise
104
+
105
+ # Prepare message
106
+ message = {"script_path": s3_path, "size": size}
107
+
108
+ # Set environment variables
109
+ message["environment"] = {
110
+ "SERVERLESS": "False" if realtime else "True",
111
+ "DT": str(dt),
112
+ "PROMOTE": str(promote),
113
+ }
114
+
115
+ # Send the message to SQS
116
+ try:
117
+ print("\nšŸ“Ø Sending message to SQS...")
118
+ response = sqs.send_message(
119
+ QueueUrl=queue_url,
120
+ MessageBody=json.dumps(message, indent=2),
121
+ MessageGroupId="ml-pipeline-jobs", # Required for FIFO
122
+ )
123
+ message_id = response["MessageId"]
124
+ print("āœ… Message sent successfully!")
125
+ print(f" Message ID: {message_id}")
126
+ except Exception as e:
127
+ print(f"āŒ Failed to send message: {e}")
128
+ raise
129
+
130
+ # Success summary
131
+ print(f"\n{'=' * 60}")
132
+ print("āœ… JOB SUBMISSION COMPLETE")
133
+ print(f"{'=' * 60}")
134
+ print(f"šŸ“„ Script: {script_name}")
135
+ print(f"šŸ“ Size: {size}")
136
+ print(f"⚔ Mode: {'Real-time' if realtime else 'Serverless'} (SERVERLESS={'False' if realtime else 'True'})")
137
+ print(f"šŸ”„ DynamicTraining: {dt}")
138
+ print(f"šŸ†• Promote: {promote}")
139
+ print(f"šŸ†” Message ID: {message_id}")
140
+ print("\nšŸ” MONITORING LOCATIONS:")
141
+ print(f" • SQS Queue: AWS Console → SQS → {queue_name}")
142
+ print(" • Lambda Logs: AWS Console → Lambda → Functions")
143
+ print(" • Batch Jobs: AWS Console → Batch → Jobs")
144
+ print(" • CloudWatch: AWS Console → CloudWatch → Log groups")
145
+ print("\nā³ Your job should start processing soon...")
146
+
147
+
148
+ def main():
149
+ """CLI entry point for submitting ML pipelines via SQS."""
150
+ parser = argparse.ArgumentParser(description="Submit ML pipeline to SQS queue for Batch processing")
151
+ parser.add_argument("script_file", help="Local path to ML pipeline script")
152
+ parser.add_argument(
153
+ "--size", default="small", choices=["small", "medium", "large"], help="Job size tier (default: small)"
154
+ )
155
+ parser.add_argument(
156
+ "--realtime",
157
+ action="store_true",
158
+ help="Create realtime endpoints (default is serverless)",
159
+ )
160
+ parser.add_argument(
161
+ "--dt",
162
+ action="store_true",
163
+ help="Set DT=True (models and endpoints will have '-dt' suffix)",
164
+ )
165
+ parser.add_argument(
166
+ "--promote",
167
+ action="store_true",
168
+ help="Set Promote=True (models and endpoints will use promoted naming",
169
+ )
170
+ args = parser.parse_args()
171
+ try:
172
+ submit_to_sqs(
173
+ args.script_file,
174
+ args.size,
175
+ realtime=args.realtime,
176
+ dt=args.dt,
177
+ promote=args.promote,
178
+ )
179
+ except Exception as e:
180
+ print(f"\nāŒ ERROR: {e}")
181
+ log.error(f"Error: {e}")
182
+ exit(1)
183
+
184
+
185
+ if __name__ == "__main__":
186
+ main()