workbench 0.8.162__py3-none-any.whl → 0.8.220__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (147) hide show
  1. workbench/algorithms/dataframe/__init__.py +1 -2
  2. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  3. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  4. workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
  5. workbench/algorithms/dataframe/projection_2d.py +44 -21
  6. workbench/algorithms/dataframe/proximity.py +259 -305
  7. workbench/algorithms/graph/light/proximity_graph.py +14 -12
  8. workbench/algorithms/models/cleanlab_model.py +382 -0
  9. workbench/algorithms/models/noise_model.py +388 -0
  10. workbench/algorithms/sql/outliers.py +3 -3
  11. workbench/api/__init__.py +5 -1
  12. workbench/api/compound.py +1 -1
  13. workbench/api/df_store.py +17 -108
  14. workbench/api/endpoint.py +18 -5
  15. workbench/api/feature_set.py +121 -15
  16. workbench/api/meta.py +5 -2
  17. workbench/api/meta_model.py +289 -0
  18. workbench/api/model.py +55 -21
  19. workbench/api/monitor.py +1 -16
  20. workbench/api/parameter_store.py +3 -52
  21. workbench/cached/cached_model.py +4 -4
  22. workbench/core/artifacts/__init__.py +11 -2
  23. workbench/core/artifacts/artifact.py +16 -8
  24. workbench/core/artifacts/data_capture_core.py +355 -0
  25. workbench/core/artifacts/df_store_core.py +114 -0
  26. workbench/core/artifacts/endpoint_core.py +382 -253
  27. workbench/core/artifacts/feature_set_core.py +249 -45
  28. workbench/core/artifacts/model_core.py +135 -80
  29. workbench/core/artifacts/monitor_core.py +33 -248
  30. workbench/core/artifacts/parameter_store_core.py +98 -0
  31. workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
  32. workbench/core/cloud_platform/aws/aws_meta.py +12 -5
  33. workbench/core/cloud_platform/aws/aws_session.py +4 -4
  34. workbench/core/pipelines/pipeline_executor.py +1 -1
  35. workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
  36. workbench/core/transforms/features_to_model/features_to_model.py +62 -40
  37. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +76 -15
  38. workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
  39. workbench/core/views/training_view.py +113 -42
  40. workbench/core/views/view.py +53 -3
  41. workbench/core/views/view_utils.py +4 -4
  42. workbench/model_script_utils/model_script_utils.py +339 -0
  43. workbench/model_script_utils/pytorch_utils.py +405 -0
  44. workbench/model_script_utils/uq_harness.py +278 -0
  45. workbench/model_scripts/chemprop/chemprop.template +649 -0
  46. workbench/model_scripts/chemprop/generated_model_script.py +649 -0
  47. workbench/model_scripts/chemprop/model_script_utils.py +339 -0
  48. workbench/model_scripts/chemprop/requirements.txt +3 -0
  49. workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
  50. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
  51. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
  52. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
  53. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
  54. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  55. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
  56. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  57. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  58. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  59. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  60. workbench/model_scripts/custom_models/uq_models/ngboost.template +30 -18
  61. workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
  62. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  63. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  64. workbench/model_scripts/meta_model/meta_model.template +209 -0
  65. workbench/model_scripts/pytorch_model/generated_model_script.py +444 -500
  66. workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
  67. workbench/model_scripts/pytorch_model/pytorch.template +440 -496
  68. workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
  69. workbench/model_scripts/pytorch_model/requirements.txt +1 -1
  70. workbench/model_scripts/pytorch_model/uq_harness.py +278 -0
  71. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  72. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  73. workbench/model_scripts/script_generation.py +20 -11
  74. workbench/model_scripts/uq_models/generated_model_script.py +248 -0
  75. workbench/model_scripts/xgb_model/generated_model_script.py +372 -404
  76. workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
  77. workbench/model_scripts/xgb_model/uq_harness.py +278 -0
  78. workbench/model_scripts/xgb_model/xgb_model.template +369 -401
  79. workbench/repl/workbench_shell.py +28 -19
  80. workbench/resources/open_source_api.key +1 -1
  81. workbench/scripts/endpoint_test.py +162 -0
  82. workbench/scripts/lambda_test.py +73 -0
  83. workbench/scripts/meta_model_sim.py +35 -0
  84. workbench/scripts/ml_pipeline_batch.py +137 -0
  85. workbench/scripts/ml_pipeline_sqs.py +186 -0
  86. workbench/scripts/monitor_cloud_watch.py +20 -100
  87. workbench/scripts/training_test.py +85 -0
  88. workbench/utils/aws_utils.py +4 -3
  89. workbench/utils/chem_utils/__init__.py +0 -0
  90. workbench/utils/chem_utils/fingerprints.py +175 -0
  91. workbench/utils/chem_utils/misc.py +194 -0
  92. workbench/utils/chem_utils/mol_descriptors.py +483 -0
  93. workbench/utils/chem_utils/mol_standardize.py +450 -0
  94. workbench/utils/chem_utils/mol_tagging.py +348 -0
  95. workbench/utils/chem_utils/projections.py +219 -0
  96. workbench/utils/chem_utils/salts.py +256 -0
  97. workbench/utils/chem_utils/sdf.py +292 -0
  98. workbench/utils/chem_utils/toxicity.py +250 -0
  99. workbench/utils/chem_utils/vis.py +253 -0
  100. workbench/utils/chemprop_utils.py +141 -0
  101. workbench/utils/cloudwatch_handler.py +1 -1
  102. workbench/utils/cloudwatch_utils.py +137 -0
  103. workbench/utils/config_manager.py +3 -7
  104. workbench/utils/endpoint_utils.py +5 -7
  105. workbench/utils/license_manager.py +2 -6
  106. workbench/utils/meta_model_simulator.py +499 -0
  107. workbench/utils/metrics_utils.py +256 -0
  108. workbench/utils/model_utils.py +278 -79
  109. workbench/utils/monitor_utils.py +44 -62
  110. workbench/utils/pandas_utils.py +3 -3
  111. workbench/utils/pytorch_utils.py +87 -0
  112. workbench/utils/shap_utils.py +11 -57
  113. workbench/utils/workbench_logging.py +0 -3
  114. workbench/utils/workbench_sqs.py +1 -1
  115. workbench/utils/xgboost_local_crossfold.py +267 -0
  116. workbench/utils/xgboost_model_utils.py +127 -219
  117. workbench/web_interface/components/model_plot.py +14 -2
  118. workbench/web_interface/components/plugin_unit_test.py +5 -2
  119. workbench/web_interface/components/plugins/dashboard_status.py +3 -1
  120. workbench/web_interface/components/plugins/generated_compounds.py +1 -1
  121. workbench/web_interface/components/plugins/model_details.py +38 -74
  122. workbench/web_interface/components/plugins/scatter_plot.py +6 -10
  123. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/METADATA +31 -9
  124. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/RECORD +128 -96
  125. workbench-0.8.220.dist-info/entry_points.txt +11 -0
  126. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/licenses/LICENSE +1 -1
  127. workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
  128. workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
  129. workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
  130. workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
  131. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  132. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  133. workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
  134. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  135. workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
  136. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -273
  137. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
  138. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  139. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  140. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  141. workbench/utils/chem_utils.py +0 -1556
  142. workbench/utils/execution_environment.py +0 -211
  143. workbench/utils/fast_inference.py +0 -167
  144. workbench/utils/resource_utils.py +0 -39
  145. workbench-0.8.162.dist-info/entry_points.txt +0 -5
  146. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/WHEEL +0 -0
  147. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,186 @@
