workbench 0.8.162__py3-none-any.whl → 0.8.220__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of workbench might be problematic. Click here for more details.
- workbench/algorithms/dataframe/__init__.py +1 -2
- workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
- workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
- workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
- workbench/algorithms/dataframe/projection_2d.py +44 -21
- workbench/algorithms/dataframe/proximity.py +259 -305
- workbench/algorithms/graph/light/proximity_graph.py +14 -12
- workbench/algorithms/models/cleanlab_model.py +382 -0
- workbench/algorithms/models/noise_model.py +388 -0
- workbench/algorithms/sql/outliers.py +3 -3
- workbench/api/__init__.py +5 -1
- workbench/api/compound.py +1 -1
- workbench/api/df_store.py +17 -108
- workbench/api/endpoint.py +18 -5
- workbench/api/feature_set.py +121 -15
- workbench/api/meta.py +5 -2
- workbench/api/meta_model.py +289 -0
- workbench/api/model.py +55 -21
- workbench/api/monitor.py +1 -16
- workbench/api/parameter_store.py +3 -52
- workbench/cached/cached_model.py +4 -4
- workbench/core/artifacts/__init__.py +11 -2
- workbench/core/artifacts/artifact.py +16 -8
- workbench/core/artifacts/data_capture_core.py +355 -0
- workbench/core/artifacts/df_store_core.py +114 -0
- workbench/core/artifacts/endpoint_core.py +382 -253
- workbench/core/artifacts/feature_set_core.py +249 -45
- workbench/core/artifacts/model_core.py +135 -80
- workbench/core/artifacts/monitor_core.py +33 -248
- workbench/core/artifacts/parameter_store_core.py +98 -0
- workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
- workbench/core/cloud_platform/aws/aws_meta.py +12 -5
- workbench/core/cloud_platform/aws/aws_session.py +4 -4
- workbench/core/pipelines/pipeline_executor.py +1 -1
- workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
- workbench/core/transforms/features_to_model/features_to_model.py +62 -40
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +76 -15
- workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
- workbench/core/views/training_view.py +113 -42
- workbench/core/views/view.py +53 -3
- workbench/core/views/view_utils.py +4 -4
- workbench/model_script_utils/model_script_utils.py +339 -0
- workbench/model_script_utils/pytorch_utils.py +405 -0
- workbench/model_script_utils/uq_harness.py +278 -0
- workbench/model_scripts/chemprop/chemprop.template +649 -0
- workbench/model_scripts/chemprop/generated_model_script.py +649 -0
- workbench/model_scripts/chemprop/model_script_utils.py +339 -0
- workbench/model_scripts/chemprop/requirements.txt +3 -0
- workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
- workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
- workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
- workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
- workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
- workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
- workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
- workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
- workbench/model_scripts/custom_models/uq_models/ngboost.template +30 -18
- workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
- workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
- workbench/model_scripts/meta_model/generated_model_script.py +209 -0
- workbench/model_scripts/meta_model/meta_model.template +209 -0
- workbench/model_scripts/pytorch_model/generated_model_script.py +444 -500
- workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
- workbench/model_scripts/pytorch_model/pytorch.template +440 -496
- workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
- workbench/model_scripts/pytorch_model/requirements.txt +1 -1
- workbench/model_scripts/pytorch_model/uq_harness.py +278 -0
- workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
- workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
- workbench/model_scripts/script_generation.py +20 -11
- workbench/model_scripts/uq_models/generated_model_script.py +248 -0
- workbench/model_scripts/xgb_model/generated_model_script.py +372 -404
- workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
- workbench/model_scripts/xgb_model/uq_harness.py +278 -0
- workbench/model_scripts/xgb_model/xgb_model.template +369 -401
- workbench/repl/workbench_shell.py +28 -19
- workbench/resources/open_source_api.key +1 -1
- workbench/scripts/endpoint_test.py +162 -0
- workbench/scripts/lambda_test.py +73 -0
- workbench/scripts/meta_model_sim.py +35 -0
- workbench/scripts/ml_pipeline_batch.py +137 -0
- workbench/scripts/ml_pipeline_sqs.py +186 -0
- workbench/scripts/monitor_cloud_watch.py +20 -100
- workbench/scripts/training_test.py +85 -0
- workbench/utils/aws_utils.py +4 -3
- workbench/utils/chem_utils/__init__.py +0 -0
- workbench/utils/chem_utils/fingerprints.py +175 -0
- workbench/utils/chem_utils/misc.py +194 -0
- workbench/utils/chem_utils/mol_descriptors.py +483 -0
- workbench/utils/chem_utils/mol_standardize.py +450 -0
- workbench/utils/chem_utils/mol_tagging.py +348 -0
- workbench/utils/chem_utils/projections.py +219 -0
- workbench/utils/chem_utils/salts.py +256 -0
- workbench/utils/chem_utils/sdf.py +292 -0
- workbench/utils/chem_utils/toxicity.py +250 -0
- workbench/utils/chem_utils/vis.py +253 -0
- workbench/utils/chemprop_utils.py +141 -0
- workbench/utils/cloudwatch_handler.py +1 -1
- workbench/utils/cloudwatch_utils.py +137 -0
- workbench/utils/config_manager.py +3 -7
- workbench/utils/endpoint_utils.py +5 -7
- workbench/utils/license_manager.py +2 -6
- workbench/utils/meta_model_simulator.py +499 -0
- workbench/utils/metrics_utils.py +256 -0
- workbench/utils/model_utils.py +278 -79
- workbench/utils/monitor_utils.py +44 -62
- workbench/utils/pandas_utils.py +3 -3
- workbench/utils/pytorch_utils.py +87 -0
- workbench/utils/shap_utils.py +11 -57
- workbench/utils/workbench_logging.py +0 -3
- workbench/utils/workbench_sqs.py +1 -1
- workbench/utils/xgboost_local_crossfold.py +267 -0
- workbench/utils/xgboost_model_utils.py +127 -219
- workbench/web_interface/components/model_plot.py +14 -2
- workbench/web_interface/components/plugin_unit_test.py +5 -2
- workbench/web_interface/components/plugins/dashboard_status.py +3 -1
- workbench/web_interface/components/plugins/generated_compounds.py +1 -1
- workbench/web_interface/components/plugins/model_details.py +38 -74
- workbench/web_interface/components/plugins/scatter_plot.py +6 -10
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/METADATA +31 -9
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/RECORD +128 -96
- workbench-0.8.220.dist-info/entry_points.txt +11 -0
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/licenses/LICENSE +1 -1
- workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
- workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
- workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
- workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
- workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
- workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
- workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
- workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
- workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
- workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -273
- workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
- workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
- workbench/model_scripts/quant_regression/quant_regression.template +0 -279
- workbench/model_scripts/quant_regression/requirements.txt +0 -1
- workbench/utils/chem_utils.py +0 -1556
- workbench/utils/execution_environment.py +0 -211
- workbench/utils/fast_inference.py +0 -167
- workbench/utils/resource_utils.py +0 -39
- workbench-0.8.162.dist-info/entry_points.txt +0 -5
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/WHEEL +0 -0
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import logging
|
|
3
|
+
import json
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
# Workbench Imports
|
|
7
|
+
from workbench.core.cloud_platform.aws.aws_account_clamp import AWSAccountClamp
|
|
8
|
+
from workbench.utils.config_manager import ConfigManager
|
|
9
|
+
from workbench.utils.s3_utils import upload_content_to_s3
|
|
10
|
+
|
|
11
|
+
log = logging.getLogger("workbench")
|
|
12
|
+
cm = ConfigManager()
|
|
13
|
+
workbench_bucket = cm.get_config("WORKBENCH_BUCKET")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def submit_to_sqs(
|
|
17
|
+
script_path: str,
|
|
18
|
+
size: str = "small",
|
|
19
|
+
realtime: bool = False,
|
|
20
|
+
dt: bool = False,
|
|
21
|
+
promote: bool = False,
|
|
22
|
+
) -> None:
|
|
23
|
+
"""
|
|
24
|
+
Upload script to S3 and submit message to SQS queue for processing.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
script_path: Local path to the ML pipeline script
|
|
28
|
+
size: Job size tier - "small" (default), "medium", or "large"
|
|
29
|
+
realtime: If True, sets serverless=False for real-time processing (default: False)
|
|
30
|
+
dt: If True, sets DT=True in environment (default: False)
|
|
31
|
+
promote: If True, sets PROMOTE=True in environment (default: False)
|
|
32
|
+
|
|
33
|
+
Raises:
|
|
34
|
+
ValueError: If size is invalid or script file not found
|
|
35
|
+
"""
|
|
36
|
+
print(f"\n{'=' * 60}")
|
|
37
|
+
print("🚀 SUBMITTING ML PIPELINE JOB")
|
|
38
|
+
print(f"{'=' * 60}")
|
|
39
|
+
if size not in ["small", "medium", "large"]:
|
|
40
|
+
raise ValueError(f"Invalid size '{size}'. Must be 'small', 'medium', or 'large'")
|
|
41
|
+
|
|
42
|
+
# Validate script exists
|
|
43
|
+
script_file = Path(script_path)
|
|
44
|
+
if not script_file.exists():
|
|
45
|
+
raise FileNotFoundError(f"Script not found: {script_path}")
|
|
46
|
+
|
|
47
|
+
print(f"📄 Script: {script_file.name}")
|
|
48
|
+
print(f"📏 Size tier: {size}")
|
|
49
|
+
print(f"⚡ Mode: {'Real-time' if realtime else 'Serverless'} (serverless={'False' if realtime else 'True'})")
|
|
50
|
+
print(f"🔄 DynamicTraining: {dt}")
|
|
51
|
+
print(f"🆕 Promote: {promote}")
|
|
52
|
+
print(f"🪣 Bucket: {workbench_bucket}")
|
|
53
|
+
sqs = AWSAccountClamp().boto3_session.client("sqs")
|
|
54
|
+
script_name = script_file.name
|
|
55
|
+
|
|
56
|
+
# List Workbench queues
|
|
57
|
+
print("\n📋 Listing Workbench SQS queues...")
|
|
58
|
+
try:
|
|
59
|
+
queues = sqs.list_queues(QueueNamePrefix="workbench-")
|
|
60
|
+
queue_urls = queues.get("QueueUrls", [])
|
|
61
|
+
if queue_urls:
|
|
62
|
+
print(f"✅ Found {len(queue_urls)} workbench queue(s):")
|
|
63
|
+
for url in queue_urls:
|
|
64
|
+
queue_name = url.split("/")[-1]
|
|
65
|
+
print(f" • {queue_name}")
|
|
66
|
+
else:
|
|
67
|
+
print("⚠️ No workbench queues found")
|
|
68
|
+
except Exception as e:
|
|
69
|
+
print(f"❌ Error listing queues: {e}")
|
|
70
|
+
|
|
71
|
+
# Upload script to S3
|
|
72
|
+
s3_path = f"s3://{workbench_bucket}/batch-jobs/{script_name}"
|
|
73
|
+
print("\n📤 Uploading script to S3...")
|
|
74
|
+
print(f" Source: {script_path}")
|
|
75
|
+
print(f" Destination: {s3_path}")
|
|
76
|
+
|
|
77
|
+
try:
|
|
78
|
+
upload_content_to_s3(script_file.read_text(), s3_path)
|
|
79
|
+
print("✅ Script uploaded successfully")
|
|
80
|
+
except Exception as e:
|
|
81
|
+
print(f"❌ Upload failed: {e}")
|
|
82
|
+
raise
|
|
83
|
+
# Get queue URL and info
|
|
84
|
+
queue_name = "workbench-ml-pipeline-queue.fifo"
|
|
85
|
+
print("\n🎯 Getting queue information...")
|
|
86
|
+
print(f" Queue name: {queue_name}")
|
|
87
|
+
|
|
88
|
+
try:
|
|
89
|
+
queue_url = sqs.get_queue_url(QueueName=queue_name)["QueueUrl"]
|
|
90
|
+
print(f" Queue URL: {queue_url}")
|
|
91
|
+
|
|
92
|
+
# Get queue attributes for additional info
|
|
93
|
+
attrs = sqs.get_queue_attributes(
|
|
94
|
+
QueueUrl=queue_url, AttributeNames=["ApproximateNumberOfMessages", "ApproximateNumberOfMessagesNotVisible"]
|
|
95
|
+
)
|
|
96
|
+
messages_available = attrs["Attributes"].get("ApproximateNumberOfMessages", "0")
|
|
97
|
+
messages_in_flight = attrs["Attributes"].get("ApproximateNumberOfMessagesNotVisible", "0")
|
|
98
|
+
print(f" Messages in queue: {messages_available}")
|
|
99
|
+
print(f" Messages in flight: {messages_in_flight}")
|
|
100
|
+
|
|
101
|
+
except Exception as e:
|
|
102
|
+
print(f"❌ Error accessing queue: {e}")
|
|
103
|
+
raise
|
|
104
|
+
|
|
105
|
+
# Prepare message
|
|
106
|
+
message = {"script_path": s3_path, "size": size}
|
|
107
|
+
|
|
108
|
+
# Set environment variables
|
|
109
|
+
message["environment"] = {
|
|
110
|
+
"SERVERLESS": "False" if realtime else "True",
|
|
111
|
+
"DT": str(dt),
|
|
112
|
+
"PROMOTE": str(promote),
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
# Send the message to SQS
|
|
116
|
+
try:
|
|
117
|
+
print("\n📨 Sending message to SQS...")
|
|
118
|
+
response = sqs.send_message(
|
|
119
|
+
QueueUrl=queue_url,
|
|
120
|
+
MessageBody=json.dumps(message, indent=2),
|
|
121
|
+
MessageGroupId="ml-pipeline-jobs", # Required for FIFO
|
|
122
|
+
)
|
|
123
|
+
message_id = response["MessageId"]
|
|
124
|
+
print("✅ Message sent successfully!")
|
|
125
|
+
print(f" Message ID: {message_id}")
|
|
126
|
+
except Exception as e:
|
|
127
|
+
print(f"❌ Failed to send message: {e}")
|
|
128
|
+
raise
|
|
129
|
+
|
|
130
|
+
# Success summary
|
|
131
|
+
print(f"\n{'=' * 60}")
|
|
132
|
+
print("✅ JOB SUBMISSION COMPLETE")
|
|
133
|
+
print(f"{'=' * 60}")
|
|
134
|
+
print(f"📄 Script: {script_name}")
|
|
135
|
+
print(f"📏 Size: {size}")
|
|
136
|
+
print(f"⚡ Mode: {'Real-time' if realtime else 'Serverless'} (SERVERLESS={'False' if realtime else 'True'})")
|
|
137
|
+
print(f"🔄 DynamicTraining: {dt}")
|
|
138
|
+
print(f"🆕 Promote: {promote}")
|
|
139
|
+
print(f"🆔 Message ID: {message_id}")
|
|
140
|
+
print("\n🔍 MONITORING LOCATIONS:")
|
|
141
|
+
print(f" • SQS Queue: AWS Console → SQS → {queue_name}")
|
|
142
|
+
print(" • Lambda Logs: AWS Console → Lambda → Functions")
|
|
143
|
+
print(" • Batch Jobs: AWS Console → Batch → Jobs")
|
|
144
|
+
print(" • CloudWatch: AWS Console → CloudWatch → Log groups")
|
|
145
|
+
print("\n⏳ Your job should start processing soon...")
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def main():
|
|
149
|
+
"""CLI entry point for submitting ML pipelines via SQS."""
|
|
150
|
+
parser = argparse.ArgumentParser(description="Submit ML pipeline to SQS queue for Batch processing")
|
|
151
|
+
parser.add_argument("script_file", help="Local path to ML pipeline script")
|
|
152
|
+
parser.add_argument(
|
|
153
|
+
"--size", default="small", choices=["small", "medium", "large"], help="Job size tier (default: small)"
|
|
154
|
+
)
|
|
155
|
+
parser.add_argument(
|
|
156
|
+
"--realtime",
|
|
157
|
+
action="store_true",
|
|
158
|
+
help="Create realtime endpoints (default is serverless)",
|
|
159
|
+
)
|
|
160
|
+
parser.add_argument(
|
|
161
|
+
"--dt",
|
|
162
|
+
action="store_true",
|
|
163
|
+
help="Set DT=True (models and endpoints will have '-dt' suffix)",
|
|
164
|
+
)
|
|
165
|
+
parser.add_argument(
|
|
166
|
+
"--promote",
|
|
167
|
+
action="store_true",
|
|
168
|
+
help="Set Promote=True (models and endpoints will use promoted naming",
|
|
169
|
+
)
|
|
170
|
+
args = parser.parse_args()
|
|
171
|
+
try:
|
|
172
|
+
submit_to_sqs(
|
|
173
|
+
args.script_file,
|
|
174
|
+
args.size,
|
|
175
|
+
realtime=args.realtime,
|
|
176
|
+
dt=args.dt,
|
|
177
|
+
promote=args.promote,
|
|
178
|
+
)
|
|
179
|
+
except Exception as e:
|
|
180
|
+
print(f"\n❌ ERROR: {e}")
|
|
181
|
+
log.error(f"Error: {e}")
|
|
182
|
+
exit(1)
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
if __name__ == "__main__":
|
|
186
|
+
main()
|
|
@@ -4,8 +4,10 @@ import sys
|
|
|
4
4
|
import time
|
|
5
5
|
import argparse
|
|
6
6
|
from datetime import datetime, timedelta, timezone
|
|
7
|
-
|
|
7
|
+
|
|
8
|
+
# Workbench Imports
|
|
8
9
|
from workbench.utils.repl_utils import cprint, Spinner
|
|
10
|
+
from workbench.utils.cloudwatch_utils import get_cloudwatch_client, get_active_log_streams, stream_log_events
|
|
9
11
|
|
|
10
12
|
# Define the log levels to include all log levels above the specified level
|
|
11
13
|
log_level_map = {
|
|
@@ -33,64 +35,6 @@ def date_display(dt):
|
|
|
33
35
|
return dt.strftime("%Y-%m-%d %I:%M%p") + "(UTC)"
|
|
34
36
|
|
|
35
37
|
|
|
36
|
-
def get_cloudwatch_client():
|
|
37
|
-
"""Get the CloudWatch Logs client using the Workbench assumed role session."""
|
|
38
|
-
session = AWSAccountClamp().boto3_session
|
|
39
|
-
return session.client("logs")
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
def get_active_log_streams(client, log_group_name, start_time_ms, stream_filter=None):
|
|
43
|
-
"""Retrieve log streams that have events after the specified start time."""
|
|
44
|
-
|
|
45
|
-
# Get all the streams in the log group
|
|
46
|
-
active_streams = []
|
|
47
|
-
stream_params = {
|
|
48
|
-
"logGroupName": log_group_name,
|
|
49
|
-
"orderBy": "LastEventTime",
|
|
50
|
-
"descending": True,
|
|
51
|
-
}
|
|
52
|
-
|
|
53
|
-
# Loop to retrieve all log streams (maximum 50 per call)
|
|
54
|
-
while True:
|
|
55
|
-
response = client.describe_log_streams(**stream_params)
|
|
56
|
-
log_streams = response.get("logStreams", [])
|
|
57
|
-
|
|
58
|
-
for log_stream in log_streams:
|
|
59
|
-
log_stream_name = log_stream["logStreamName"]
|
|
60
|
-
last_event_timestamp = log_stream.get("lastEventTimestamp")
|
|
61
|
-
|
|
62
|
-
# Include streams with events since the specified start time
|
|
63
|
-
# Note: There's some issue where the last event timestamp is 'off'
|
|
64
|
-
# so we're going to add 60 minutes from the last event timestamp
|
|
65
|
-
last_event_timestamp += 60 * 60 * 1000
|
|
66
|
-
if last_event_timestamp >= start_time_ms:
|
|
67
|
-
active_streams.append(log_stream_name)
|
|
68
|
-
else:
|
|
69
|
-
break # Stop if we reach streams older than the start time
|
|
70
|
-
|
|
71
|
-
# Check if there are more streams to retrieve
|
|
72
|
-
if "nextToken" in response:
|
|
73
|
-
stream_params["nextToken"] = response["nextToken"]
|
|
74
|
-
else:
|
|
75
|
-
break
|
|
76
|
-
|
|
77
|
-
# Sort and report the active log streams
|
|
78
|
-
active_streams.sort()
|
|
79
|
-
if active_streams:
|
|
80
|
-
print("Active log streams:", len(active_streams))
|
|
81
|
-
|
|
82
|
-
# Filter the active streams by a substring if provided
|
|
83
|
-
if stream_filter and active_streams:
|
|
84
|
-
print(f"Filtering active log streams by '{stream_filter}'...")
|
|
85
|
-
active_streams = [stream for stream in active_streams if stream_filter in stream]
|
|
86
|
-
|
|
87
|
-
for stream in active_streams:
|
|
88
|
-
print(f"\t - {stream}")
|
|
89
|
-
|
|
90
|
-
# Return the active log streams
|
|
91
|
-
return active_streams
|
|
92
|
-
|
|
93
|
-
|
|
94
38
|
def get_latest_log_events(client, log_group_name, start_time, end_time=None, stream_filter=None):
|
|
95
39
|
"""Retrieve the latest log events from the active/filtered log streams in a CloudWatch Logs group."""
|
|
96
40
|
|
|
@@ -99,11 +43,15 @@ def get_latest_log_events(client, log_group_name, start_time, end_time=None, str
|
|
|
99
43
|
get_latest_log_events.first_run = True
|
|
100
44
|
|
|
101
45
|
log_events = []
|
|
102
|
-
start_time_ms = int(start_time.timestamp() * 1000)
|
|
46
|
+
start_time_ms = int(start_time.timestamp() * 1000)
|
|
47
|
+
|
|
48
|
+
# Use the util function to get active streams
|
|
49
|
+
active_streams = get_active_log_streams(log_group_name, start_time_ms, stream_filter, client)
|
|
103
50
|
|
|
104
|
-
# Get the active log streams with events since start_time
|
|
105
|
-
active_streams = get_active_log_streams(client, log_group_name, start_time_ms, stream_filter)
|
|
106
51
|
if active_streams:
|
|
52
|
+
print(f"Active log streams: {len(active_streams)}")
|
|
53
|
+
for stream in active_streams:
|
|
54
|
+
print(f"\t - {stream}")
|
|
107
55
|
print(f"Processing log events from {date_display(start_time)} on {len(active_streams)} active log streams...")
|
|
108
56
|
get_latest_log_events.first_run = False
|
|
109
57
|
else:
|
|
@@ -114,50 +62,22 @@ def get_latest_log_events(client, log_group_name, start_time, end_time=None, str
|
|
|
114
62
|
print("Monitoring for new events...")
|
|
115
63
|
return log_events
|
|
116
64
|
|
|
117
|
-
#
|
|
65
|
+
# Use the util function to stream events from each log stream
|
|
118
66
|
for log_stream_name in active_streams:
|
|
119
|
-
params = {
|
|
120
|
-
"logGroupName": log_group_name,
|
|
121
|
-
"logStreamName": log_stream_name,
|
|
122
|
-
"startTime": start_time_ms, # Use start_time in milliseconds
|
|
123
|
-
"startFromHead": True, # Start from the nearest event to start_time
|
|
124
|
-
}
|
|
125
|
-
next_event_token = None
|
|
126
|
-
if end_time is not None:
|
|
127
|
-
params["endTime"] = int(end_time.timestamp() * 1000)
|
|
128
|
-
|
|
129
|
-
# Process the log events from this log stream
|
|
130
67
|
spinner = Spinner("lightpurple", f"Pulling events from {log_stream_name}:")
|
|
131
68
|
spinner.start()
|
|
132
69
|
log_stream_events = 0
|
|
133
70
|
|
|
134
|
-
#
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
71
|
+
# Stream events using the util function
|
|
72
|
+
for event in stream_log_events(
|
|
73
|
+
log_group_name, log_stream_name, start_time, end_time, follow=False, client=client
|
|
74
|
+
):
|
|
75
|
+
log_stream_events += 1
|
|
76
|
+
log_events.append(event)
|
|
139
77
|
|
|
140
|
-
|
|
141
|
-
|
|
78
|
+
spinner.stop()
|
|
79
|
+
print(f"Processed {log_stream_events} events from {log_stream_name} (Total: {len(log_events)})")
|
|
142
80
|
|
|
143
|
-
events = events_response.get("events", [])
|
|
144
|
-
for event in events:
|
|
145
|
-
event["logStreamName"] = log_stream_name
|
|
146
|
-
|
|
147
|
-
# Add the log stream events to our list of all log events
|
|
148
|
-
log_stream_events += len(events)
|
|
149
|
-
log_events.extend(events)
|
|
150
|
-
|
|
151
|
-
# Handle pagination for log events
|
|
152
|
-
next_event_token = events_response.get("nextForwardToken")
|
|
153
|
-
|
|
154
|
-
# Break the loop if there are no more events to fetch
|
|
155
|
-
if not next_event_token or next_event_token == params.get("nextToken"):
|
|
156
|
-
spinner.stop()
|
|
157
|
-
print(f"Processed {log_stream_events} events from {log_stream_name} (Total: {len(log_events)})")
|
|
158
|
-
break
|
|
159
|
-
|
|
160
|
-
# Return the log events
|
|
161
81
|
return log_events
|
|
162
82
|
|
|
163
83
|
|
|
@@ -206,6 +126,7 @@ def monitor_log_group(
|
|
|
206
126
|
print(f"Monitoring log group: {log_group_name} from {date_display(start_time)}")
|
|
207
127
|
print(f"Log levels: {log_levels}")
|
|
208
128
|
print(f"Search terms: {search_terms}")
|
|
129
|
+
|
|
209
130
|
while True:
|
|
210
131
|
# Get the latest log events with stream filtering if provided
|
|
211
132
|
all_log_events = get_latest_log_events(client, log_group_name, start_time, end_time, stream_filter)
|
|
@@ -218,7 +139,6 @@ def monitor_log_group(
|
|
|
218
139
|
|
|
219
140
|
# Check the search terms
|
|
220
141
|
if not search_terms or any(term in event["message"].lower() for term in search_terms):
|
|
221
|
-
|
|
222
142
|
# Calculate the start and end index for this match
|
|
223
143
|
start_index = max(i - before, 0)
|
|
224
144
|
end_index = min(i + after, len(all_log_events) - 1)
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Local test harness for SageMaker training scripts.
|
|
3
|
+
|
|
4
|
+
Usage:
|
|
5
|
+
python training_test.py <model_script.py> <featureset_name>
|
|
6
|
+
|
|
7
|
+
Example:
|
|
8
|
+
python training_test.py ../model_scripts/pytorch_model/generated_model_script.py caco2-class-features
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import os
|
|
12
|
+
import shutil
|
|
13
|
+
import subprocess
|
|
14
|
+
import sys
|
|
15
|
+
import tempfile
|
|
16
|
+
|
|
17
|
+
import pandas as pd
|
|
18
|
+
|
|
19
|
+
from workbench.api import FeatureSet
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def get_training_data(featureset_name: str) -> pd.DataFrame:
|
|
23
|
+
"""Get training data from the FeatureSet."""
|
|
24
|
+
fs = FeatureSet(featureset_name)
|
|
25
|
+
return fs.pull_dataframe()
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def main():
|
|
29
|
+
if len(sys.argv) < 3:
|
|
30
|
+
print("Usage: python training_test.py <model_script.py> <featureset_name>")
|
|
31
|
+
sys.exit(1)
|
|
32
|
+
|
|
33
|
+
script_path = sys.argv[1]
|
|
34
|
+
featureset_name = sys.argv[2]
|
|
35
|
+
|
|
36
|
+
if not os.path.exists(script_path):
|
|
37
|
+
print(f"Error: Script not found: {script_path}")
|
|
38
|
+
sys.exit(1)
|
|
39
|
+
|
|
40
|
+
# Create temp directories
|
|
41
|
+
model_dir = tempfile.mkdtemp(prefix="training_model_")
|
|
42
|
+
train_dir = tempfile.mkdtemp(prefix="training_data_")
|
|
43
|
+
output_dir = tempfile.mkdtemp(prefix="training_output_")
|
|
44
|
+
|
|
45
|
+
print(f"Model dir: {model_dir}")
|
|
46
|
+
print(f"Train dir: {train_dir}")
|
|
47
|
+
|
|
48
|
+
try:
|
|
49
|
+
# Get training data and save to CSV
|
|
50
|
+
print(f"Loading FeatureSet: {featureset_name}")
|
|
51
|
+
df = get_training_data(featureset_name)
|
|
52
|
+
print(f"Data shape: {df.shape}")
|
|
53
|
+
|
|
54
|
+
train_file = os.path.join(train_dir, "training_data.csv")
|
|
55
|
+
df.to_csv(train_file, index=False)
|
|
56
|
+
|
|
57
|
+
# Set up environment
|
|
58
|
+
env = os.environ.copy()
|
|
59
|
+
env["SM_MODEL_DIR"] = model_dir
|
|
60
|
+
env["SM_CHANNEL_TRAIN"] = train_dir
|
|
61
|
+
env["SM_OUTPUT_DATA_DIR"] = output_dir
|
|
62
|
+
|
|
63
|
+
print("\n" + "=" * 60)
|
|
64
|
+
print("Starting training...")
|
|
65
|
+
print("=" * 60 + "\n")
|
|
66
|
+
|
|
67
|
+
# Run the script
|
|
68
|
+
cmd = [sys.executable, script_path, "--model-dir", model_dir, "--train", train_dir]
|
|
69
|
+
result = subprocess.run(cmd, env=env)
|
|
70
|
+
|
|
71
|
+
print("\n" + "=" * 60)
|
|
72
|
+
if result.returncode == 0:
|
|
73
|
+
print("Training completed successfully!")
|
|
74
|
+
else:
|
|
75
|
+
print(f"Training failed with return code: {result.returncode}")
|
|
76
|
+
print("=" * 60)
|
|
77
|
+
|
|
78
|
+
finally:
|
|
79
|
+
shutil.rmtree(model_dir, ignore_errors=True)
|
|
80
|
+
shutil.rmtree(train_dir, ignore_errors=True)
|
|
81
|
+
shutil.rmtree(output_dir, ignore_errors=True)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
if __name__ == "__main__":
|
|
85
|
+
main()
|
workbench/utils/aws_utils.py
CHANGED
|
@@ -55,7 +55,8 @@ def aws_throttle(func=None, retry_intervals=None):
|
|
|
55
55
|
if func is None:
|
|
56
56
|
return lambda f: aws_throttle(f, retry_intervals=retry_intervals)
|
|
57
57
|
|
|
58
|
-
|
|
58
|
+
# This is currently commented out (we might want to use it later)
|
|
59
|
+
# service_hold_time = 2 # Seconds to wait before calling AWS function
|
|
59
60
|
default_intervals = [2**i for i in range(1, 9)] # Default exponential backoff: 2, 4, 8... 256 seconds
|
|
60
61
|
intervals = retry_intervals or default_intervals
|
|
61
62
|
|
|
@@ -64,8 +65,8 @@ def aws_throttle(func=None, retry_intervals=None):
|
|
|
64
65
|
for attempt, delay in enumerate(intervals, start=1):
|
|
65
66
|
try:
|
|
66
67
|
# Add sleep before calling AWS func if running as a service
|
|
67
|
-
if cm.running_as_service:
|
|
68
|
-
|
|
68
|
+
# if cm.running_as_service:
|
|
69
|
+
# time.sleep(service_hold_time)
|
|
69
70
|
return func(*args, **kwargs)
|
|
70
71
|
except ClientError as e:
|
|
71
72
|
if e.response["Error"]["Code"] == "ThrottlingException":
|
|
File without changes
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
"""Molecular fingerprint computation utilities for ADMET modeling.
|
|
2
|
+
|
|
3
|
+
This module provides Morgan count fingerprints, the standard for ADMET prediction.
|
|
4
|
+
Count fingerprints outperform binary fingerprints for molecular property prediction.
|
|
5
|
+
|
|
6
|
+
References:
|
|
7
|
+
- Count vs Binary: https://pubs.acs.org/doi/10.1021/acs.est.3c02198
|
|
8
|
+
- ECFP/Morgan: https://pubs.acs.org/doi/10.1021/ci100050t
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import logging
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
import pandas as pd
|
|
15
|
+
from rdkit import Chem, RDLogger
|
|
16
|
+
from rdkit.Chem import AllChem
|
|
17
|
+
from rdkit.Chem.MolStandardize import rdMolStandardize
|
|
18
|
+
|
|
19
|
+
# Suppress RDKit warnings (e.g., "not removing hydrogen atom without neighbors")
|
|
20
|
+
# Keep errors enabled so we see actual problems
|
|
21
|
+
RDLogger.DisableLog("rdApp.warning")
|
|
22
|
+
|
|
23
|
+
# Set up the logger
|
|
24
|
+
log = logging.getLogger("workbench")
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def compute_morgan_fingerprints(df: pd.DataFrame, radius: int = 2, n_bits: int = 2048) -> pd.DataFrame:
|
|
28
|
+
"""Compute Morgan count fingerprints for ADMET modeling.
|
|
29
|
+
|
|
30
|
+
Generates true count fingerprints where each bit position contains the
|
|
31
|
+
number of times that substructure appears in the molecule (clamped to 0-255).
|
|
32
|
+
This is the recommended approach for ADMET prediction per 2025 research.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
df: Input DataFrame containing SMILES strings.
|
|
36
|
+
radius: Radius for the Morgan fingerprint (default 2 = ECFP4 equivalent).
|
|
37
|
+
n_bits: Number of bits for the fingerprint (default 2048).
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
pd.DataFrame: Input DataFrame with 'fingerprint' column added.
|
|
41
|
+
Values are comma-separated uint8 counts.
|
|
42
|
+
|
|
43
|
+
Note:
|
|
44
|
+
Count fingerprints outperform binary for ADMET prediction.
|
|
45
|
+
See: https://pubs.acs.org/doi/10.1021/acs.est.3c02198
|
|
46
|
+
"""
|
|
47
|
+
delete_mol_column = False
|
|
48
|
+
|
|
49
|
+
# Check for the SMILES column (case-insensitive)
|
|
50
|
+
smiles_column = next((col for col in df.columns if col.lower() == "smiles"), None)
|
|
51
|
+
if smiles_column is None:
|
|
52
|
+
raise ValueError("Input DataFrame must have a 'smiles' column")
|
|
53
|
+
|
|
54
|
+
# Sanity check the molecule column (sometimes it gets serialized, which doesn't work)
|
|
55
|
+
if "molecule" in df.columns and df["molecule"].dtype == "string":
|
|
56
|
+
log.warning("Detected serialized molecules in 'molecule' column. Removing...")
|
|
57
|
+
del df["molecule"]
|
|
58
|
+
|
|
59
|
+
# Convert SMILES to RDKit molecule objects
|
|
60
|
+
if "molecule" not in df.columns:
|
|
61
|
+
log.info("Converting SMILES to RDKit Molecules...")
|
|
62
|
+
delete_mol_column = True
|
|
63
|
+
df["molecule"] = df[smiles_column].apply(Chem.MolFromSmiles)
|
|
64
|
+
# Make sure our molecules are not None
|
|
65
|
+
failed_smiles = df[df["molecule"].isnull()][smiles_column].tolist()
|
|
66
|
+
if failed_smiles:
|
|
67
|
+
log.warning(f"Failed to convert {len(failed_smiles)} SMILES to molecules ({failed_smiles})")
|
|
68
|
+
df = df.dropna(subset=["molecule"]).copy()
|
|
69
|
+
|
|
70
|
+
# If we have fragments in our compounds, get the largest fragment before computing fingerprints
|
|
71
|
+
largest_frags = df["molecule"].apply(
|
|
72
|
+
lambda mol: rdMolStandardize.LargestFragmentChooser().choose(mol) if mol else None
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
def mol_to_count_string(mol):
|
|
76
|
+
"""Convert molecule to comma-separated count fingerprint string."""
|
|
77
|
+
if mol is None:
|
|
78
|
+
return pd.NA
|
|
79
|
+
|
|
80
|
+
# Get hashed Morgan fingerprint with counts
|
|
81
|
+
fp = AllChem.GetHashedMorganFingerprint(mol, radius, nBits=n_bits)
|
|
82
|
+
|
|
83
|
+
# Initialize array and populate with counts (clamped to uint8 range)
|
|
84
|
+
counts = np.zeros(n_bits, dtype=np.uint8)
|
|
85
|
+
for idx, count in fp.GetNonzeroElements().items():
|
|
86
|
+
counts[idx] = min(count, 255)
|
|
87
|
+
|
|
88
|
+
# Return as comma-separated string
|
|
89
|
+
return ",".join(map(str, counts))
|
|
90
|
+
|
|
91
|
+
# Compute Morgan count fingerprints
|
|
92
|
+
fingerprints = largest_frags.apply(mol_to_count_string)
|
|
93
|
+
|
|
94
|
+
# Add the fingerprints to the DataFrame
|
|
95
|
+
df["fingerprint"] = fingerprints
|
|
96
|
+
|
|
97
|
+
# Drop the intermediate 'molecule' column if it was added
|
|
98
|
+
if delete_mol_column:
|
|
99
|
+
del df["molecule"]
|
|
100
|
+
|
|
101
|
+
return df
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
if __name__ == "__main__":
|
|
105
|
+
print("Running Morgan count fingerprint tests...")
|
|
106
|
+
|
|
107
|
+
# Test molecules
|
|
108
|
+
test_molecules = {
|
|
109
|
+
"aspirin": "CC(=O)OC1=CC=CC=C1C(=O)O",
|
|
110
|
+
"caffeine": "CN1C=NC2=C1C(=O)N(C(=O)N2C)C",
|
|
111
|
+
"glucose": "C([C@@H]1[C@H]([C@@H]([C@H](C(O1)O)O)O)O)O", # With stereochemistry
|
|
112
|
+
"sodium_acetate": "CC(=O)[O-].[Na+]", # Salt (largest fragment used)
|
|
113
|
+
"benzene": "c1ccccc1",
|
|
114
|
+
"butene_e": "C/C=C/C", # E-butene
|
|
115
|
+
"butene_z": "C/C=C\\C", # Z-butene
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
# Test 1: Morgan Count Fingerprints (default parameters)
|
|
119
|
+
print("\n1. Testing Morgan fingerprint generation (radius=2, n_bits=2048)...")
|
|
120
|
+
|
|
121
|
+
test_df = pd.DataFrame({"SMILES": list(test_molecules.values()), "name": list(test_molecules.keys())})
|
|
122
|
+
fp_df = compute_morgan_fingerprints(test_df.copy())
|
|
123
|
+
|
|
124
|
+
print(" Fingerprint generation results:")
|
|
125
|
+
for _, row in fp_df.iterrows():
|
|
126
|
+
fp = row.get("fingerprint", "N/A")
|
|
127
|
+
if pd.notna(fp):
|
|
128
|
+
counts = [int(x) for x in fp.split(",")]
|
|
129
|
+
non_zero = sum(1 for c in counts if c > 0)
|
|
130
|
+
max_count = max(counts)
|
|
131
|
+
print(f" {row['name']:15} → {len(counts)} features, {non_zero} non-zero, max={max_count}")
|
|
132
|
+
else:
|
|
133
|
+
print(f" {row['name']:15} → N/A")
|
|
134
|
+
|
|
135
|
+
# Test 2: Different parameters
|
|
136
|
+
print("\n2. Testing with different parameters (radius=3, n_bits=1024)...")
|
|
137
|
+
|
|
138
|
+
fp_df_custom = compute_morgan_fingerprints(test_df.copy(), radius=3, n_bits=1024)
|
|
139
|
+
|
|
140
|
+
for _, row in fp_df_custom.iterrows():
|
|
141
|
+
fp = row.get("fingerprint", "N/A")
|
|
142
|
+
if pd.notna(fp):
|
|
143
|
+
counts = [int(x) for x in fp.split(",")]
|
|
144
|
+
non_zero = sum(1 for c in counts if c > 0)
|
|
145
|
+
print(f" {row['name']:15} → {len(counts)} features, {non_zero} non-zero")
|
|
146
|
+
else:
|
|
147
|
+
print(f" {row['name']:15} → N/A")
|
|
148
|
+
|
|
149
|
+
# Test 3: Edge cases
|
|
150
|
+
print("\n3. Testing edge cases...")
|
|
151
|
+
|
|
152
|
+
# Invalid SMILES
|
|
153
|
+
invalid_df = pd.DataFrame({"SMILES": ["INVALID", ""]})
|
|
154
|
+
fp_invalid = compute_morgan_fingerprints(invalid_df.copy())
|
|
155
|
+
print(f" ✓ Invalid SMILES handled: {len(fp_invalid)} rows returned")
|
|
156
|
+
|
|
157
|
+
# Test with pre-existing molecule column
|
|
158
|
+
mol_df = test_df.copy()
|
|
159
|
+
mol_df["molecule"] = mol_df["SMILES"].apply(Chem.MolFromSmiles)
|
|
160
|
+
fp_with_mol = compute_morgan_fingerprints(mol_df)
|
|
161
|
+
print(f" ✓ Pre-existing molecule column handled: {len(fp_with_mol)} fingerprints generated")
|
|
162
|
+
|
|
163
|
+
# Test 4: Verify count values are reasonable
|
|
164
|
+
print("\n4. Verifying count distribution...")
|
|
165
|
+
all_counts = []
|
|
166
|
+
for _, row in fp_df.iterrows():
|
|
167
|
+
fp = row.get("fingerprint", "N/A")
|
|
168
|
+
if pd.notna(fp):
|
|
169
|
+
counts = [int(x) for x in fp.split(",")]
|
|
170
|
+
all_counts.extend([c for c in counts if c > 0])
|
|
171
|
+
|
|
172
|
+
if all_counts:
|
|
173
|
+
print(f" Non-zero counts: min={min(all_counts)}, max={max(all_counts)}, mean={np.mean(all_counts):.2f}")
|
|
174
|
+
|
|
175
|
+
print("\n✅ All fingerprint tests completed!")
|