workbench 0.8.234__py3-none-any.whl → 0.8.239__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- workbench/algorithms/dataframe/smart_aggregator.py +17 -12
- workbench/api/endpoint.py +13 -4
- workbench/api/model.py +2 -2
- workbench/cached/cached_model.py +2 -2
- workbench/core/artifacts/athena_source.py +5 -3
- workbench/core/artifacts/endpoint_core.py +30 -5
- workbench/core/cloud_platform/aws/aws_meta.py +2 -1
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +27 -14
- workbench/model_script_utils/model_script_utils.py +225 -0
- workbench/model_script_utils/uq_harness.py +39 -21
- workbench/model_scripts/chemprop/chemprop.template +30 -15
- workbench/model_scripts/chemprop/generated_model_script.py +35 -18
- workbench/model_scripts/chemprop/model_script_utils.py +225 -0
- workbench/model_scripts/pytorch_model/generated_model_script.py +29 -15
- workbench/model_scripts/pytorch_model/model_script_utils.py +225 -0
- workbench/model_scripts/pytorch_model/pytorch.template +28 -14
- workbench/model_scripts/pytorch_model/uq_harness.py +39 -21
- workbench/model_scripts/xgb_model/generated_model_script.py +35 -22
- workbench/model_scripts/xgb_model/model_script_utils.py +225 -0
- workbench/model_scripts/xgb_model/uq_harness.py +39 -21
- workbench/model_scripts/xgb_model/xgb_model.template +29 -18
- workbench/scripts/ml_pipeline_batch.py +47 -2
- workbench/scripts/ml_pipeline_launcher.py +410 -0
- workbench/scripts/ml_pipeline_sqs.py +22 -2
- workbench/themes/dark/custom.css +29 -0
- workbench/themes/light/custom.css +29 -0
- workbench/themes/midnight_blue/custom.css +28 -0
- workbench/utils/model_utils.py +9 -0
- workbench/utils/theme_manager.py +95 -0
- workbench/web_interface/components/component_interface.py +3 -0
- workbench/web_interface/components/plugin_interface.py +26 -0
- workbench/web_interface/components/plugins/ag_table.py +4 -11
- workbench/web_interface/components/plugins/confusion_matrix.py +14 -8
- workbench/web_interface/components/plugins/model_plot.py +156 -0
- workbench/web_interface/components/plugins/scatter_plot.py +9 -2
- workbench/web_interface/components/plugins/shap_summary_plot.py +12 -4
- workbench/web_interface/components/settings_menu.py +10 -49
- {workbench-0.8.234.dist-info → workbench-0.8.239.dist-info}/METADATA +2 -2
- {workbench-0.8.234.dist-info → workbench-0.8.239.dist-info}/RECORD +43 -42
- {workbench-0.8.234.dist-info → workbench-0.8.239.dist-info}/WHEEL +1 -1
- {workbench-0.8.234.dist-info → workbench-0.8.239.dist-info}/entry_points.txt +1 -0
- workbench/web_interface/components/model_plot.py +0 -75
- {workbench-0.8.234.dist-info → workbench-0.8.239.dist-info}/licenses/LICENSE +0 -0
- {workbench-0.8.234.dist-info → workbench-0.8.239.dist-info}/top_level.txt +0 -0
|
@@ -55,25 +55,30 @@ def smart_aggregator(df: pd.DataFrame, target_rows: int = 1000, outlier_column:
|
|
|
55
55
|
result["aggregation_count"] = 1
|
|
56
56
|
return result.reset_index(drop=True)
|
|
57
57
|
|
|
58
|
-
# Handle NaN values - fill with column median
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
58
|
+
# Handle NaN values - fill with column median (use numpy for speed)
|
|
59
|
+
clustering_data = df[numeric_cols].values
|
|
60
|
+
col_medians = np.nanmedian(clustering_data, axis=0)
|
|
61
|
+
nan_mask = np.isnan(clustering_data)
|
|
62
|
+
clustering_data = np.where(nan_mask, col_medians, clustering_data)
|
|
63
|
+
|
|
64
|
+
# Normalize and cluster (n_init=1 since MiniBatchKMeans is already approximate)
|
|
65
|
+
X = StandardScaler().fit_transform(clustering_data)
|
|
63
66
|
df["_cluster"] = MiniBatchKMeans(
|
|
64
|
-
n_clusters=min(target_rows, n_rows), random_state=42, batch_size=min(1024, n_rows), n_init=
|
|
67
|
+
n_clusters=min(target_rows, n_rows), random_state=42, batch_size=min(1024, n_rows), n_init=1
|
|
65
68
|
).fit_predict(X)
|
|
66
69
|
|
|
67
70
|
# Post-process: give high-outlier rows their own unique clusters so they don't get aggregated
|
|
68
71
|
if outlier_column and outlier_column in df.columns:
|
|
69
|
-
# Top 10% of outlier values get their own clusters, capped at
|
|
70
|
-
n_to_isolate = min(int(n_rows * 0.1),
|
|
71
|
-
|
|
72
|
-
|
|
72
|
+
# Top 10% of outlier values get their own clusters, capped at 20% of target_rows
|
|
73
|
+
n_to_isolate = min(int(n_rows * 0.1), int(target_rows * 0.2))
|
|
74
|
+
outlier_values = df[outlier_column].values
|
|
75
|
+
threshold = np.partition(outlier_values, -n_to_isolate)[-n_to_isolate]
|
|
76
|
+
high_outlier_mask = outlier_values >= threshold
|
|
73
77
|
n_high_outliers = high_outlier_mask.sum()
|
|
74
|
-
# Assign unique cluster IDs starting after the max existing cluster
|
|
78
|
+
# Assign unique cluster IDs starting after the max existing cluster (match dtype to avoid warning)
|
|
75
79
|
max_cluster = df["_cluster"].max()
|
|
76
|
-
|
|
80
|
+
new_cluster_ids = np.arange(max_cluster + 1, max_cluster + 1 + n_high_outliers, dtype=df["_cluster"].dtype)
|
|
81
|
+
df.loc[high_outlier_mask, "_cluster"] = new_cluster_ids
|
|
77
82
|
log.info(f"smart_aggregator: Isolated {n_high_outliers} high-outlier rows (>= {threshold:.3f})")
|
|
78
83
|
elif outlier_column:
|
|
79
84
|
log.warning(f"smart_aggregator: outlier_column '{outlier_column}' not found in columns")
|
workbench/api/endpoint.py
CHANGED
|
@@ -29,7 +29,12 @@ class Endpoint(EndpointCore):
|
|
|
29
29
|
return super().details(**kwargs)
|
|
30
30
|
|
|
31
31
|
def inference(
|
|
32
|
-
self,
|
|
32
|
+
self,
|
|
33
|
+
eval_df: pd.DataFrame,
|
|
34
|
+
capture_name: str = None,
|
|
35
|
+
id_column: str = None,
|
|
36
|
+
drop_error_rows: bool = False,
|
|
37
|
+
include_quantiles: bool = False,
|
|
33
38
|
) -> pd.DataFrame:
|
|
34
39
|
"""Run inference on the Endpoint using the provided DataFrame
|
|
35
40
|
|
|
@@ -38,11 +43,12 @@ class Endpoint(EndpointCore):
|
|
|
38
43
|
capture_name (str, optional): The Name of the capture to use (default: None)
|
|
39
44
|
id_column (str, optional): The name of the column to use as the ID (default: None)
|
|
40
45
|
drop_error_rows (bool): Whether to drop rows with errors (default: False)
|
|
46
|
+
include_quantiles (bool): Include q_* quantile columns in saved output (default: False)
|
|
41
47
|
|
|
42
48
|
Returns:
|
|
43
49
|
pd.DataFrame: The DataFrame with predictions
|
|
44
50
|
"""
|
|
45
|
-
return super().inference(eval_df, capture_name, id_column, drop_error_rows)
|
|
51
|
+
return super().inference(eval_df, capture_name, id_column, drop_error_rows, include_quantiles)
|
|
46
52
|
|
|
47
53
|
def auto_inference(self) -> pd.DataFrame:
|
|
48
54
|
"""Run inference on the Endpoint using the test data from the model training view
|
|
@@ -75,13 +81,16 @@ class Endpoint(EndpointCore):
|
|
|
75
81
|
"""
|
|
76
82
|
return super().fast_inference(eval_df, threads=threads)
|
|
77
83
|
|
|
78
|
-
def cross_fold_inference(self) -> pd.DataFrame:
|
|
84
|
+
def cross_fold_inference(self, include_quantiles: bool = False) -> pd.DataFrame:
|
|
79
85
|
"""Pull cross-fold inference from model associated with this Endpoint
|
|
80
86
|
|
|
87
|
+
Args:
|
|
88
|
+
include_quantiles (bool): Include q_* quantile columns in saved output (default: False)
|
|
89
|
+
|
|
81
90
|
Returns:
|
|
82
91
|
pd.DataFrame: A DataFrame with cross fold predictions
|
|
83
92
|
"""
|
|
84
|
-
return super().cross_fold_inference()
|
|
93
|
+
return super().cross_fold_inference(include_quantiles)
|
|
85
94
|
|
|
86
95
|
|
|
87
96
|
if __name__ == "__main__":
|
workbench/api/model.py
CHANGED
|
@@ -44,7 +44,7 @@ class Model(ModelCore):
|
|
|
44
44
|
serverless: bool = True,
|
|
45
45
|
mem_size: int = 2048,
|
|
46
46
|
max_concurrency: int = 5,
|
|
47
|
-
instance: str =
|
|
47
|
+
instance: str = None,
|
|
48
48
|
data_capture: bool = False,
|
|
49
49
|
) -> Endpoint:
|
|
50
50
|
"""Create an Endpoint from the Model.
|
|
@@ -55,7 +55,7 @@ class Model(ModelCore):
|
|
|
55
55
|
serverless (bool): Set the endpoint to be serverless (default: True)
|
|
56
56
|
mem_size (int): The memory size for the Endpoint in MB (default: 2048)
|
|
57
57
|
max_concurrency (int): The maximum concurrency for the Endpoint (default: 5)
|
|
58
|
-
instance (str): The instance type
|
|
58
|
+
instance (str): The instance type for Realtime Endpoints (default: None = auto-select based on model)
|
|
59
59
|
data_capture (bool): Enable data capture for the Endpoint (default: False)
|
|
60
60
|
|
|
61
61
|
Returns:
|
workbench/cached/cached_model.py
CHANGED
|
@@ -86,13 +86,13 @@ class CachedModel(CachedArtifactMixin, ModelCore):
|
|
|
86
86
|
|
|
87
87
|
@CachedArtifactMixin.cache_result
|
|
88
88
|
def get_inference_predictions(
|
|
89
|
-
self, capture_name: str = "full_cross_fold", target_rows: int =
|
|
89
|
+
self, capture_name: str = "full_cross_fold", target_rows: int = 2000
|
|
90
90
|
) -> Union[pd.DataFrame, None]:
|
|
91
91
|
"""Retrieve the captured prediction results for this model
|
|
92
92
|
|
|
93
93
|
Args:
|
|
94
94
|
capture_name (str, optional): Specific capture_name (default: full_cross_fold)
|
|
95
|
-
target_rows (int, optional): Target number of rows to return (default:
|
|
95
|
+
target_rows (int, optional): Target number of rows to return (default: 2000)
|
|
96
96
|
|
|
97
97
|
Returns:
|
|
98
98
|
pd.DataFrame: DataFrame of the Captured Predictions (might be None)
|
|
@@ -258,7 +258,7 @@ class AthenaSource(DataSourceAbstract):
|
|
|
258
258
|
|
|
259
259
|
# Wait for the query to complete
|
|
260
260
|
wr.athena.wait_query(query_execution_id=query_execution_id, boto3_session=self.boto3_session)
|
|
261
|
-
self.log.debug(f"
|
|
261
|
+
self.log.debug(f"Query executed successfully: {query_execution_id}")
|
|
262
262
|
break # If successful, exit the retry loop
|
|
263
263
|
except wr.exceptions.QueryFailed as e:
|
|
264
264
|
if "AlreadyExistsException" in str(e):
|
|
@@ -271,11 +271,13 @@ class AthenaSource(DataSourceAbstract):
|
|
|
271
271
|
time.sleep(retry_delay)
|
|
272
272
|
else:
|
|
273
273
|
if not silence_errors:
|
|
274
|
-
self.log.critical(f"Failed to execute
|
|
274
|
+
self.log.critical(f"Failed to execute query after {max_retries} attempts: {query}")
|
|
275
|
+
self.log.critical(f"Error: {e}")
|
|
275
276
|
raise
|
|
276
277
|
else:
|
|
277
278
|
if not silence_errors:
|
|
278
|
-
self.log.critical(f"Failed to execute
|
|
279
|
+
self.log.critical(f"Failed to execute query: {query}")
|
|
280
|
+
self.log.critical(f"Error: {e}")
|
|
279
281
|
raise
|
|
280
282
|
|
|
281
283
|
def s3_storage_location(self) -> str:
|
|
@@ -370,7 +370,12 @@ class EndpointCore(Artifact):
|
|
|
370
370
|
return self.inference(eval_df, "full_inference")
|
|
371
371
|
|
|
372
372
|
def inference(
|
|
373
|
-
self,
|
|
373
|
+
self,
|
|
374
|
+
eval_df: pd.DataFrame,
|
|
375
|
+
capture_name: str = None,
|
|
376
|
+
id_column: str = None,
|
|
377
|
+
drop_error_rows: bool = False,
|
|
378
|
+
include_quantiles: bool = False,
|
|
374
379
|
) -> pd.DataFrame:
|
|
375
380
|
"""Run inference on the Endpoint using the provided DataFrame
|
|
376
381
|
|
|
@@ -379,6 +384,7 @@ class EndpointCore(Artifact):
|
|
|
379
384
|
capture_name (str, optional): Name of the inference capture (default=None)
|
|
380
385
|
id_column (str, optional): Name of the ID column (default=None)
|
|
381
386
|
drop_error_rows (bool, optional): If True, drop rows that had endpoint errors/issues (default=False)
|
|
387
|
+
include_quantiles (bool): Include q_* quantile columns in saved output (default: False)
|
|
382
388
|
|
|
383
389
|
Returns:
|
|
384
390
|
pd.DataFrame: DataFrame with the inference results
|
|
@@ -478,6 +484,7 @@ class EndpointCore(Artifact):
|
|
|
478
484
|
description,
|
|
479
485
|
features,
|
|
480
486
|
id_column,
|
|
487
|
+
include_quantiles,
|
|
481
488
|
)
|
|
482
489
|
|
|
483
490
|
# Save primary target (or single target) with original capture_name
|
|
@@ -491,6 +498,7 @@ class EndpointCore(Artifact):
|
|
|
491
498
|
capture_name.replace("_", " ").title(),
|
|
492
499
|
features,
|
|
493
500
|
id_column,
|
|
501
|
+
include_quantiles,
|
|
494
502
|
)
|
|
495
503
|
|
|
496
504
|
# Capture uncertainty metrics if prediction_std is available (UQ, ChemProp, etc.)
|
|
@@ -501,9 +509,12 @@ class EndpointCore(Artifact):
|
|
|
501
509
|
# Return the prediction DataFrame
|
|
502
510
|
return prediction_df
|
|
503
511
|
|
|
504
|
-
def cross_fold_inference(self) -> pd.DataFrame:
|
|
512
|
+
def cross_fold_inference(self, include_quantiles: bool = False) -> pd.DataFrame:
|
|
505
513
|
"""Pull cross-fold inference training results for this Endpoint's model
|
|
506
514
|
|
|
515
|
+
Args:
|
|
516
|
+
include_quantiles (bool): Include q_* quantile columns in saved output (default: False)
|
|
517
|
+
|
|
507
518
|
Returns:
|
|
508
519
|
pd.DataFrame: A DataFrame with cross fold predictions
|
|
509
520
|
"""
|
|
@@ -594,6 +605,7 @@ class EndpointCore(Artifact):
|
|
|
594
605
|
description,
|
|
595
606
|
features=additional_columns,
|
|
596
607
|
id_column=id_column,
|
|
608
|
+
include_quantiles=include_quantiles,
|
|
597
609
|
)
|
|
598
610
|
|
|
599
611
|
# Save primary target (or single target) as "full_cross_fold"
|
|
@@ -607,6 +619,7 @@ class EndpointCore(Artifact):
|
|
|
607
619
|
"Full Cross Fold",
|
|
608
620
|
features=additional_columns,
|
|
609
621
|
id_column=id_column,
|
|
622
|
+
include_quantiles=include_quantiles,
|
|
610
623
|
)
|
|
611
624
|
|
|
612
625
|
return out_of_fold_df
|
|
@@ -824,6 +837,7 @@ class EndpointCore(Artifact):
|
|
|
824
837
|
description: str,
|
|
825
838
|
features: list,
|
|
826
839
|
id_column: str = None,
|
|
840
|
+
include_quantiles: bool = False,
|
|
827
841
|
):
|
|
828
842
|
"""Internal: Capture the inference results and metrics to S3 for a single target
|
|
829
843
|
|
|
@@ -836,6 +850,7 @@ class EndpointCore(Artifact):
|
|
|
836
850
|
description (str): Description of the inference results
|
|
837
851
|
features (list): List of features to include in the inference results
|
|
838
852
|
id_column (str, optional): Name of the ID column (default=None)
|
|
853
|
+
include_quantiles (bool): Include q_* quantile columns in output (default: False)
|
|
839
854
|
"""
|
|
840
855
|
|
|
841
856
|
# Compute a dataframe hash (just use the last 8)
|
|
@@ -862,7 +877,7 @@ class EndpointCore(Artifact):
|
|
|
862
877
|
wr.s3.to_csv(metrics, f"{inference_capture_path}/inference_metrics.csv", index=False)
|
|
863
878
|
|
|
864
879
|
# Save the inference predictions for this target
|
|
865
|
-
self._save_target_inference(inference_capture_path, pred_results_df, target, id_column)
|
|
880
|
+
self._save_target_inference(inference_capture_path, pred_results_df, target, id_column, include_quantiles)
|
|
866
881
|
|
|
867
882
|
# CLASSIFIER: Write the confusion matrix to our S3 Model Inference Folder
|
|
868
883
|
if model_type == ModelType.CLASSIFIER:
|
|
@@ -882,6 +897,7 @@ class EndpointCore(Artifact):
|
|
|
882
897
|
pred_results_df: pd.DataFrame,
|
|
883
898
|
target: str,
|
|
884
899
|
id_column: str = None,
|
|
900
|
+
include_quantiles: bool = False,
|
|
885
901
|
):
|
|
886
902
|
"""Save inference results for a single target.
|
|
887
903
|
|
|
@@ -890,6 +906,7 @@ class EndpointCore(Artifact):
|
|
|
890
906
|
pred_results_df (pd.DataFrame): DataFrame with prediction results
|
|
891
907
|
target (str): Target column name
|
|
892
908
|
id_column (str, optional): Name of the ID column
|
|
909
|
+
include_quantiles (bool): Include q_* quantile columns in output (default: False)
|
|
893
910
|
"""
|
|
894
911
|
cols = pred_results_df.columns
|
|
895
912
|
|
|
@@ -902,8 +919,16 @@ class EndpointCore(Artifact):
|
|
|
902
919
|
|
|
903
920
|
output_columns += [c for c in ["prediction", "prediction_std"] if c in cols]
|
|
904
921
|
|
|
905
|
-
# Add
|
|
906
|
-
|
|
922
|
+
# Add confidence column (always include if present)
|
|
923
|
+
if "confidence" in cols:
|
|
924
|
+
output_columns.append("confidence")
|
|
925
|
+
|
|
926
|
+
# Add quantile columns (q_*) only if requested
|
|
927
|
+
if include_quantiles:
|
|
928
|
+
output_columns += [c for c in cols if c.startswith("q_")]
|
|
929
|
+
|
|
930
|
+
# Add proba columns for classifiers
|
|
931
|
+
output_columns += [c for c in cols if c.endswith("_proba")]
|
|
907
932
|
|
|
908
933
|
# Add smiles column if present
|
|
909
934
|
if "smiles" in cols:
|
|
@@ -245,7 +245,8 @@ class AWSMeta:
|
|
|
245
245
|
"Model Group": model_group_name,
|
|
246
246
|
"Health": health_tags,
|
|
247
247
|
"Owner": aws_tags.get("workbench_owner", "-"),
|
|
248
|
-
"
|
|
248
|
+
"Type": aws_tags.get("workbench_model_type", "-"),
|
|
249
|
+
"Framework": aws_tags.get("workbench_model_framework", "-"),
|
|
249
250
|
"Created": created,
|
|
250
251
|
"Ver": model_details.get("ModelPackageVersion", "-"),
|
|
251
252
|
"Input": aws_tags.get("workbench_input", "-"),
|
|
@@ -26,13 +26,13 @@ class ModelToEndpoint(Transform):
|
|
|
26
26
|
```
|
|
27
27
|
"""
|
|
28
28
|
|
|
29
|
-
def __init__(self, model_name: str, endpoint_name: str, serverless: bool = True, instance: str =
|
|
29
|
+
def __init__(self, model_name: str, endpoint_name: str, serverless: bool = True, instance: str = None):
|
|
30
30
|
"""ModelToEndpoint Initialization
|
|
31
31
|
Args:
|
|
32
32
|
model_name(str): The Name of the input Model
|
|
33
33
|
endpoint_name(str): The Name of the output Endpoint
|
|
34
34
|
serverless(bool): Deploy the Endpoint in serverless mode (default: True)
|
|
35
|
-
instance(str): The instance type
|
|
35
|
+
instance(str): The instance type for Realtime Endpoints (default: None = auto-select)
|
|
36
36
|
"""
|
|
37
37
|
# Make sure the endpoint_name is a valid name
|
|
38
38
|
Artifact.is_name_valid(endpoint_name, delimiter="-", lower_case=False)
|
|
@@ -42,7 +42,7 @@ class ModelToEndpoint(Transform):
|
|
|
42
42
|
|
|
43
43
|
# Set up all my instance attributes
|
|
44
44
|
self.serverless = serverless
|
|
45
|
-
self.
|
|
45
|
+
self.instance = instance
|
|
46
46
|
self.input_type = TransformInput.MODEL
|
|
47
47
|
self.output_type = TransformOutput.ENDPOINT
|
|
48
48
|
|
|
@@ -100,24 +100,37 @@ class ModelToEndpoint(Transform):
|
|
|
100
100
|
# Get the metadata/tags to push into AWS
|
|
101
101
|
aws_tags = self.get_aws_tags()
|
|
102
102
|
|
|
103
|
+
# Check the model framework for resource requirements
|
|
104
|
+
from workbench.api import ModelFramework
|
|
105
|
+
|
|
106
|
+
self.log.info(f"Model Framework: {workbench_model.model_framework}")
|
|
107
|
+
needs_more_resources = workbench_model.model_framework in [ModelFramework.PYTORCH, ModelFramework.CHEMPROP]
|
|
108
|
+
|
|
103
109
|
# Is this a serverless deployment?
|
|
104
110
|
serverless_config = None
|
|
111
|
+
instance_type = None
|
|
105
112
|
if self.serverless:
|
|
106
113
|
# For PyTorch or ChemProp we need at least 4GB of memory
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
if workbench_model.model_framework in [ModelFramework.PYTORCH, ModelFramework.CHEMPROP]:
|
|
111
|
-
if mem_size < 4096:
|
|
112
|
-
self.log.important(
|
|
113
|
-
f"{workbench_model.model_framework} needs at least 4GB of memory (setting to 4GB)"
|
|
114
|
-
)
|
|
115
|
-
mem_size = 4096
|
|
114
|
+
if needs_more_resources and mem_size < 4096:
|
|
115
|
+
self.log.important(f"{workbench_model.model_framework} needs at least 4GB of memory (setting to 4GB)")
|
|
116
|
+
mem_size = 4096
|
|
116
117
|
serverless_config = ServerlessInferenceConfig(
|
|
117
118
|
memory_size_in_mb=mem_size,
|
|
118
119
|
max_concurrency=max_concurrency,
|
|
119
120
|
)
|
|
121
|
+
instance_type = "serverless"
|
|
120
122
|
self.log.important(f"Serverless Config: Memory={mem_size}MB, MaxConcurrency={max_concurrency}")
|
|
123
|
+
else:
|
|
124
|
+
# For realtime endpoints, use explicit instance if provided, otherwise auto-select
|
|
125
|
+
if self.instance:
|
|
126
|
+
instance_type = self.instance
|
|
127
|
+
self.log.important(f"Realtime Endpoint: Using specified instance type: {instance_type}")
|
|
128
|
+
elif needs_more_resources:
|
|
129
|
+
instance_type = "ml.c7i.xlarge"
|
|
130
|
+
self.log.important(f"{workbench_model.model_framework} needs more resources (using {instance_type})")
|
|
131
|
+
else:
|
|
132
|
+
instance_type = "ml.t2.medium"
|
|
133
|
+
self.log.important(f"Realtime Endpoint: Instance Type={instance_type}")
|
|
121
134
|
|
|
122
135
|
# Configure data capture if requested (and not serverless)
|
|
123
136
|
data_capture_config = None
|
|
@@ -141,7 +154,7 @@ class ModelToEndpoint(Transform):
|
|
|
141
154
|
try:
|
|
142
155
|
model_package.deploy(
|
|
143
156
|
initial_instance_count=1,
|
|
144
|
-
instance_type=
|
|
157
|
+
instance_type=instance_type,
|
|
145
158
|
serverless_inference_config=serverless_config,
|
|
146
159
|
endpoint_name=self.output_name,
|
|
147
160
|
serializer=CSVSerializer(),
|
|
@@ -158,7 +171,7 @@ class ModelToEndpoint(Transform):
|
|
|
158
171
|
# Retry the deploy
|
|
159
172
|
model_package.deploy(
|
|
160
173
|
initial_instance_count=1,
|
|
161
|
-
instance_type=
|
|
174
|
+
instance_type=instance_type,
|
|
162
175
|
serverless_inference_config=serverless_config,
|
|
163
176
|
endpoint_name=self.output_name,
|
|
164
177
|
serializer=CSVSerializer(),
|
|
@@ -16,6 +16,7 @@ from sklearn.metrics import (
|
|
|
16
16
|
r2_score,
|
|
17
17
|
root_mean_squared_error,
|
|
18
18
|
)
|
|
19
|
+
from sklearn.model_selection import GroupKFold, GroupShuffleSplit
|
|
19
20
|
from scipy.stats import spearmanr
|
|
20
21
|
|
|
21
22
|
|
|
@@ -367,3 +368,227 @@ def print_confusion_matrix(y_true: np.ndarray, y_pred: np.ndarray, label_names:
|
|
|
367
368
|
for j, col_name in enumerate(label_names):
|
|
368
369
|
value = conf_mtx[i, j]
|
|
369
370
|
print(f"ConfusionMatrix:{row_name}:{col_name} {value}")
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
# =============================================================================
|
|
374
|
+
# Dataset Splitting Utilities for Molecular Data
|
|
375
|
+
# =============================================================================
|
|
376
|
+
def get_scaffold(smiles: str) -> str:
|
|
377
|
+
"""Extract Bemis-Murcko scaffold from a SMILES string.
|
|
378
|
+
|
|
379
|
+
Args:
|
|
380
|
+
smiles: SMILES string of the molecule
|
|
381
|
+
|
|
382
|
+
Returns:
|
|
383
|
+
SMILES string of the scaffold, or empty string if molecule is invalid
|
|
384
|
+
"""
|
|
385
|
+
from rdkit import Chem
|
|
386
|
+
from rdkit.Chem.Scaffolds import MurckoScaffold
|
|
387
|
+
|
|
388
|
+
mol = Chem.MolFromSmiles(smiles)
|
|
389
|
+
if mol is None:
|
|
390
|
+
return ""
|
|
391
|
+
try:
|
|
392
|
+
scaffold = MurckoScaffold.GetScaffoldForMol(mol)
|
|
393
|
+
return Chem.MolToSmiles(scaffold)
|
|
394
|
+
except Exception:
|
|
395
|
+
return ""
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
def get_scaffold_groups(smiles_list: list[str]) -> np.ndarray:
|
|
399
|
+
"""Assign each molecule to a scaffold group.
|
|
400
|
+
|
|
401
|
+
Args:
|
|
402
|
+
smiles_list: List of SMILES strings
|
|
403
|
+
|
|
404
|
+
Returns:
|
|
405
|
+
Array of group indices (same scaffold = same group)
|
|
406
|
+
"""
|
|
407
|
+
scaffold_to_group = {}
|
|
408
|
+
groups = []
|
|
409
|
+
|
|
410
|
+
for smi in smiles_list:
|
|
411
|
+
scaffold = get_scaffold(smi)
|
|
412
|
+
if scaffold not in scaffold_to_group:
|
|
413
|
+
scaffold_to_group[scaffold] = len(scaffold_to_group)
|
|
414
|
+
groups.append(scaffold_to_group[scaffold])
|
|
415
|
+
|
|
416
|
+
n_scaffolds = len(scaffold_to_group)
|
|
417
|
+
print(f"Found {n_scaffolds} unique scaffolds from {len(smiles_list)} molecules")
|
|
418
|
+
return np.array(groups)
|
|
419
|
+
|
|
420
|
+
|
|
421
|
+
def get_butina_clusters(smiles_list: list[str], cutoff: float = 0.4) -> np.ndarray:
|
|
422
|
+
"""Cluster molecules using Butina algorithm on Morgan fingerprints.
|
|
423
|
+
|
|
424
|
+
Uses RDKit's Butina clustering with Tanimoto distance on Morgan fingerprints.
|
|
425
|
+
This is Pat Walters' recommended approach for creating diverse train/test splits.
|
|
426
|
+
|
|
427
|
+
Args:
|
|
428
|
+
smiles_list: List of SMILES strings
|
|
429
|
+
cutoff: Tanimoto distance cutoff for clustering (default 0.4)
|
|
430
|
+
Lower values = more clusters = more similar molecules per cluster
|
|
431
|
+
|
|
432
|
+
Returns:
|
|
433
|
+
Array of cluster indices
|
|
434
|
+
"""
|
|
435
|
+
from rdkit import Chem, DataStructs
|
|
436
|
+
from rdkit.Chem.rdFingerprintGenerator import GetMorganGenerator
|
|
437
|
+
from rdkit.ML.Cluster import Butina
|
|
438
|
+
|
|
439
|
+
# Create Morgan fingerprint generator
|
|
440
|
+
fp_gen = GetMorganGenerator(radius=2, fpSize=2048)
|
|
441
|
+
|
|
442
|
+
# Generate Morgan fingerprints
|
|
443
|
+
fps = []
|
|
444
|
+
valid_indices = []
|
|
445
|
+
for i, smi in enumerate(smiles_list):
|
|
446
|
+
mol = Chem.MolFromSmiles(smi)
|
|
447
|
+
if mol is not None:
|
|
448
|
+
fp = fp_gen.GetFingerprint(mol)
|
|
449
|
+
fps.append(fp)
|
|
450
|
+
valid_indices.append(i)
|
|
451
|
+
|
|
452
|
+
if len(fps) == 0:
|
|
453
|
+
raise ValueError("No valid molecules found for clustering")
|
|
454
|
+
|
|
455
|
+
# Compute distance matrix (upper triangle only for efficiency)
|
|
456
|
+
n = len(fps)
|
|
457
|
+
dists = []
|
|
458
|
+
for i in range(1, n):
|
|
459
|
+
sims = DataStructs.BulkTanimotoSimilarity(fps[i], fps[:i])
|
|
460
|
+
dists.extend([1 - s for s in sims])
|
|
461
|
+
|
|
462
|
+
# Butina clustering
|
|
463
|
+
clusters = Butina.ClusterData(dists, n, cutoff, isDistData=True)
|
|
464
|
+
|
|
465
|
+
# Map back to original indices
|
|
466
|
+
cluster_labels = np.zeros(len(smiles_list), dtype=int)
|
|
467
|
+
for cluster_idx, cluster in enumerate(clusters):
|
|
468
|
+
for mol_idx in cluster:
|
|
469
|
+
original_idx = valid_indices[mol_idx]
|
|
470
|
+
cluster_labels[original_idx] = cluster_idx
|
|
471
|
+
|
|
472
|
+
# Assign invalid molecules to their own clusters
|
|
473
|
+
next_cluster = len(clusters)
|
|
474
|
+
for i in range(len(smiles_list)):
|
|
475
|
+
if i not in valid_indices:
|
|
476
|
+
cluster_labels[i] = next_cluster
|
|
477
|
+
next_cluster += 1
|
|
478
|
+
|
|
479
|
+
n_clusters = len(set(cluster_labels))
|
|
480
|
+
print(f"Butina clustering: {n_clusters} clusters from {len(smiles_list)} molecules (cutoff={cutoff})")
|
|
481
|
+
return cluster_labels
|
|
482
|
+
|
|
483
|
+
|
|
484
|
+
def _find_smiles_column(columns: list[str]) -> str | None:
|
|
485
|
+
"""Find SMILES column (case-insensitive match for 'smiles').
|
|
486
|
+
|
|
487
|
+
Args:
|
|
488
|
+
columns: List of column names
|
|
489
|
+
|
|
490
|
+
Returns:
|
|
491
|
+
The matching column name, or None if not found
|
|
492
|
+
"""
|
|
493
|
+
return next((c for c in columns if c.lower() == "smiles"), None)
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
def get_split_indices(
|
|
497
|
+
df: pd.DataFrame,
|
|
498
|
+
n_splits: int = 5,
|
|
499
|
+
strategy: str = "random",
|
|
500
|
+
smiles_column: str | None = None,
|
|
501
|
+
target_column: str | None = None,
|
|
502
|
+
test_size: float = 0.2,
|
|
503
|
+
random_state: int = 42,
|
|
504
|
+
butina_cutoff: float = 0.4,
|
|
505
|
+
) -> list[tuple[np.ndarray, np.ndarray]]:
|
|
506
|
+
"""Get train/validation split indices using various strategies.
|
|
507
|
+
|
|
508
|
+
This is a unified interface for generating splits that can be used across
|
|
509
|
+
all model templates (XGBoost, PyTorch, ChemProp).
|
|
510
|
+
|
|
511
|
+
Args:
|
|
512
|
+
df: DataFrame containing the data
|
|
513
|
+
n_splits: Number of CV folds (1 = single train/val split)
|
|
514
|
+
strategy: Split strategy - one of:
|
|
515
|
+
- "random": Standard random split (default sklearn behavior)
|
|
516
|
+
- "scaffold": Bemis-Murcko scaffold-based grouping
|
|
517
|
+
- "butina": Morgan fingerprint clustering (recommended for ADMET)
|
|
518
|
+
smiles_column: Column containing SMILES. If None, auto-detects 'smiles' (case-insensitive)
|
|
519
|
+
target_column: Column containing target values (for stratification, optional)
|
|
520
|
+
test_size: Fraction for validation set when n_splits=1 (default 0.2)
|
|
521
|
+
random_state: Random seed for reproducibility
|
|
522
|
+
butina_cutoff: Tanimoto distance cutoff for Butina clustering (default 0.4)
|
|
523
|
+
|
|
524
|
+
Returns:
|
|
525
|
+
List of (train_indices, val_indices) tuples
|
|
526
|
+
|
|
527
|
+
Note:
|
|
528
|
+
If scaffold/butina strategy is requested but no SMILES column is found,
|
|
529
|
+
automatically falls back to random split with a warning message.
|
|
530
|
+
|
|
531
|
+
Example:
|
|
532
|
+
>>> folds = get_split_indices(df, n_splits=5, strategy="scaffold")
|
|
533
|
+
>>> for train_idx, val_idx in folds:
|
|
534
|
+
... X_train, X_val = df.iloc[train_idx], df.iloc[val_idx]
|
|
535
|
+
"""
|
|
536
|
+
from sklearn.model_selection import KFold, StratifiedKFold, train_test_split
|
|
537
|
+
|
|
538
|
+
n_samples = len(df)
|
|
539
|
+
|
|
540
|
+
# Random split (original behavior)
|
|
541
|
+
if strategy == "random":
|
|
542
|
+
if n_splits == 1:
|
|
543
|
+
indices = np.arange(n_samples)
|
|
544
|
+
train_idx, val_idx = train_test_split(indices, test_size=test_size, random_state=random_state)
|
|
545
|
+
return [(train_idx, val_idx)]
|
|
546
|
+
else:
|
|
547
|
+
if target_column and df[target_column].dtype in ["object", "category", "bool"]:
|
|
548
|
+
kfold = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_state)
|
|
549
|
+
return list(kfold.split(df, df[target_column]))
|
|
550
|
+
else:
|
|
551
|
+
kfold = KFold(n_splits=n_splits, shuffle=True, random_state=random_state)
|
|
552
|
+
return list(kfold.split(df))
|
|
553
|
+
|
|
554
|
+
# Scaffold or Butina split requires SMILES - auto-detect if not provided
|
|
555
|
+
if smiles_column is None:
|
|
556
|
+
smiles_column = _find_smiles_column(df.columns.tolist())
|
|
557
|
+
|
|
558
|
+
# Fall back to random split if no SMILES column available
|
|
559
|
+
if smiles_column is None or smiles_column not in df.columns:
|
|
560
|
+
print(f"No 'smiles' column found for strategy='{strategy}', falling back to random split")
|
|
561
|
+
return get_split_indices(
|
|
562
|
+
df,
|
|
563
|
+
n_splits=n_splits,
|
|
564
|
+
strategy="random",
|
|
565
|
+
target_column=target_column,
|
|
566
|
+
test_size=test_size,
|
|
567
|
+
random_state=random_state,
|
|
568
|
+
)
|
|
569
|
+
|
|
570
|
+
smiles_list = df[smiles_column].tolist()
|
|
571
|
+
|
|
572
|
+
# Get group assignments
|
|
573
|
+
if strategy == "scaffold":
|
|
574
|
+
groups = get_scaffold_groups(smiles_list)
|
|
575
|
+
elif strategy == "butina":
|
|
576
|
+
groups = get_butina_clusters(smiles_list, cutoff=butina_cutoff)
|
|
577
|
+
else:
|
|
578
|
+
raise ValueError(f"Unknown strategy: {strategy}. Use 'random', 'scaffold', or 'butina'")
|
|
579
|
+
|
|
580
|
+
# Generate splits using GroupKFold or GroupShuffleSplit
|
|
581
|
+
if n_splits == 1:
|
|
582
|
+
# Single split: use GroupShuffleSplit
|
|
583
|
+
splitter = GroupShuffleSplit(n_splits=1, test_size=test_size, random_state=random_state)
|
|
584
|
+
return list(splitter.split(df, groups=groups))
|
|
585
|
+
else:
|
|
586
|
+
# K-fold: use GroupKFold (ensures no group appears in both train and val)
|
|
587
|
+
# Note: GroupKFold doesn't shuffle, so we shuffle group order first
|
|
588
|
+
unique_groups = np.unique(groups)
|
|
589
|
+
rng = np.random.default_rng(random_state)
|
|
590
|
+
shuffled_group_map = {g: i for i, g in enumerate(rng.permutation(unique_groups))}
|
|
591
|
+
shuffled_groups = np.array([shuffled_group_map[g] for g in groups])
|
|
592
|
+
|
|
593
|
+
gkf = GroupKFold(n_splits=n_splits)
|
|
594
|
+
return list(gkf.split(df, groups=shuffled_groups))
|