1
+ import argparse
2
+ import logging
3
+ import json
4
+ from pathlib import Path
5
+
6
+ # Workbench Imports
7
+ from workbench.core.cloud_platform.aws.aws_account_clamp import AWSAccountClamp
8
+ from workbench.utils.config_manager import ConfigManager
9
+ from workbench.utils.s3_utils import upload_content_to_s3
10
+
11
+ log = logging.getLogger("workbench")
12
+ cm = ConfigManager()
13
+ workbench_bucket = cm.get_config("WORKBENCH_BUCKET")
14
+
15
+
16
+ def submit_to_sqs(
17
+ script_path: str,
18
+ size: str = "small",
19
+ realtime: bool = False,
20
+ dt: bool = False,
21
+ promote: bool = False,
22
+ ) -> None:
23
+ """
24
+ Upload script to S3 and submit message to SQS queue for processing.
25
+
26
+ Args:
27
+ script_path: Local path to the ML pipeline script
28
+ size: Job size tier - "small" (default), "medium", or "large"
29
+ realtime: If True, sets serverless=False for real-time processing (default: False)
30
+ dt: If True, sets DT=True in environment (default: False)
31
+ promote: If True, sets PROMOTE=True in environment (default: False)
32
+
33
+ Raises:
34
+ ValueError: If size is invalid or script file not found
35
+ """
36
+ print(f"\n{'=' * 60}")
37
+ print("🚀 SUBMITTING ML PIPELINE JOB")
38
+ print(f"{'=' * 60}")
39
+ if size not in ["small", "medium", "large"]:
40
+ raise ValueError(f"Invalid size '{size}'. Must be 'small', 'medium', or 'large'")
41
+
42
+ # Validate script exists
43
+ script_file = Path(script_path)
44
+ if not script_file.exists():
45
+ raise FileNotFoundError(f"Script not found: {script_path}")
46
+
47
+ print(f"📄 Script: {script_file.name}")
48
+ print(f"📏 Size tier: {size}")
49
+ print(f"⚡ Mode: {'Real-time' if realtime else 'Serverless'} (serverless={'False' if realtime else 'True'})")
50
+ print(f"🔄 DynamicTraining: {dt}")
51
+ print(f"🆕 Promote: {promote}")
52
+ print(f"🪣 Bucket: {workbench_bucket}")
53
+ sqs = AWSAccountClamp().boto3_session.client("sqs")
54
+ script_name = script_file.name
55
+
56
+ # List Workbench queues
57
+ print("\n📋 Listing Workbench SQS queues...")
58
+ try:
59
+ queues = sqs.list_queues(QueueNamePrefix="workbench-")
60
+ queue_urls = queues.get("QueueUrls", [])
61
+ if queue_urls:
62
+ print(f"✅ Found {len(queue_urls)} workbench queue(s):")
63
+ for url in queue_urls:
64
+ queue_name = url.split("/")[-1]
65
+ print(f" • {queue_name}")
66
+ else:
67
+ print("⚠️ No workbench queues found")
68
+ except Exception as e:
69
+ print(f"❌ Error listing queues: {e}")
70
+
71
+ # Upload script to S3
72
+ s3_path = f"s3://{workbench_bucket}/batch-jobs/{script_name}"
73
+ print("\n📤 Uploading script to S3...")
74
+ print(f" Source: {script_path}")
75
+ print(f" Destination: {s3_path}")
76
+
77
+ try:
78
+ upload_content_to_s3(script_file.read_text(), s3_path)
79
+ print("✅ Script uploaded successfully")
80
+ except Exception as e:
81
+ print(f"❌ Upload failed: {e}")
82
+ raise
83
+ # Get queue URL and info
84
+ queue_name = "workbench-ml-pipeline-queue.fifo"
85
+ print("\n🎯 Getting queue information...")
86
+ print(f" Queue name: {queue_name}")
87
+
88
+ try:
89
+ queue_url = sqs.get_queue_url(QueueName=queue_name)["QueueUrl"]
90
+ print(f" Queue URL: {queue_url}")
91
+
92
+ # Get queue attributes for additional info
93
+ attrs = sqs.get_queue_attributes(
94
+ QueueUrl=queue_url, AttributeNames=["ApproximateNumberOfMessages", "ApproximateNumberOfMessagesNotVisible"]
95
+ )
96
+ messages_available = attrs["Attributes"].get("ApproximateNumberOfMessages", "0")
97
+ messages_in_flight = attrs["Attributes"].get("ApproximateNumberOfMessagesNotVisible", "0")
98
+ print(f" Messages in queue: {messages_available}")
99
+ print(f" Messages in flight: {messages_in_flight}")
100
+
101
+ except Exception as e:
102
+ print(f"❌ Error accessing queue: {e}")
103
+ raise
104
+
105
+ # Prepare message
106
+ message = {"script_path": s3_path, "size": size}
107
+
108
+ # Set environment variables
109
+ message["environment"] = {
110
+ "SERVERLESS": "False" if realtime else "True",
111
+ "DT": str(dt),
112
+ "PROMOTE": str(promote),
113
+ }
114
+
115
+ # Send the message to SQS
116
+ try:
117
+ print("\n📨 Sending message to SQS...")
118
+ response = sqs.send_message(
119
+ QueueUrl=queue_url,
120
+ MessageBody=json.dumps(message, indent=2),
121
+ MessageGroupId="ml-pipeline-jobs", # Required for FIFO
122
+ )
123
+ message_id = response["MessageId"]
124
+ print("✅ Message sent successfully!")
125
+ print(f" Message ID: {message_id}")
126
+ except Exception as e:
127
+ print(f"❌ Failed to send message: {e}")
128
+ raise
129
+
130
+ # Success summary
131
+ print(f"\n{'=' * 60}")
132
+ print("✅ JOB SUBMISSION COMPLETE")
133
+ print(f"{'=' * 60}")
134
+ print(f"📄 Script: {script_name}")
135
+ print(f"📏 Size: {size}")
136
+ print(f"⚡ Mode: {'Real-time' if realtime else 'Serverless'} (SERVERLESS={'False' if realtime else 'True'})")
137
+ print(f"🔄 DynamicTraining: {dt}")
138
+ print(f"🆕 Promote: {promote}")
139
+ print(f"🆔 Message ID: {message_id}")
140
+ print("\n🔍 MONITORING LOCATIONS:")
141
+ print(f" • SQS Queue: AWS Console → SQS → {queue_name}")
142
+ print(" • Lambda Logs: AWS Console → Lambda → Functions")
143
+ print(" • Batch Jobs: AWS Console → Batch → Jobs")
144
+ print(" • CloudWatch: AWS Console → CloudWatch → Log groups")
145
+ print("\n⏳ Your job should start processing soon...")
146
+
147
+
148
+ def main():
149
+ """CLI entry point for submitting ML pipelines via SQS."""
150
+ parser = argparse.ArgumentParser(description="Submit ML pipeline to SQS queue for Batch processing")
151
+ parser.add_argument("script_file", help="Local path to ML pipeline script")
152
+ parser.add_argument(
153
+ "--size", default="small", choices=["small", "medium", "large"], help="Job size tier (default: small)"
154
+ )
155
+ parser.add_argument(
156
+ "--realtime",
157
+ action="store_true",
158
+ help="Create realtime endpoints (default is serverless)",
159
+ )
160
+ parser.add_argument(
161
+ "--dt",
162
+ action="store_true",
163
+ help="Set DT=True (models and endpoints will have '-dt' suffix)",
164
+ )
165
+ parser.add_argument(
166
+ "--promote",
167
+ action="store_true",
168
+ help="Set Promote=True (models and endpoints will use promoted naming",
169
+ )
170
+ args = parser.parse_args()
171
+ try:
172
+ submit_to_sqs(
173
+ args.script_file,
174
+ args.size,
175
+ realtime=args.realtime,
176
+ dt=args.dt,
177
+ promote=args.promote,
178
+ )
179
+ except Exception as e:
180
+ print(f"\n❌ ERROR: {e}")
181
+ log.error(f"Error: {e}")
182
+ exit(1)
183
+
184
+
185
+ if __name__ == "__main__":
186
+ main()
@@ -4,8 +4,10 @@ import sys
4
4
  import time
