workbench 0.8.161__py3-none-any.whl → 0.8.192__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. workbench/algorithms/dataframe/proximity.py +143 -102
  2. workbench/algorithms/graph/light/proximity_graph.py +2 -1
  3. workbench/api/compound.py +1 -1
  4. workbench/api/endpoint.py +12 -0
  5. workbench/api/feature_set.py +4 -4
  6. workbench/api/meta.py +5 -2
  7. workbench/api/model.py +16 -12
  8. workbench/api/monitor.py +1 -16
  9. workbench/core/artifacts/artifact.py +11 -3
  10. workbench/core/artifacts/data_capture_core.py +355 -0
  11. workbench/core/artifacts/endpoint_core.py +168 -78
  12. workbench/core/artifacts/feature_set_core.py +72 -13
  13. workbench/core/artifacts/model_core.py +50 -15
  14. workbench/core/artifacts/monitor_core.py +33 -248
  15. workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
  16. workbench/core/cloud_platform/aws/aws_meta.py +12 -5
  17. workbench/core/cloud_platform/aws/aws_session.py +4 -4
  18. workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
  19. workbench/core/transforms/features_to_model/features_to_model.py +9 -4
  20. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +36 -6
  21. workbench/core/transforms/pandas_transforms/pandas_to_features.py +27 -0
  22. workbench/core/views/training_view.py +49 -53
  23. workbench/core/views/view.py +51 -1
  24. workbench/core/views/view_utils.py +4 -4
  25. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
  26. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
  27. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
  28. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +3 -5
  29. workbench/model_scripts/custom_models/proximity/proximity.py +143 -102
  30. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  31. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +10 -17
  32. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  33. workbench/model_scripts/custom_models/uq_models/meta_uq.template +156 -58
  34. workbench/model_scripts/custom_models/uq_models/ngboost.template +20 -14
  35. workbench/model_scripts/custom_models/uq_models/proximity.py +143 -102
  36. workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
  37. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +5 -13
  38. workbench/model_scripts/pytorch_model/pytorch.template +19 -20
  39. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  40. workbench/model_scripts/script_generation.py +7 -2
  41. workbench/model_scripts/uq_models/mapie.template +492 -0
  42. workbench/model_scripts/uq_models/requirements.txt +1 -0
  43. workbench/model_scripts/xgb_model/xgb_model.template +31 -40
  44. workbench/repl/workbench_shell.py +11 -6
  45. workbench/scripts/lambda_launcher.py +63 -0
  46. workbench/scripts/ml_pipeline_batch.py +137 -0
  47. workbench/scripts/ml_pipeline_sqs.py +186 -0
  48. workbench/scripts/monitor_cloud_watch.py +20 -100
  49. workbench/utils/aws_utils.py +4 -3
  50. workbench/utils/chem_utils/__init__.py +0 -0
  51. workbench/utils/chem_utils/fingerprints.py +134 -0
  52. workbench/utils/chem_utils/misc.py +194 -0
  53. workbench/utils/chem_utils/mol_descriptors.py +483 -0
  54. workbench/utils/chem_utils/mol_standardize.py +450 -0
  55. workbench/utils/chem_utils/mol_tagging.py +348 -0
  56. workbench/utils/chem_utils/projections.py +209 -0
  57. workbench/utils/chem_utils/salts.py +256 -0
  58. workbench/utils/chem_utils/sdf.py +292 -0
  59. workbench/utils/chem_utils/toxicity.py +250 -0
  60. workbench/utils/chem_utils/vis.py +253 -0
  61. workbench/utils/cloudwatch_handler.py +1 -1
  62. workbench/utils/cloudwatch_utils.py +137 -0
  63. workbench/utils/config_manager.py +3 -7
  64. workbench/utils/endpoint_utils.py +5 -7
  65. workbench/utils/license_manager.py +2 -6
  66. workbench/utils/model_utils.py +76 -30
  67. workbench/utils/monitor_utils.py +44 -62
  68. workbench/utils/pandas_utils.py +3 -3
  69. workbench/utils/shap_utils.py +10 -2
  70. workbench/utils/workbench_logging.py +0 -3
  71. workbench/utils/workbench_sqs.py +1 -1
  72. workbench/utils/xgboost_model_utils.py +283 -145
  73. workbench/web_interface/components/plugins/dashboard_status.py +3 -1
  74. workbench/web_interface/components/plugins/generated_compounds.py +1 -1
  75. workbench/web_interface/components/plugins/scatter_plot.py +3 -3
  76. {workbench-0.8.161.dist-info → workbench-0.8.192.dist-info}/METADATA +4 -4
  77. {workbench-0.8.161.dist-info → workbench-0.8.192.dist-info}/RECORD +81 -76
  78. {workbench-0.8.161.dist-info → workbench-0.8.192.dist-info}/entry_points.txt +3 -0
  79. workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
  80. workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
  81. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  82. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  83. workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
  84. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  85. workbench/model_scripts/pytorch_model/generated_model_script.py +0 -565
  86. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  87. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  88. workbench/model_scripts/scikit_learn/generated_model_script.py +0 -307
  89. workbench/model_scripts/xgb_model/generated_model_script.py +0 -477
  90. workbench/utils/chem_utils.py +0 -1556
  91. workbench/utils/execution_environment.py +0 -211
  92. workbench/utils/fast_inference.py +0 -167
  93. workbench/utils/resource_utils.py +0 -39
  94. {workbench-0.8.161.dist-info → workbench-0.8.192.dist-info}/WHEEL +0 -0
  95. {workbench-0.8.161.dist-info → workbench-0.8.192.dist-info}/licenses/LICENSE +0 -0
  96. {workbench-0.8.161.dist-info → workbench-0.8.192.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,6 @@
1
+ import IPython
1
2
  from IPython import start_ipython
3
+ from distutils.version import LooseVersion
2
4
  from IPython.terminal.prompts import Prompts
3
5
  from IPython.terminal.ipapp import load_default_config
4
6
  from pygments.token import Token
@@ -39,7 +41,7 @@ from workbench.cached.cached_meta import CachedMeta
39
41
  try:
40
42
  import rdkit # noqa
41
43
  import mordred # noqa
42
- from workbench.utils import chem_utils
44
+ from workbench.utils.chem_utils import vis
43
45
 
44
46
  HAVE_CHEM_UTILS = True
45
47
  except ImportError:
@@ -70,7 +72,7 @@ if not ConfigManager().config_okay():
70
72
 
71
73
  # Set the log level to important
72
74
  log = logging.getLogger("workbench")
73
- log.setLevel(IMPORTANT_LEVEL_NUM)
75
+ log.setLevel(logging.INFO)
74
76
  log.addFilter(
75
77
  lambda record: not (
76
78
  record.getMessage().startswith("Async: Metadata") or record.getMessage().startswith("Updated Metadata")
@@ -176,12 +178,12 @@ class WorkbenchShell:
176
178
 
177
179
  # Add cheminformatics utils if available
178
180
  if HAVE_CHEM_UTILS:
179
- self.commands["show"] = chem_utils.show
181
+ self.commands["show"] = vis.show
180
182
 
181
183
  def start(self):
182
184
  """Start the Workbench IPython shell"""
183
185
  cprint("magenta", "\nWelcome to Workbench!")
184
- if self.aws_status is False:
186
+ if not self.aws_status:
185
187
  cprint("red", "AWS Account Connection Failed...Review/Fix the Workbench Config:")
186
188
  cprint("red", f"Path: {self.cm.site_config_path}")
187
189
  self.show_config()
@@ -202,7 +204,10 @@ class WorkbenchShell:
202
204
 
203
205
  # Start IPython with the config and commands in the namespace
204
206
  try:
205
- ipython_argv = ["--no-tip", "--theme", "linux"]
207
+ if LooseVersion(IPython.__version__) >= LooseVersion("9.0.0"):
208
+ ipython_argv = ["--no-tip", "--theme", "linux"]
209
+ else:
210
+ ipython_argv = []
206
211
  start_ipython(ipython_argv, user_ns=locs, config=config)
207
212
  finally:
208
213
  spinner = self.spinner_start("Goodbye to AWS:")
@@ -555,7 +560,7 @@ class WorkbenchShell:
555
560
  from workbench.web_interface.components.plugin_unit_test import PluginUnitTest
556
561
 
557
562
  # Get kwargs
558
- theme = kwargs.get("theme", "dark")
563
+ theme = kwargs.get("theme", "midnight_blue")
559
564
 
560
565
  plugin_test = PluginUnitTest(plugin_class, theme=theme, input_data=data, **kwargs)
561
566
 
@@ -0,0 +1,63 @@
1
+ import sys
2
+ import os
3
+ import json
4
+ import importlib.util
5
+
6
+
7
+ def main():
8
+ if len(sys.argv) != 2:
9
+ print("Usage: lambda_launcher <handler_module_name>")
10
+ print("\nOptional: testing/event.json with test event")
11
+ print("Optional: testing/env.json with environment variables")
12
+ sys.exit(1)
13
+
14
+ handler_file = sys.argv[1]
15
+
16
+ # Add .py if not present
17
+ if not handler_file.endswith(".py"):
18
+ handler_file += ".py"
19
+
20
+ # Check if file exists
21
+ if not os.path.exists(handler_file):
22
+ print(f"Error: File '{handler_file}' not found")
23
+ sys.exit(1)
24
+
25
+ # Load environment variables from env.json if it exists
26
+ if os.path.exists("testing/env.json"):
27
+ print("Loading environment variables from testing/env.json")
28
+ with open("testing/env.json") as f:
29
+ env_vars = json.load(f)
30
+ for key, value in env_vars.items():
31
+ os.environ[key] = value
32
+ print(f" Set {key} = {value}")
33
+ print()
34
+
35
+ # Load event configuration
36
+ if os.path.exists("testing/event.json"):
37
+ print("Loading event from testing/event.json")
38
+ with open("testing/event.json") as f:
39
+ event = json.load(f)
40
+ else:
41
+ print("No testing/event.json found, using empty event")
42
+ event = {}
43
+
44
+ # Load the module dynamically
45
+ spec = importlib.util.spec_from_file_location("lambda_module", handler_file)
46
+ lambda_module = importlib.util.module_from_spec(spec)
47
+ spec.loader.exec_module(lambda_module)
48
+
49
+ # Call the lambda_handler
50
+ print(f"Invoking lambda_handler from {handler_file}...")
51
+ print("-" * 50)
52
+ print(f"Event: {json.dumps(event, indent=2)}")
53
+ print("-" * 50)
54
+
55
+ result = lambda_module.lambda_handler(event, {})
56
+
57
+ print("-" * 50)
58
+ print("Result:")
59
+ print(json.dumps(result, indent=2))
60
+
61
+
62
+ if __name__ == "__main__":
63
+ main()
@@ -0,0 +1,137 @@
1
+ import argparse
2
+ import logging
3
+ import time
4
+ from datetime import datetime
5
+ from pathlib import Path
6
+
7
+ # Workbench Imports
8
+ from workbench.core.cloud_platform.aws.aws_account_clamp import AWSAccountClamp
9
+ from workbench.utils.config_manager import ConfigManager
10
+ from workbench.utils.s3_utils import upload_content_to_s3
11
+ from workbench.utils.cloudwatch_utils import get_cloudwatch_logs_url
12
+
13
+ log = logging.getLogger("workbench")
14
+ cm = ConfigManager()
15
+ workbench_bucket = cm.get_config("WORKBENCH_BUCKET")
16
+
17
+
18
+ def get_ecr_image_uri() -> str:
19
+ """Get the ECR image URI for the current region."""
20
+ region = AWSAccountClamp().region
21
+ return f"507740646243.dkr.ecr.{region}.amazonaws.com/aws-ml-images/py312-ml-pipelines:0.1"
22
+
23
+
24
+ def get_batch_role_arn() -> str:
25
+ """Get the Batch execution role ARN."""
26
+ account_id = AWSAccountClamp().account_id
27
+ return f"arn:aws:iam::{account_id}:role/Workbench-BatchRole"
28
+
29
+
30
+ def _log_cloudwatch_link(job: dict, message_prefix: str = "View logs") -> None:
31
+ """
32
+ Helper method to log CloudWatch logs link with clickable URL and full URL display.
33
+
34
+ Args:
35
+ job: Batch job description dictionary
36
+ message_prefix: Prefix for the log message (default: "View logs")
37
+ """
38
+ log_stream = job.get("container", {}).get("logStreamName")
39
+ logs_url = get_cloudwatch_logs_url(log_group="/aws/batch/job", log_stream=log_stream)
40
+ if logs_url:
41
+ clickable_url = f"\033]8;;{logs_url}\033\\{logs_url}\033]8;;\033\\"
42
+ log.info(f"{message_prefix}: {clickable_url}")
43
+ else:
44
+ log.info("Check AWS Batch console for logs")
45
+
46
+
47
+ def run_batch_job(script_path: str, size: str = "small") -> int:
48
+ """
49
+ Submit and monitor an AWS Batch job for ML pipeline execution.
50
+
51
+ Uploads script to S3, submits Batch job, monitors until completion or 2 minutes of RUNNING.
52
+
53
+ Args:
54
+ script_path: Local path to the ML pipeline script
55
+ size: Job size tier - "small" (default), "medium", or "large"
56
+ - small: 2 vCPU, 4GB RAM for lightweight processing
57
+ - medium: 4 vCPU, 8GB RAM for standard ML workloads
58
+ - large: 8 vCPU, 16GB RAM for heavy training/inference
59
+
60
+ Returns:
61
+ Exit code (0 for success/disconnected, non-zero for failure)
62
+ """
63
+ if size not in ["small", "medium", "large"]:
64
+ raise ValueError(f"Invalid size '{size}'. Must be 'small', 'medium', or 'large'")
65
+
66
+ batch = AWSAccountClamp().boto3_session.client("batch")
67
+ script_name = Path(script_path).stem
68
+
69
+ # Upload script to S3
70
+ s3_path = f"s3://{workbench_bucket}/batch-jobs/{Path(script_path).name}"
71
+ log.info(f"Uploading script to {s3_path}")
72
+ upload_content_to_s3(Path(script_path).read_text(), s3_path)
73
+
74
+ # Submit job
75
+ job_name = f"workbench_{script_name}_{datetime.now():%Y%m%d_%H%M%S}"
76
+ response = batch.submit_job(
77
+ jobName=job_name,
78
+ jobQueue="workbench-job-queue",
79
+ jobDefinition=f"workbench-batch-{size}",
80
+ containerOverrides={
81
+ "environment": [
82
+ {"name": "ML_PIPELINE_S3_PATH", "value": s3_path},
83
+ {"name": "WORKBENCH_BUCKET", "value": workbench_bucket},
84
+ ]
85
+ },
86
+ )
87
+ job_id = response["jobId"]
88
+ log.info(f"Submitted job: {job_name} ({job_id}) using {size} tier")
89
+
90
+ # Monitor job
91
+ last_status, running_start = None, None
92
+ while True:
93
+ job = batch.describe_jobs(jobs=[job_id])["jobs"][0]
94
+ status = job["status"]
95
+
96
+ if status != last_status:
97
+ log.info(f"Job status: {status}")
98
+ last_status = status
99
+ if status == "RUNNING":
100
+ running_start = time.time()
101
+
102
+ # Disconnect after 2 minutes of running
103
+ if status == "RUNNING" and running_start and (time.time() - running_start >= 120):
104
+ log.info("✅ ML Pipeline is running successfully!")
105
+ _log_cloudwatch_link(job, "📊 Monitor logs")
106
+ return 0
107
+
108
+ # Handle completion
109
+ if status in ["SUCCEEDED", "FAILED"]:
110
+ exit_code = job.get("attempts", [{}])[-1].get("exitCode", 1)
111
+ msg = (
112
+ "Job completed successfully"
113
+ if status == "SUCCEEDED"
114
+ else f"Job failed: {job.get('statusReason', 'Unknown')}"
115
+ )
116
+ log.info(msg) if status == "SUCCEEDED" else log.error(msg)
117
+ _log_cloudwatch_link(job)
118
+ return exit_code
119
+
120
+ time.sleep(10)
121
+
122
+
123
+ def main():
124
+ """CLI entry point for running ML pipelines on AWS Batch."""
125
+ parser = argparse.ArgumentParser(description="Run ML pipeline script on AWS Batch")
126
+ parser.add_argument("script_file", help="Local path to ML pipeline script")
127
+ args = parser.parse_args()
128
+ try:
129
+ exit_code = run_batch_job(args.script_file)
130
+ exit(exit_code)
131
+ except Exception as e:
132
+ log.error(f"Error: {e}")
133
+ exit(1)
134
+
135
+
136
+ if __name__ == "__main__":
137
+ main()
@@ -0,0 +1,186 @@
1
+ import argparse
2
+ import logging
3
+ import json
4
+ from pathlib import Path
5
+
6
+ # Workbench Imports
7
+ from workbench.core.cloud_platform.aws.aws_account_clamp import AWSAccountClamp
8
+ from workbench.utils.config_manager import ConfigManager
9
+ from workbench.utils.s3_utils import upload_content_to_s3
10
+
11
+ log = logging.getLogger("workbench")
12
+ cm = ConfigManager()
13
+ workbench_bucket = cm.get_config("WORKBENCH_BUCKET")
14
+
15
+
16
+ def submit_to_sqs(
17
+ script_path: str,
18
+ size: str = "small",
19
+ realtime: bool = False,
20
+ dt: bool = False,
21
+ promote: bool = False,
22
+ ) -> None:
23
+ """
24
+ Upload script to S3 and submit message to SQS queue for processing.
25
+
26
+ Args:
27
+ script_path: Local path to the ML pipeline script
28
+ size: Job size tier - "small" (default), "medium", or "large"
29
+ realtime: If True, sets serverless=False for real-time processing (default: False)
30
+ dt: If True, sets DT=True in environment (default: False)
31
+ promote: If True, sets PROMOTE=True in environment (default: False)
32
+
33
+ Raises:
34
+ ValueError: If size is invalid or script file not found
35
+ """
36
+ print(f"\n{'=' * 60}")
37
+ print("🚀 SUBMITTING ML PIPELINE JOB")
38
+ print(f"{'=' * 60}")
39
+ if size not in ["small", "medium", "large"]:
40
+ raise ValueError(f"Invalid size '{size}'. Must be 'small', 'medium', or 'large'")
41
+
42
+ # Validate script exists
43
+ script_file = Path(script_path)
44
+ if not script_file.exists():
45
+ raise FileNotFoundError(f"Script not found: {script_path}")
46
+
47
+ print(f"📄 Script: {script_file.name}")
48
+ print(f"📏 Size tier: {size}")
49
+ print(f"⚡ Mode: {'Real-time' if realtime else 'Serverless'} (serverless={'False' if realtime else 'True'})")
50
+ print(f"🔄 DynamicTraining: {dt}")
51
+ print(f"🆕 Promote: {promote}")
52
+ print(f"🪣 Bucket: {workbench_bucket}")
53
+ sqs = AWSAccountClamp().boto3_session.client("sqs")
54
+ script_name = script_file.name
55
+
56
+ # List Workbench queues
57
+ print("\n📋 Listing Workbench SQS queues...")
58
+ try:
59
+ queues = sqs.list_queues(QueueNamePrefix="workbench-")
60
+ queue_urls = queues.get("QueueUrls", [])
61
+ if queue_urls:
62
+ print(f"✅ Found {len(queue_urls)} workbench queue(s):")
63
+ for url in queue_urls:
64
+ queue_name = url.split("/")[-1]
65
+ print(f" • {queue_name}")
66
+ else:
67
+ print("⚠️ No workbench queues found")
68
+ except Exception as e:
69
+ print(f"❌ Error listing queues: {e}")
70
+
71
+ # Upload script to S3
72
+ s3_path = f"s3://{workbench_bucket}/batch-jobs/{script_name}"
73
+ print("\n📤 Uploading script to S3...")
74
+ print(f" Source: {script_path}")
75
+ print(f" Destination: {s3_path}")
76
+
77
+ try:
78
+ upload_content_to_s3(script_file.read_text(), s3_path)
79
+ print("✅ Script uploaded successfully")
80
+ except Exception as e:
81
+ print(f"❌ Upload failed: {e}")
82
+ raise
83
+ # Get queue URL and info
84
+ queue_name = "workbench-ml-pipeline-queue.fifo"
85
+ print("\n🎯 Getting queue information...")
86
+ print(f" Queue name: {queue_name}")
87
+
88
+ try:
89
+ queue_url = sqs.get_queue_url(QueueName=queue_name)["QueueUrl"]
90
+ print(f" Queue URL: {queue_url}")
91
+
92
+ # Get queue attributes for additional info
93
+ attrs = sqs.get_queue_attributes(
94
+ QueueUrl=queue_url, AttributeNames=["ApproximateNumberOfMessages", "ApproximateNumberOfMessagesNotVisible"]
95
+ )
96
+ messages_available = attrs["Attributes"].get("ApproximateNumberOfMessages", "0")
97
+ messages_in_flight = attrs["Attributes"].get("ApproximateNumberOfMessagesNotVisible", "0")
98
+ print(f" Messages in queue: {messages_available}")
99
+ print(f" Messages in flight: {messages_in_flight}")
100
+
101
+ except Exception as e:
102
+ print(f"❌ Error accessing queue: {e}")
103
+ raise
104
+
105
+ # Prepare message
106
+ message = {"script_path": s3_path, "size": size}
107
+
108
+ # Set environment variables
109
+ message["environment"] = {
110
+ "SERVERLESS": "False" if realtime else "True",
111
+ "DT": str(dt),
112
+ "PROMOTE": str(promote),
113
+ }
114
+
115
+ # Send the message to SQS
116
+ try:
117
+ print("\n📨 Sending message to SQS...")
118
+ response = sqs.send_message(
119
+ QueueUrl=queue_url,
120
+ MessageBody=json.dumps(message, indent=2),
121
+ MessageGroupId="ml-pipeline-jobs", # Required for FIFO
122
+ )
123
+ message_id = response["MessageId"]
124
+ print("✅ Message sent successfully!")
125
+ print(f" Message ID: {message_id}")
126
+ except Exception as e:
127
+ print(f"❌ Failed to send message: {e}")
128
+ raise
129
+
130
+ # Success summary
131
+ print(f"\n{'=' * 60}")
132
+ print("✅ JOB SUBMISSION COMPLETE")
133
+ print(f"{'=' * 60}")
134
+ print(f"📄 Script: {script_name}")
135
+ print(f"📏 Size: {size}")
136
+ print(f"⚡ Mode: {'Real-time' if realtime else 'Serverless'} (SERVERLESS={'False' if realtime else 'True'})")
137
+ print(f"🔄 DynamicTraining: {dt}")
138
+ print(f"🆕 Promote: {promote}")
139
+ print(f"🆔 Message ID: {message_id}")
140
+ print("\n🔍 MONITORING LOCATIONS:")
141
+ print(f" • SQS Queue: AWS Console → SQS → {queue_name}")
142
+ print(" • Lambda Logs: AWS Console → Lambda → Functions")
143
+ print(" • Batch Jobs: AWS Console → Batch → Jobs")
144
+ print(" • CloudWatch: AWS Console → CloudWatch → Log groups")
145
+ print("\n⏳ Your job should start processing soon...")
146
+
147
+
148
+ def main():
149
+ """CLI entry point for submitting ML pipelines via SQS."""
150
+ parser = argparse.ArgumentParser(description="Submit ML pipeline to SQS queue for Batch processing")
151
+ parser.add_argument("script_file", help="Local path to ML pipeline script")
152
+ parser.add_argument(
153
+ "--size", default="small", choices=["small", "medium", "large"], help="Job size tier (default: small)"
154
+ )
155
+ parser.add_argument(
156
+ "--realtime",
157
+ action="store_true",
158
+ help="Create realtime endpoints (default is serverless)",
159
+ )
160
+ parser.add_argument(
161
+ "--dt",
162
+ action="store_true",
163
+ help="Set DT=True (models and endpoints will have '-dt' suffix)",
164
+ )
165
+ parser.add_argument(
166
+ "--promote",
167
+ action="store_true",
168
+ help="Set Promote=True (models and endpoints will use promoted naming",
169
+ )
170
+ args = parser.parse_args()
171
+ try:
172
+ submit_to_sqs(
173
+ args.script_file,
174
+ args.size,
175
+ realtime=args.realtime,
176
+ dt=args.dt,
177
+ promote=args.promote,
178
+ )
179
+ except Exception as e:
180
+ print(f"\n❌ ERROR: {e}")
181
+ log.error(f"Error: {e}")
182
+ exit(1)
183
+
184
+
185
+ if __name__ == "__main__":
186
+ main()
@@ -4,8 +4,10 @@ import sys
4
4
  import time
5
5
  import argparse
6
6
  from datetime import datetime, timedelta, timezone
7
- from workbench.core.cloud_platform.aws.aws_account_clamp import AWSAccountClamp
7
+
8
+ # Workbench Imports
8
9
  from workbench.utils.repl_utils import cprint, Spinner
10
+ from workbench.utils.cloudwatch_utils import get_cloudwatch_client, get_active_log_streams, stream_log_events
9
11
 
10
12
  # Define the log levels to include all log levels above the specified level
11
13
  log_level_map = {
@@ -33,64 +35,6 @@ def date_display(dt):
33
35
  return dt.strftime("%Y-%m-%d %I:%M%p") + "(UTC)"
34
36
 
35
37
 
36
- def get_cloudwatch_client():
37
- """Get the CloudWatch Logs client using the Workbench assumed role session."""
38
- session = AWSAccountClamp().boto3_session
39
- return session.client("logs")
40
-
41
-
42
- def get_active_log_streams(client, log_group_name, start_time_ms, stream_filter=None):
43
- """Retrieve log streams that have events after the specified start time."""
44
-
45
- # Get all the streams in the log group
46
- active_streams = []
47
- stream_params = {
48
- "logGroupName": log_group_name,
49
- "orderBy": "LastEventTime",
50
- "descending": True,
51
- }
52
-
53
- # Loop to retrieve all log streams (maximum 50 per call)
54
- while True:
55
- response = client.describe_log_streams(**stream_params)
56
- log_streams = response.get("logStreams", [])
57
-
58
- for log_stream in log_streams:
59
- log_stream_name = log_stream["logStreamName"]
60
- last_event_timestamp = log_stream.get("lastEventTimestamp")
61
-
62
- # Include streams with events since the specified start time
63
- # Note: There's some issue where the last event timestamp is 'off'
64
- # so we're going to add 60 minutes from the last event timestamp
65
- last_event_timestamp += 60 * 60 * 1000
66
- if last_event_timestamp >= start_time_ms:
67
- active_streams.append(log_stream_name)
68
- else:
69
- break # Stop if we reach streams older than the start time
70
-
71
- # Check if there are more streams to retrieve
72
- if "nextToken" in response:
73
- stream_params["nextToken"] = response["nextToken"]
74
- else:
75
- break
76
-
77
- # Sort and report the active log streams
78
- active_streams.sort()
79
- if active_streams:
80
- print("Active log streams:", len(active_streams))
81
-
82
- # Filter the active streams by a substring if provided
83
- if stream_filter and active_streams:
84
- print(f"Filtering active log streams by '{stream_filter}'...")
85
- active_streams = [stream for stream in active_streams if stream_filter in stream]
86
-
87
- for stream in active_streams:
88
- print(f"\t - {stream}")
89
-
90
- # Return the active log streams
91
- return active_streams
92
-
93
-
94
38
  def get_latest_log_events(client, log_group_name, start_time, end_time=None, stream_filter=None):
95
39
  """Retrieve the latest log events from the active/filtered log streams in a CloudWatch Logs group."""
96
40
 
@@ -99,11 +43,15 @@ def get_latest_log_events(client, log_group_name, start_time, end_time=None, str
99
43
  get_latest_log_events.first_run = True
100
44
 
101
45
  log_events = []
102
- start_time_ms = int(start_time.timestamp() * 1000) # Convert start_time to milliseconds
46
+ start_time_ms = int(start_time.timestamp() * 1000)
47
+
48
+ # Use the util function to get active streams
49
+ active_streams = get_active_log_streams(log_group_name, start_time_ms, stream_filter, client)
103
50
 
104
- # Get the active log streams with events since start_time
105
- active_streams = get_active_log_streams(client, log_group_name, start_time_ms, stream_filter)
106
51
  if active_streams:
52
+ print(f"Active log streams: {len(active_streams)}")
53
+ for stream in active_streams:
54
+ print(f"\t - {stream}")
107
55
  print(f"Processing log events from {date_display(start_time)} on {len(active_streams)} active log streams...")
108
56
  get_latest_log_events.first_run = False
109
57
  else:
@@ -114,50 +62,22 @@ def get_latest_log_events(client, log_group_name, start_time, end_time=None, str
114
62
  print("Monitoring for new events...")
115
63
  return log_events
116
64
 
117
- # Iterate over the active streams and fetch log events
65
+ # Use the util function to stream events from each log stream
118
66
  for log_stream_name in active_streams:
119
- params = {
120
- "logGroupName": log_group_name,
121
- "logStreamName": log_stream_name,
122
- "startTime": start_time_ms, # Use start_time in milliseconds
123
- "startFromHead": True, # Start from the nearest event to start_time
124
- }
125
- next_event_token = None
126
- if end_time is not None:
127
- params["endTime"] = int(end_time.timestamp() * 1000)
128
-
129
- # Process the log events from this log stream
130
67
  spinner = Spinner("lightpurple", f"Pulling events from {log_stream_name}:")
131
68
  spinner.start()
132
69
  log_stream_events = 0
133
70
 
134
- # Get the log events for the active log stream
135
- while True:
136
- if next_event_token:
137
- params["nextToken"] = next_event_token
138
- params.pop("startTime", None) # Remove startTime when using nextToken
71
+ # Stream events using the util function
72
+ for event in stream_log_events(
73
+ log_group_name, log_stream_name, start_time, end_time, follow=False, client=client
74
+ ):
75
+ log_stream_events += 1
76
+ log_events.append(event)
139
77
 
140
- # Fetch the log events (this call takes a while: optimize if we can)
141
- events_response = client.get_log_events(**params)
78
+ spinner.stop()
79
+ print(f"Processed {log_stream_events} events from {log_stream_name} (Total: {len(log_events)})")
142
80
 
143
- events = events_response.get("events", [])
144
- for event in events:
145
- event["logStreamName"] = log_stream_name
146
-
147
- # Add the log stream events to our list of all log events
148
- log_stream_events += len(events)
149
- log_events.extend(events)
150
-
151
- # Handle pagination for log events
152
- next_event_token = events_response.get("nextForwardToken")
153
-
154
- # Break the loop if there are no more events to fetch
155
- if not next_event_token or next_event_token == params.get("nextToken"):
156
- spinner.stop()
157
- print(f"Processed {log_stream_events} events from {log_stream_name} (Total: {len(log_events)})")
158
- break
159
-
160
- # Return the log events
161
81
  return log_events
162
82
 
163
83
 
@@ -206,6 +126,7 @@ def monitor_log_group(
206
126
  print(f"Monitoring log group: {log_group_name} from {date_display(start_time)}")
207
127
  print(f"Log levels: {log_levels}")
208
128
  print(f"Search terms: {search_terms}")
129
+
209
130
  while True:
210
131
  # Get the latest log events with stream filtering if provided
211
132
  all_log_events = get_latest_log_events(client, log_group_name, start_time, end_time, stream_filter)
@@ -218,7 +139,6 @@ def monitor_log_group(
218
139
 
219
140
  # Check the search terms
220
141
  if not search_terms or any(term in event["message"].lower() for term in search_terms):
221
-
222
142
  # Calculate the start and end index for this match
223
143
  start_index = max(i - before, 0)
224
144
  end_index = min(i + after, len(all_log_events) - 1)
@@ -55,7 +55,8 @@ def aws_throttle(func=None, retry_intervals=None):
55
55
  if func is None:
56
56
  return lambda f: aws_throttle(f, retry_intervals=retry_intervals)
57
57
 
58
- service_hold_time = 2 # Seconds to wait before calling AWS function
58
+ # This is currently commented out (we might want to use it later)
59
+ # service_hold_time = 2 # Seconds to wait before calling AWS function
59
60
  default_intervals = [2**i for i in range(1, 9)] # Default exponential backoff: 2, 4, 8... 256 seconds
60
61
  intervals = retry_intervals or default_intervals
61
62
 
@@ -64,8 +65,8 @@ def aws_throttle(func=None, retry_intervals=None):
64
65
  for attempt, delay in enumerate(intervals, start=1):
65
66
  try:
66
67
  # Add sleep before calling AWS func if running as a service
67
- if cm.running_as_service:
68
- time.sleep(service_hold_time)
68
+ # if cm.running_as_service:
69
+ # time.sleep(service_hold_time)
69
70
  return func(*args, **kwargs)
70
71
  except ClientError as e:
71
72
  if e.response["Error"]["Code"] == "ThrottlingException":
File without changes