5
5
  import argparse
6
6
  from datetime import datetime, timedelta, timezone
7
- from workbench.core.cloud_platform.aws.aws_account_clamp import AWSAccountClamp
7
+
8
+ # Workbench Imports
8
9
  from workbench.utils.repl_utils import cprint, Spinner
10
+ from workbench.utils.cloudwatch_utils import get_cloudwatch_client, get_active_log_streams, stream_log_events
9
11
 
10
12
  # Define the log levels to include all log levels above the specified level
11
13
  log_level_map = {
@@ -33,64 +35,6 @@ def date_display(dt):
33
35
  return dt.strftime("%Y-%m-%d %I:%M%p") + "(UTC)"
34
36
 
35
37
 
36
- def get_cloudwatch_client():
37
- """Get the CloudWatch Logs client using the Workbench assumed role session."""
38
- session = AWSAccountClamp().boto3_session
39
- return session.client("logs")
40
-
41
-
42
- def get_active_log_streams(client, log_group_name, start_time_ms, stream_filter=None):
43
- """Retrieve log streams that have events after the specified start time."""
44
-
45
- # Get all the streams in the log group
46
- active_streams = []
47
- stream_params = {
48
- "logGroupName": log_group_name,
49
- "orderBy": "LastEventTime",
50
- "descending": True,
51
- }
52
-
53
- # Loop to retrieve all log streams (maximum 50 per call)
54
- while True:
55
- response = client.describe_log_streams(**stream_params)
56
- log_streams = response.get("logStreams", [])
57
-
58
- for log_stream in log_streams:
59
- log_stream_name = log_stream["logStreamName"]
60
- last_event_timestamp = log_stream.get("lastEventTimestamp")
61
-
62
- # Include streams with events since the specified start time
63
- # Note: There's some issue where the last event timestamp is 'off'
64
- # so we're going to add 60 minutes from the last event timestamp
65
- last_event_timestamp += 60 * 60 * 1000
66
- if last_event_timestamp >= start_time_ms:
67
- active_streams.append(log_stream_name)
68
- else:
69
- break # Stop if we reach streams older than the start time
70
-
71
- # Check if there are more streams to retrieve
72
- if "nextToken" in response:
73
- stream_params["nextToken"] = response["nextToken"]
74
- else:
75
- break
76
-
77
- # Sort and report the active log streams
78
- active_streams.sort()
79
- if active_streams:
80
- print("Active log streams:", len(active_streams))
81
-
82
- # Filter the active streams by a substring if provided
83
- if stream_filter and active_streams:
84
- print(f"Filtering active log streams by '{stream_filter}'...")
85
- active_streams = [stream for stream in active_streams if stream_filter in stream]
86
-
87
- for stream in active_streams:
88
- print(f"\t - {stream}")
89
-
90
- # Return the active log streams
91
- return active_streams
92
-
93
-
94
38
  def get_latest_log_events(client, log_group_name, start_time, end_time=None, stream_filter=None):
95
39
  """Retrieve the latest log events from the active/filtered log streams in a CloudWatch Logs group."""
96
40
 
@@ -99,11 +43,15 @@ def get_latest_log_events(client, log_group_name, start_time, end_time=None, str
99
43
  get_latest_log_events.first_run = True
100
44
 
101
45
  log_events = []
102
- start_time_ms = int(start_time.timestamp() * 1000) # Convert start_time to milliseconds
46
+ start_time_ms = int(start_time.timestamp() * 1000)
47
+
48
+ # Use the util function to get active streams
49
+ active_streams = get_active_log_streams(log_group_name, start_time_ms, stream_filter, client)
103
50
 
104
- # Get the active log streams with events since start_time
105
- active_streams = get_active_log_streams(client, log_group_name, start_time_ms, stream_filter)
106
51
  if active_streams:
52
+ print(f"Active log streams: {len(active_streams)}")
53
+ for stream in active_streams:
54
+ print(f"\t - {stream}")
107
55
  print(f"Processing log events from {date_display(start_time)} on {len(active_streams)} active log streams...")
108
56
  get_latest_log_events.first_run = False
109
57
  else:
@@ -114,50 +62,22 @@ def get_latest_log_events(client, log_group_name, start_time, end_time=None, str
114
62
  print("Monitoring for new events...")
115
63
  return log_events
116
64
 
117
- # Iterate over the active streams and fetch log events
65
+ # Use the util function to stream events from each log stream
118
66
  for log_stream_name in active_streams:
119
- params = {
120
- "logGroupName": log_group_name,
121
- "logStreamName": log_stream_name,
122
- "startTime": start_time_ms, # Use start_time in milliseconds
123
- "startFromHead": True, # Start from the nearest event to start_time
124
- }
125
- next_event_token = None
126
- if end_time is not None:
127
- params["endTime"] = int(end_time.timestamp() * 1000)
128
-
129
- # Process the log events from this log stream
130
67
  spinner = Spinner("lightpurple", f"Pulling events from {log_stream_name}:")
131
68
  spinner.start()
132
69
  log_stream_events = 0
133
70
 
134
- # Get the log events for the active log stream
135
- while True:
136
- if next_event_token:
137
- params["nextToken"] = next_event_token
138
- params.pop("startTime", None) # Remove startTime when using nextToken
71
+ # Stream events using the util function
72
+ for event in stream_log_events(
73
+ log_group_name, log_stream_name, start_time, end_time, follow=False, client=client
74
+ ):
75
+ log_stream_events += 1
76
+ log_events.append(event)
139
77
 
140
- # Fetch the log events (this call takes a while: optimize if we can)
141
- events_response = client.get_log_events(**params)
78
+ spinner.stop()
79
+ print(f"Processed {log_stream_events} events from {log_stream_name} (Total: {len(log_events)})")
142
80
 
143
- events = events_response.get("events", [])
144
- for event in events:
145
- event["logStreamName"] = log_stream_name
146
-
147
- # Add the log stream events to our list of all log events
148
- log_stream_events += len(events)
149
- log_events.extend(events)
150
-
151
- # Handle pagination for log events
152
- next_event_token = events_response.get("nextForwardToken")
153
-
154
- # Break the loop if there are no more events to fetch
155
- if not next_event_token or next_event_token == params.get("nextToken"):
156
- spinner.stop()
157
- print(f"Processed {log_stream_events} events from {log_stream_name} (Total: {len(log_events)})")
158
- break
159
-
160
- # Return the log events
161
81
  return log_events
162
82
 
163
83
 
@@ -206,6 +126,7 @@ def monitor_log_group(
206
126
  print(f"Monitoring log group: {log_group_name} from {date_display(start_time)}")
207
127
  print(f"Log levels: {log_levels}")
208
128
  print(f"Search terms: {search_terms}")
129
+
209
130
  while True:
210
131
  # Get the latest log events with stream filtering if provided
211
132
  all_log_events = get_latest_log_events(client, log_group_name, start_time, end_time, stream_filter)
@@ -218,7 +139,6 @@ def monitor_log_group(
218
139
 
219
140
  # Check the search terms
220
141
  if not search_terms or any(term in event["message"].lower() for term in search_terms):
221
-
222
142
  # Calculate the start and end index for this match
223
143
  start_index = max(i - before, 0)
224
144
  end_index = min(i + after, len(all_log_events) - 1)
@@ -0,0 +1,85 @@
1
+ """
2
+ Local test harness for SageMaker training scripts.
3
+
4
+ Usage:
5
+ python training_test.py <model_script.py> <featureset_name>
6
+
7
+ Example:
8
+ python training_test.py ../model_scripts/pytorch_model/generated_model_script.py caco2-class-features
9
+ """
10
+
11
+ import os
12
+ import shutil
13
+ import subprocess
14
+ import sys
15
+ import tempfile
16
+
17
+ import pandas as pd
18
+
19
+ from workbench.api import FeatureSet
20
+
21
+
22
+ def get_training_data(featureset_name: str) -> pd.DataFrame:
23
+ """Get training data from the FeatureSet."""
24
+ fs = FeatureSet(featureset_name)
25
+ return fs.pull_dataframe()
26
+
27
+
28
+ def main():
29
+ if len(sys.argv) < 3:
30
+ print("Usage: python training_test.py <model_script.py> <featureset_name>")
31
+ sys.exit(1)
32
+
33
+ script_path = sys.argv[1]
34
+ featureset_name = sys.argv[2]
35
+
36
+ if not os.path.exists(script_path):
37
+ print(f"Error: Script not found: {script_path}")
38
+ sys.exit(1)
39
+
40
+ # Create temp directories
41
+ model_dir = tempfile.mkdtemp(prefix="training_model_")
42
+ train_dir = tempfile.mkdtemp(prefix="training_data_")
43
+ output_dir = tempfile.mkdtemp(prefix="training_output_")
44
+
45
+ print(f"Model dir: {model_dir}")
46
+ print(f"Train dir: {train_dir}")
47
+
48
+ try:
49
+ # Get training data and save to CSV
50
+ print(f"Loading FeatureSet: {featureset_name}")
51
+ df = get_training_data(featureset_name)
52
+ print(f"Data shape: {df.shape}")
53
+
54
+ train_file = os.path.join(train_dir, "training_data.csv")
55
+ df.to_csv(train_file, index=False)
56
+
57
+ # Set up environment
58
+ env = os.environ.copy()
59
+ env["SM_MODEL_DIR"] = model_dir
60
+ env["SM_CHANNEL_TRAIN"] = train_dir
61
+ env["SM_OUTPUT_DATA_DIR"] = output_dir
62
+
63
+ print("\n" + "=" * 60)
64
+ print("Starting training...")
65
+ print("=" * 60 + "\n")
66
+
67
+ # Run the script
68
+ cmd = [sys.executable, script_path, "--model-dir", model_dir, "--train", train_dir]
69
+ result = subprocess.run(cmd, env=env)
70
+
71
+ print("\n" + "=" * 60)
72
+ if result.returncode == 0:
73
+ print("Training completed successfully!")
74
+ else:
75
+ print(f"Training failed with return code: {result.returncode}")
76
+ print("=" * 60)
77
+
78
+ finally:
79
+ shutil.rmtree(model_dir, ignore_errors=True)
80
+ shutil.rmtree(train_dir, ignore_errors=True)
81
+ shutil.rmtree(output_dir, ignore_errors=True)
82
+
83
+
84
+ if __name__ == "__main__":
85
+ main()
@@ -55,7 +55,8 @@ def aws_throttle(func=None, retry_intervals=None):
55
55
  if func is None:
56
56
  return lambda f: aws_throttle(f, retry_intervals=retry_intervals)
57
57
 
58
- service_hold_time = 2 # Seconds to wait before calling AWS function
58
+ # This is currently commented out (we might want to use it later)
59
+ # service_hold_time = 2 # Seconds to wait before calling AWS function
59
60
  default_intervals = [2**i for i in range(1, 9)] # Default exponential backoff: 2, 4, 8... 256 seconds
60
61
  intervals = retry_intervals or default_intervals
61
62
 
@@ -64,8 +65,8 @@ def aws_throttle(func=None, retry_intervals=None):
64
65
  for attempt, delay in enumerate(intervals, start=1):
65
66
  try:
66
67
  # Add sleep before calling AWS func if running as a service
67
- if cm.running_as_service:
68
- time.sleep(service_hold_time)
68
+ # if cm.running_as_service:
69
+ # time.sleep(service_hold_time)
69
70
  return func(*args, **kwargs)
70
71
  except ClientError as e:
71
72
  if e.response["Error"]["Code"] == "ThrottlingException":
File without changes
@@ -0,0 +1,175 @@
1
+ """Molecular fingerprint computation utilities for ADMET modeling.
2
+
3
+ This module provides Morgan count fingerprints, the standard for ADMET prediction.
4
+ Count fingerprints outperform binary fingerprints for molecular property prediction.
5
+
6
+ References:
7
+ - Count vs Binary: https://pubs.acs.org/doi/10.1021/acs.est.3c02198
8
+ - ECFP/Morgan: https://pubs.acs.org/doi/10.1021/ci100050t
9
+ """
10
+
11
+ import logging
12
+
13
+ import numpy as np
14
+ import pandas as pd
15
+ from rdkit import Chem, RDLogger
16
+ from rdkit.Chem import AllChem
17
+ from rdkit.Chem.MolStandardize import rdMolStandardize
18
+
19
+ # Suppress RDKit warnings (e.g., "not removing hydrogen atom without neighbors")
20
+ # Keep errors enabled so we see actual problems
21
+ RDLogger.DisableLog("rdApp.warning")
22
+
23
+ # Set up the logger
24
+ log = logging.getLogger("workbench")
25
+
26
+
27
+ def compute_morgan_fingerprints(df: pd.DataFrame, radius: int = 2, n_bits: int = 2048) -> pd.DataFrame:
28
+ """Compute Morgan count fingerprints for ADMET modeling.
29
+
30
+ Generates true count fingerprints where each bit position contains the
31
+ number of times that substructure appears in the molecule (clamped to 0-255).
32
+ This is the recommended approach for ADMET prediction per 2025 research.
33
+
34
+ Args:
35
+ df: Input DataFrame containing SMILES strings.
36
+ radius: Radius for the Morgan fingerprint (default 2 = ECFP4 equivalent).
37
+ n_bits: Number of bits for the fingerprint (default 2048).
38
+
39
+ Returns:
40
+ pd.DataFrame: Input DataFrame with 'fingerprint' column added.
41
+ Values are comma-separated uint8 counts.
42
+
43
+ Note:
44
+ Count fingerprints outperform binary for ADMET prediction.
45
+ See: https://pubs.acs.org/doi/10.1021/acs.est.3c02198
46
+ """
47
+ delete_mol_column = False
48
+
49
+ # Check for the SMILES column (case-insensitive)
50
+ smiles_column = next((col for col in df.columns if col.lower() == "smiles"), None)
51
+ if smiles_column is None:
52
+ raise ValueError("Input DataFrame must have a 'smiles' column")
53
+
54
+ # Sanity check the molecule column (sometimes it gets serialized, which doesn't work)
55
+ if "molecule" in df.columns and df["molecule"].dtype == "string":
56
+ log.warning("Detected serialized molecules in 'molecule' column. Removing...")
57
+ del df["molecule"]
58
+
59
+ # Convert SMILES to RDKit molecule objects
60
+ if "molecule" not in df.columns:
61
+ log.info("Converting SMILES to RDKit Molecules...")
62
+ delete_mol_column = True
63
+ df["molecule"] = df[smiles_column].apply(Chem.MolFromSmiles)
64
+ # Make sure our molecules are not None
65
+ failed_smiles = df[df["molecule"].isnull()][smiles_column].tolist()
66
+ if failed_smiles:
67
+ log.warning(f"Failed to convert {len(failed_smiles)} SMILES to molecules ({failed_smiles})")
68
+ df = df.dropna(subset=["molecule"]).copy()
69
+
70
+ # If we have fragments in our compounds, get the largest fragment before computing fingerprints
71
+ largest_frags = df["molecule"].apply(
72
+ lambda mol: rdMolStandardize.LargestFragmentChooser().choose(mol) if mol else None
73
+ )
74
+
75
+ def mol_to_count_string(mol):
76
+ """Convert molecule to comma-separated count fingerprint string."""
77
+ if mol is None:
78
+ return pd.NA
79
+
80
+ # Get hashed Morgan fingerprint with counts
81
+ fp = AllChem.GetHashedMorganFingerprint(mol, radius, nBits=n_bits)
82
+
83
+ # Initialize array and populate with counts (clamped to uint8 range)
84
+ counts = np.zeros(n_bits, dtype=np.uint8)
85
+ for idx, count in fp.GetNonzeroElements().items():
86
+ counts[idx] = min(count, 255)
87
+
88
+ # Return as comma-separated string
89
+ return ",".join(map(str, counts))
90
+
91
+ # Compute Morgan count fingerprints
92
+ fingerprints = largest_frags.apply(mol_to_count_string)
93
+
94
+ # Add the fingerprints to the DataFrame
95
+ df["fingerprint"] = fingerprints
96
+
97
+ # Drop the intermediate 'molecule' column if it was added
98
+ if delete_mol_column:
99
+ del df["molecule"]
100
+
101
+ return df
102
+
103
+
104
+ if __name__ == "__main__":
105
+ print("Running Morgan count fingerprint tests...")
106
+
107
+ # Test molecules
108
+ test_molecules = {
109
+ "aspirin": "CC(=O)OC1=CC=CC=C1C(=O)O",
110
+ "caffeine": "CN1C=NC2=C1C(=O)N(C(=O)N2C)C",
111
+ "glucose": "C([C@@H]1[C@H]([C@@H]([C@H](C(O1)O)O)O)O)O", # With stereochemistry
112
+ "sodium_acetate": "CC(=O)[O-].[Na+]", # Salt (largest fragment used)
113
+ "benzene": "c1ccccc1",
114
+ "butene_e": "C/C=C/C", # E-butene
115
+ "butene_z": "C/C=C\\C", # Z-butene
116
+ }
117
+
118
+ # Test 1: Morgan Count Fingerprints (default parameters)
119
+ print("\n1. Testing Morgan fingerprint generation (radius=2, n_bits=2048)...")
120
+
121
+ test_df = pd.DataFrame({"SMILES": list(test_molecules.values()), "name": list(test_molecules.keys())})
122
+ fp_df = compute_morgan_fingerprints(test_df.copy())
123
+
124
+ print(" Fingerprint generation results:")
125
+ for _, row in fp_df.iterrows():
126
+ fp = row.get("fingerprint", "N/A")
127
+ if pd.notna(fp):
128
+ counts = [int(x) for x in fp.split(",")]
129
+ non_zero = sum(1 for c in counts if c > 0)
130
+ max_count = max(counts)
131
+ print(f" {row['name']:15} → {len(counts)} features, {non_zero} non-zero, max={max_count}")
132
+ else:
133
+ print(f" {row['name']:15} → N/A")
134
+
135
+ # Test 2: Different parameters
136
+ print("\n2. Testing with different parameters (radius=3, n_bits=1024)...")
137
+
138
+ fp_df_custom = compute_morgan_fingerprints(test_df.copy(), radius=3, n_bits=1024)
139
+
140
+ for _, row in fp_df_custom.iterrows():
141
+ fp = row.get("fingerprint", "N/A")
142
+ if pd.notna(fp):
143
+ counts = [int(x) for x in fp.split(",")]
144
+ non_zero = sum(1 for c in counts if c > 0)
145
+ print(f" {row['name']:15} → {len(counts)} features, {non_zero} non-zero")
146
+ else:
147
+ print(f" {row['name']:15} → N/A")
148
+
149
+ # Test 3: Edge cases
150
+ print("\n3. Testing edge cases...")
151
+
152
+ # Invalid SMILES
153
+ invalid_df = pd.DataFrame({"SMILES": ["INVALID", ""]})
154
+ fp_invalid = compute_morgan_fingerprints(invalid_df.copy())
155
+ print(f" ✓ Invalid SMILES handled: {len(fp_invalid)} rows returned")
156
+
157
+ # Test with pre-existing molecule column
158
+ mol_df = test_df.copy()
159
+ mol_df["molecule"] = mol_df["SMILES"].apply(Chem.MolFromSmiles)
160
+ fp_with_mol = compute_morgan_fingerprints(mol_df)
161
+ print(f" ✓ Pre-existing molecule column handled: {len(fp_with_mol)} fingerprints generated")
162
+
163
+ # Test 4: Verify count values are reasonable
164
+ print("\n4. Verifying count distribution...")
165
+ all_counts = []
166
+ for _, row in fp_df.iterrows():
167
+ fp = row.get("fingerprint", "N/A")
168
+ if pd.notna(fp):
169
+ counts = [int(x) for x in fp.split(",")]
170
+ all_counts.extend([c for c in counts if c > 0])
171
+
172
+ if all_counts:
173
+ print(f" Non-zero counts: min={min(all_counts)}, max={max(all_counts)}, mean={np.mean(all_counts):.2f}")
174
+
175
+ print("\n✅ All fingerprint tests completed!")