workbench 0.8.203__py3-none-any.whl → 0.8.205__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- workbench/core/artifacts/endpoint_core.py +65 -76
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +11 -0
- workbench/model_scripts/chemprop/chemprop.template +18 -1
- workbench/model_scripts/chemprop/generated_model_script.py +24 -7
- workbench/model_scripts/uq_models/generated_model_script.py +3 -3
- {workbench-0.8.203.dist-info → workbench-0.8.205.dist-info}/METADATA +1 -1
- {workbench-0.8.203.dist-info → workbench-0.8.205.dist-info}/RECORD +11 -11
- {workbench-0.8.203.dist-info → workbench-0.8.205.dist-info}/WHEEL +0 -0
- {workbench-0.8.203.dist-info → workbench-0.8.205.dist-info}/entry_points.txt +0 -0
- {workbench-0.8.203.dist-info → workbench-0.8.205.dist-info}/licenses/LICENSE +0 -0
- {workbench-0.8.203.dist-info → workbench-0.8.205.dist-info}/top_level.txt +0 -0
|
@@ -436,19 +436,14 @@ class EndpointCore(Artifact):
|
|
|
436
436
|
|
|
437
437
|
# Normalize targets to a list for iteration
|
|
438
438
|
target_list = targets if isinstance(targets, list) else [targets]
|
|
439
|
+
primary_target = target_list[0]
|
|
439
440
|
|
|
440
|
-
# For
|
|
441
|
-
#
|
|
442
|
-
|
|
443
|
-
# Determine capture name: use prefix for multi-target, original name for single-target
|
|
444
|
-
if len(target_list) > 1:
|
|
445
|
-
prefix = "auto" if "auto" in capture_name else capture_name
|
|
446
|
-
target_capture_name = f"{prefix}_{target}"
|
|
447
|
-
else:
|
|
448
|
-
target_capture_name = capture_name
|
|
449
|
-
|
|
450
|
-
description = target_capture_name.replace("_", " ").title()
|
|
441
|
+
# For auto_inference, use shorter "auto_{target}" naming
|
|
442
|
+
# Otherwise use "{capture_name}_{target}"
|
|
443
|
+
prefix = "auto" if capture_name == "auto_inference" else capture_name
|
|
451
444
|
|
|
445
|
+
# Save results for each target, plus primary target with original capture_name
|
|
446
|
+
for target in target_list:
|
|
452
447
|
# Drop rows with NaN target values for metrics/plots
|
|
453
448
|
target_df = prediction_df.dropna(subset=[target])
|
|
454
449
|
|
|
@@ -460,6 +455,9 @@ class EndpointCore(Artifact):
|
|
|
460
455
|
else:
|
|
461
456
|
target_metrics = pd.DataFrame()
|
|
462
457
|
|
|
458
|
+
# Save as {prefix}_{target}
|
|
459
|
+
target_capture_name = f"{prefix}_{target}"
|
|
460
|
+
description = target_capture_name.replace("_", " ").title()
|
|
463
461
|
self._capture_inference_results(
|
|
464
462
|
target_capture_name,
|
|
465
463
|
target_df,
|
|
@@ -471,6 +469,19 @@ class EndpointCore(Artifact):
|
|
|
471
469
|
id_column,
|
|
472
470
|
)
|
|
473
471
|
|
|
472
|
+
# Also save primary target with original capture_name for backward compatibility
|
|
473
|
+
if target == primary_target:
|
|
474
|
+
self._capture_inference_results(
|
|
475
|
+
capture_name,
|
|
476
|
+
target_df,
|
|
477
|
+
target,
|
|
478
|
+
model.model_type,
|
|
479
|
+
target_metrics,
|
|
480
|
+
capture_name.replace("_", " ").title(),
|
|
481
|
+
features,
|
|
482
|
+
id_column,
|
|
483
|
+
)
|
|
484
|
+
|
|
474
485
|
# For UQ Models we also capture the uncertainty metrics
|
|
475
486
|
if model.model_type in [ModelType.UQ_REGRESSOR]:
|
|
476
487
|
metrics = uq_metrics(prediction_df, primary_target)
|
|
@@ -561,13 +572,11 @@ class EndpointCore(Artifact):
|
|
|
561
572
|
|
|
562
573
|
# Normalize targets to a list for iteration
|
|
563
574
|
target_list = targets if isinstance(targets, list) else [targets]
|
|
575
|
+
primary_target = target_list[0]
|
|
564
576
|
|
|
565
|
-
#
|
|
566
|
-
#
|
|
577
|
+
# Save results for each target as cv_{target}
|
|
578
|
+
# Also save primary target as "full_cross_fold" for backward compatibility
|
|
567
579
|
for target in target_list:
|
|
568
|
-
capture_name = f"cv_{target}"
|
|
569
|
-
description = capture_name.replace("_", " ").title()
|
|
570
|
-
|
|
571
580
|
# Drop rows with NaN target values for metrics/plots
|
|
572
581
|
target_df = out_of_fold_df.dropna(subset=[target])
|
|
573
582
|
|
|
@@ -579,6 +588,9 @@ class EndpointCore(Artifact):
|
|
|
579
588
|
else:
|
|
580
589
|
target_metrics = pd.DataFrame()
|
|
581
590
|
|
|
591
|
+
# Save as cv_{target}
|
|
592
|
+
capture_name = f"cv_{target}"
|
|
593
|
+
description = capture_name.replace("_", " ").title()
|
|
582
594
|
self._capture_inference_results(
|
|
583
595
|
capture_name,
|
|
584
596
|
target_df,
|
|
@@ -590,6 +602,19 @@ class EndpointCore(Artifact):
|
|
|
590
602
|
id_column=id_column,
|
|
591
603
|
)
|
|
592
604
|
|
|
605
|
+
# Also save primary target as "full_cross_fold" for backward compatibility
|
|
606
|
+
if target == primary_target:
|
|
607
|
+
self._capture_inference_results(
|
|
608
|
+
"full_cross_fold",
|
|
609
|
+
target_df,
|
|
610
|
+
target,
|
|
611
|
+
model_type,
|
|
612
|
+
target_metrics,
|
|
613
|
+
"Full Cross Fold",
|
|
614
|
+
features=additional_columns,
|
|
615
|
+
id_column=id_column,
|
|
616
|
+
)
|
|
617
|
+
|
|
593
618
|
return out_of_fold_df
|
|
594
619
|
|
|
595
620
|
def fast_inference(self, eval_df: pd.DataFrame, threads: int = 4) -> pd.DataFrame:
|
|
@@ -795,30 +820,6 @@ class EndpointCore(Artifact):
|
|
|
795
820
|
combined = row_hashes.values.tobytes()
|
|
796
821
|
return hashlib.md5(combined).hexdigest()[:hash_length]
|
|
797
822
|
|
|
798
|
-
@staticmethod
|
|
799
|
-
def _find_prediction_column(df: pd.DataFrame, target_column: str) -> Optional[str]:
|
|
800
|
-
"""Find the prediction column in a DataFrame.
|
|
801
|
-
|
|
802
|
-
Looks for 'prediction' column first, then '{target}_pred' pattern.
|
|
803
|
-
|
|
804
|
-
Args:
|
|
805
|
-
df: DataFrame to search
|
|
806
|
-
target_column: Name of the target column (used for {target}_pred pattern)
|
|
807
|
-
|
|
808
|
-
Returns:
|
|
809
|
-
Name of the prediction column, or None if not found
|
|
810
|
-
"""
|
|
811
|
-
# Check for 'prediction' column first (legacy/standard format)
|
|
812
|
-
if "prediction" in df.columns:
|
|
813
|
-
return "prediction"
|
|
814
|
-
|
|
815
|
-
# Check for '{target}_pred' format (multi-target format)
|
|
816
|
-
target_pred_col = f"{target_column}_pred"
|
|
817
|
-
if target_pred_col in df.columns:
|
|
818
|
-
return target_pred_col
|
|
819
|
-
|
|
820
|
-
return None
|
|
821
|
-
|
|
822
823
|
def _capture_inference_results(
|
|
823
824
|
self,
|
|
824
825
|
capture_name: str,
|
|
@@ -946,29 +947,23 @@ class EndpointCore(Artifact):
|
|
|
946
947
|
self.log.warning("No predictions were made. Returning empty DataFrame.")
|
|
947
948
|
return pd.DataFrame()
|
|
948
949
|
|
|
949
|
-
#
|
|
950
|
-
|
|
951
|
-
|
|
952
|
-
self.log.warning(f"No prediction column found for target '{target_column}'")
|
|
950
|
+
# Check for prediction column
|
|
951
|
+
if "prediction" not in prediction_df.columns:
|
|
952
|
+
self.log.warning("No 'prediction' column found in DataFrame")
|
|
953
953
|
return pd.DataFrame()
|
|
954
954
|
|
|
955
955
|
# Check for NaN values in target or prediction columns
|
|
956
|
-
if prediction_df[target_column].isnull().any() or prediction_df[
|
|
957
|
-
# Compute the number of NaN values in each column
|
|
956
|
+
if prediction_df[target_column].isnull().any() or prediction_df["prediction"].isnull().any():
|
|
958
957
|
num_nan_target = prediction_df[target_column].isnull().sum()
|
|
959
|
-
num_nan_prediction = prediction_df[
|
|
960
|
-
self.log.warning(
|
|
961
|
-
|
|
962
|
-
)
|
|
963
|
-
self.log.warning(
|
|
964
|
-
"NaN values found in target or prediction columns. Dropping NaN rows for metric computation."
|
|
965
|
-
)
|
|
966
|
-
prediction_df = prediction_df.dropna(subset=[target_column, prediction_col])
|
|
958
|
+
num_nan_prediction = prediction_df["prediction"].isnull().sum()
|
|
959
|
+
self.log.warning(f"NaNs Found: {target_column} {num_nan_target} and prediction: {num_nan_prediction}.")
|
|
960
|
+
self.log.warning("Dropping NaN rows for metric computation.")
|
|
961
|
+
prediction_df = prediction_df.dropna(subset=[target_column, "prediction"])
|
|
967
962
|
|
|
968
963
|
# Compute the metrics
|
|
969
964
|
try:
|
|
970
965
|
y_true = prediction_df[target_column]
|
|
971
|
-
y_pred = prediction_df[
|
|
966
|
+
y_pred = prediction_df["prediction"]
|
|
972
967
|
|
|
973
968
|
mae = mean_absolute_error(y_true, y_pred)
|
|
974
969
|
rmse = np.sqrt(mean_squared_error(y_true, y_pred))
|
|
@@ -1000,17 +995,13 @@ class EndpointCore(Artifact):
|
|
|
1000
995
|
Returns:
|
|
1001
996
|
pd.DataFrame: DataFrame with two new columns called 'residuals' and 'residuals_abs'
|
|
1002
997
|
"""
|
|
1003
|
-
|
|
1004
|
-
|
|
1005
|
-
|
|
1006
|
-
|
|
1007
|
-
# Find the prediction column: "prediction" or "{target}_pred"
|
|
1008
|
-
prediction_col = self._find_prediction_column(prediction_df, target_column)
|
|
1009
|
-
if prediction_col is None:
|
|
1010
|
-
self.log.warning(f"No prediction column found for target '{target_column}'. Cannot compute residuals.")
|
|
998
|
+
# Check for prediction column
|
|
999
|
+
if "prediction" not in prediction_df.columns:
|
|
1000
|
+
self.log.warning("No 'prediction' column found. Cannot compute residuals.")
|
|
1011
1001
|
return prediction_df
|
|
1012
1002
|
|
|
1013
|
-
|
|
1003
|
+
y_true = prediction_df[target_column]
|
|
1004
|
+
y_pred = prediction_df["prediction"]
|
|
1014
1005
|
|
|
1015
1006
|
# Check for classification scenario
|
|
1016
1007
|
if not pd.api.types.is_numeric_dtype(y_true) or not pd.api.types.is_numeric_dtype(y_pred):
|
|
@@ -1051,14 +1042,13 @@ class EndpointCore(Artifact):
|
|
|
1051
1042
|
Returns:
|
|
1052
1043
|
pd.DataFrame: DataFrame with the performance metrics
|
|
1053
1044
|
"""
|
|
1054
|
-
#
|
|
1055
|
-
|
|
1056
|
-
|
|
1057
|
-
self.log.warning(f"No prediction column found for target '{target_column}'")
|
|
1045
|
+
# Check for prediction column
|
|
1046
|
+
if "prediction" not in prediction_df.columns:
|
|
1047
|
+
self.log.warning("No 'prediction' column found in DataFrame")
|
|
1058
1048
|
return pd.DataFrame()
|
|
1059
1049
|
|
|
1060
1050
|
# Drop rows with NaN predictions (can't compute metrics on missing predictions)
|
|
1061
|
-
nan_mask = prediction_df[
|
|
1051
|
+
nan_mask = prediction_df["prediction"].isna()
|
|
1062
1052
|
if nan_mask.any():
|
|
1063
1053
|
n_nan = nan_mask.sum()
|
|
1064
1054
|
self.log.warning(f"Dropping {n_nan} rows with NaN predictions for metrics calculation")
|
|
@@ -1078,7 +1068,7 @@ class EndpointCore(Artifact):
|
|
|
1078
1068
|
# Calculate precision, recall, f1, and support, handling zero division
|
|
1079
1069
|
scores = precision_recall_fscore_support(
|
|
1080
1070
|
prediction_df[target_column],
|
|
1081
|
-
prediction_df[
|
|
1071
|
+
prediction_df["prediction"],
|
|
1082
1072
|
average=None,
|
|
1083
1073
|
labels=class_labels,
|
|
1084
1074
|
zero_division=0,
|
|
@@ -1126,21 +1116,20 @@ class EndpointCore(Artifact):
|
|
|
1126
1116
|
Returns:
|
|
1127
1117
|
pd.DataFrame: DataFrame with the confusion matrix
|
|
1128
1118
|
"""
|
|
1129
|
-
#
|
|
1130
|
-
|
|
1131
|
-
|
|
1132
|
-
self.log.warning(f"No prediction column found for target '{target_column}'")
|
|
1119
|
+
# Check for prediction column
|
|
1120
|
+
if "prediction" not in prediction_df.columns:
|
|
1121
|
+
self.log.warning("No 'prediction' column found in DataFrame")
|
|
1133
1122
|
return pd.DataFrame()
|
|
1134
1123
|
|
|
1135
1124
|
# Drop rows with NaN predictions (can't include in confusion matrix)
|
|
1136
|
-
nan_mask = prediction_df[
|
|
1125
|
+
nan_mask = prediction_df["prediction"].isna()
|
|
1137
1126
|
if nan_mask.any():
|
|
1138
1127
|
n_nan = nan_mask.sum()
|
|
1139
1128
|
self.log.warning(f"Dropping {n_nan} rows with NaN predictions for confusion matrix")
|
|
1140
1129
|
prediction_df = prediction_df[~nan_mask].copy()
|
|
1141
1130
|
|
|
1142
1131
|
y_true = prediction_df[target_column]
|
|
1143
|
-
y_pred = prediction_df[
|
|
1132
|
+
y_pred = prediction_df["prediction"]
|
|
1144
1133
|
|
|
1145
1134
|
# Get model class labels
|
|
1146
1135
|
model_class_labels = ModelCore(self.model_name).class_labels()
|
|
@@ -102,10 +102,21 @@ class ModelToEndpoint(Transform):
|
|
|
102
102
|
# Is this a serverless deployment?
|
|
103
103
|
serverless_config = None
|
|
104
104
|
if self.serverless:
|
|
105
|
+
# For PyTorch or ChemProp we need at least 4GB of memory
|
|
106
|
+
from workbench.api import ModelFramework
|
|
107
|
+
|
|
108
|
+
self.log.info(f"Model Framework: {workbench_model.model_framework}")
|
|
109
|
+
if workbench_model.model_framework in [ModelFramework.PYTORCH_TABULAR, ModelFramework.CHEMPROP]:
|
|
110
|
+
if mem_size < 4096:
|
|
111
|
+
self.log.important(
|
|
112
|
+
f"{workbench_model.model_framework} needs at least 4GB of memory (setting to 4GB)"
|
|
113
|
+
)
|
|
114
|
+
mem_size = 4096
|
|
105
115
|
serverless_config = ServerlessInferenceConfig(
|
|
106
116
|
memory_size_in_mb=mem_size,
|
|
107
117
|
max_concurrency=max_concurrency,
|
|
108
118
|
)
|
|
119
|
+
self.log.important(f"Serverless Config: Memory={mem_size}MB, MaxConcurrency={max_concurrency}")
|
|
109
120
|
|
|
110
121
|
# Configure data capture if requested (and not serverless)
|
|
111
122
|
data_capture_config = None
|
|
@@ -25,6 +25,7 @@
|
|
|
25
25
|
# - argparse, file loading, S3 writes
|
|
26
26
|
# =============================
|
|
27
27
|
|
|
28
|
+
import glob
|
|
28
29
|
import os
|
|
29
30
|
import argparse
|
|
30
31
|
import json
|
|
@@ -185,7 +186,7 @@ def build_mpnn_model(
|
|
|
185
186
|
# Model hyperparameters with defaults
|
|
186
187
|
hidden_dim = hyperparameters.get("hidden_dim", 700)
|
|
187
188
|
depth = hyperparameters.get("depth", 6)
|
|
188
|
-
dropout = hyperparameters.get("dropout", 0.
|
|
189
|
+
dropout = hyperparameters.get("dropout", 0.15)
|
|
189
190
|
ffn_hidden_dim = hyperparameters.get("ffn_hidden_dim", 2000)
|
|
190
191
|
ffn_num_layers = hyperparameters.get("ffn_num_layers", 2)
|
|
191
192
|
|
|
@@ -468,6 +469,11 @@ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
|
|
|
468
469
|
df.loc[valid_mask, f"{tc}_pred"] = preds[:, t_idx]
|
|
469
470
|
df.loc[valid_mask, f"{tc}_pred_std"] = preds_std[:, t_idx]
|
|
470
471
|
|
|
472
|
+
# Add prediction/prediction_std aliases for first target
|
|
473
|
+
first_target = target_columns[0]
|
|
474
|
+
df["prediction"] = df[f"{first_target}_pred"]
|
|
475
|
+
df["prediction_std"] = df[f"{first_target}_pred_std"]
|
|
476
|
+
|
|
471
477
|
return df
|
|
472
478
|
|
|
473
479
|
|
|
@@ -881,6 +887,11 @@ if __name__ == "__main__":
|
|
|
881
887
|
else:
|
|
882
888
|
df_val[f"{t_name}_pred_std"] = 0.0
|
|
883
889
|
|
|
890
|
+
# Add prediction/prediction_std aliases for first target
|
|
891
|
+
first_target = target_columns[0]
|
|
892
|
+
df_val["prediction"] = df_val[f"{first_target}_pred"]
|
|
893
|
+
df_val["prediction_std"] = df_val[f"{first_target}_pred_std"]
|
|
894
|
+
|
|
884
895
|
# Save validation predictions to S3
|
|
885
896
|
# Include id_column if it exists in df_val
|
|
886
897
|
output_columns = []
|
|
@@ -890,6 +901,7 @@ if __name__ == "__main__":
|
|
|
890
901
|
output_columns += target_columns
|
|
891
902
|
output_columns += [f"{t}_pred" for t in target_columns]
|
|
892
903
|
output_columns += [f"{t}_pred_std" for t in target_columns]
|
|
904
|
+
output_columns += ["prediction", "prediction_std"]
|
|
893
905
|
# Add proba columns for classifiers
|
|
894
906
|
output_columns += [col for col in df_val.columns if col.endswith("_proba")]
|
|
895
907
|
# Filter to only columns that exist
|
|
@@ -906,6 +918,11 @@ if __name__ == "__main__":
|
|
|
906
918
|
models.save_model(model_path, ens_model)
|
|
907
919
|
print(f"Saved model {model_idx + 1} to {model_path}")
|
|
908
920
|
|
|
921
|
+
# Clean up checkpoint files (not needed for inference, reduces artifact size)
|
|
922
|
+
for ckpt_file in glob.glob(os.path.join(args.model_dir, "best_model_*.ckpt")):
|
|
923
|
+
os.remove(ckpt_file)
|
|
924
|
+
print(f"Removed checkpoint: {ckpt_file}")
|
|
925
|
+
|
|
909
926
|
# Save ensemble metadata (n_ensemble = number of models for inference)
|
|
910
927
|
n_ensemble = len(ensemble_models)
|
|
911
928
|
ensemble_metadata = {
|
|
@@ -25,6 +25,7 @@
|
|
|
25
25
|
# - argparse, file loading, S3 writes
|
|
26
26
|
# =============================
|
|
27
27
|
|
|
28
|
+
import glob
|
|
28
29
|
import os
|
|
29
30
|
import argparse
|
|
30
31
|
import json
|
|
@@ -53,12 +54,12 @@ from chemprop import data, models, nn
|
|
|
53
54
|
|
|
54
55
|
# Template Parameters
|
|
55
56
|
TEMPLATE_PARAMS = {
|
|
56
|
-
"model_type": "
|
|
57
|
-
"targets": ['
|
|
58
|
-
"feature_list": ['smiles'],
|
|
59
|
-
"id_column": "
|
|
60
|
-
"model_metrics_s3_path": "s3://
|
|
61
|
-
"hyperparameters": {},
|
|
57
|
+
"model_type": "uq_regressor",
|
|
58
|
+
"targets": ['udm_asy_res_efflux_ratio'], # List of target columns (single or multi-task)
|
|
59
|
+
"feature_list": ['smiles', 'smr_vsa4', 'tpsa', 'nhohcount', 'mollogp', 'peoe_vsa1', 'smr_vsa3', 'nitrogen_span', 'numhdonors', 'minpartialcharge', 'vsa_estate3', 'vsa_estate6', 'tertiary_amine_count', 'hba_hbd_ratio', 'peoe_vsa8', 'estate_vsa4', 'xc_4dv', 'vsa_estate2', 'molmr', 'xp_2dv', 'mi', 'molecular_axis_length', 'vsa_estate4', 'xp_6dv', 'qed', 'estate_vsa8', 'chi1v', 'asphericity', 'axp_1d', 'bcut2d_logphi', 'kappa3', 'axp_7d', 'num_s_centers', 'amphiphilic_moment', 'molecular_asymmetry', 'charge_centroid_distance', 'estate_vsa3', 'vsa_estate8', 'aromatic_interaction_score', 'molecular_volume_3d', 'axp_7dv', 'peoe_vsa3', 'smr_vsa6', 'bcut2d_mrhi', 'radius_of_gyration', 'xpc_4dv', 'minabsestateindex', 'axp_0dv', 'chi4n', 'balabanj', 'bcut2d_mwlow'],
|
|
60
|
+
"id_column": "udm_mol_bat_id",
|
|
61
|
+
"model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/caco2-er-chemprop-reg-hybrid/training",
|
|
62
|
+
"hyperparameters": {'n_folds': 5, 'hidden_dim': 700, 'depth': 6, 'dropout': 0.15, 'ffn_hidden_dim': 2000, 'ffn_num_layers': 2},
|
|
62
63
|
}
|
|
63
64
|
|
|
64
65
|
|
|
@@ -185,7 +186,7 @@ def build_mpnn_model(
|
|
|
185
186
|
# Model hyperparameters with defaults
|
|
186
187
|
hidden_dim = hyperparameters.get("hidden_dim", 700)
|
|
187
188
|
depth = hyperparameters.get("depth", 6)
|
|
188
|
-
dropout = hyperparameters.get("dropout", 0.
|
|
189
|
+
dropout = hyperparameters.get("dropout", 0.15)
|
|
189
190
|
ffn_hidden_dim = hyperparameters.get("ffn_hidden_dim", 2000)
|
|
190
191
|
ffn_num_layers = hyperparameters.get("ffn_num_layers", 2)
|
|
191
192
|
|
|
@@ -468,6 +469,11 @@ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
|
|
|
468
469
|
df.loc[valid_mask, f"{tc}_pred"] = preds[:, t_idx]
|
|
469
470
|
df.loc[valid_mask, f"{tc}_pred_std"] = preds_std[:, t_idx]
|
|
470
471
|
|
|
472
|
+
# Add prediction/prediction_std aliases for first target
|
|
473
|
+
first_target = target_columns[0]
|
|
474
|
+
df["prediction"] = df[f"{first_target}_pred"]
|
|
475
|
+
df["prediction_std"] = df[f"{first_target}_pred_std"]
|
|
476
|
+
|
|
471
477
|
return df
|
|
472
478
|
|
|
473
479
|
|
|
@@ -881,6 +887,11 @@ if __name__ == "__main__":
|
|
|
881
887
|
else:
|
|
882
888
|
df_val[f"{t_name}_pred_std"] = 0.0
|
|
883
889
|
|
|
890
|
+
# Add prediction/prediction_std aliases for first target
|
|
891
|
+
first_target = target_columns[0]
|
|
892
|
+
df_val["prediction"] = df_val[f"{first_target}_pred"]
|
|
893
|
+
df_val["prediction_std"] = df_val[f"{first_target}_pred_std"]
|
|
894
|
+
|
|
884
895
|
# Save validation predictions to S3
|
|
885
896
|
# Include id_column if it exists in df_val
|
|
886
897
|
output_columns = []
|
|
@@ -890,6 +901,7 @@ if __name__ == "__main__":
|
|
|
890
901
|
output_columns += target_columns
|
|
891
902
|
output_columns += [f"{t}_pred" for t in target_columns]
|
|
892
903
|
output_columns += [f"{t}_pred_std" for t in target_columns]
|
|
904
|
+
output_columns += ["prediction", "prediction_std"]
|
|
893
905
|
# Add proba columns for classifiers
|
|
894
906
|
output_columns += [col for col in df_val.columns if col.endswith("_proba")]
|
|
895
907
|
# Filter to only columns that exist
|
|
@@ -906,6 +918,11 @@ if __name__ == "__main__":
|
|
|
906
918
|
models.save_model(model_path, ens_model)
|
|
907
919
|
print(f"Saved model {model_idx + 1} to {model_path}")
|
|
908
920
|
|
|
921
|
+
# Clean up checkpoint files (not needed for inference, reduces artifact size)
|
|
922
|
+
for ckpt_file in glob.glob(os.path.join(args.model_dir, "best_model_*.ckpt")):
|
|
923
|
+
os.remove(ckpt_file)
|
|
924
|
+
print(f"Removed checkpoint: {ckpt_file}")
|
|
925
|
+
|
|
909
926
|
# Save ensemble metadata (n_ensemble = number of models for inference)
|
|
910
927
|
n_ensemble = len(ensemble_models)
|
|
911
928
|
ensemble_metadata = {
|
|
@@ -19,11 +19,11 @@ from typing import List, Tuple, Optional, Dict
|
|
|
19
19
|
|
|
20
20
|
# Template Placeholders
|
|
21
21
|
TEMPLATE_PARAMS = {
|
|
22
|
-
"target": "
|
|
23
|
-
"features": ['
|
|
22
|
+
"target": "udm_asy_res_efflux_ratio",
|
|
23
|
+
"features": ['smr_vsa4', 'tpsa', 'nhohcount', 'peoe_vsa1', 'mollogp', 'vsa_estate3', 'xc_4dv', 'smr_vsa3', 'tertiary_amine_count', 'peoe_vsa8', 'minpartialcharge', 'nitrogen_span', 'vsa_estate2', 'chi1v', 'hba_hbd_ratio', 'molecular_axis_length', 'molmr', 'vsa_estate4', 'num_s_centers', 'vsa_estate6', 'qed', 'numhdonors', 'mi', 'estate_vsa4', 'axp_7d', 'kappa3', 'asphericity', 'estate_vsa8', 'estate_vsa2', 'estate_vsa3', 'peoe_vsa3', 'xp_6dv', 'bcut2d_logphi', 'vsa_estate8', 'amphiphilic_moment', 'type_ii_pattern_count', 'minestateindex', 'charge_centroid_distance', 'molecular_asymmetry', 'molecular_volume_3d', 'bcut2d_mrlow', 'axp_1d', 'vsa_estate9', 'aromatic_interaction_score', 'xp_7dv', 'bcut2d_mwlow', 'axp_7dv', 'slogp_vsa1', 'maxestateindex', 'fr_al_oh', 'nbase', 'xp_2dv', 'radius_of_gyration', 'sps', 'xch_7d', 'bcut2d_mrhi', 'axp_0dv', 'vsa_estate5', 'hallkieralpha', 'xp_0dv', 'fr_nhpyrrole', 'smr_vsa1', 'smr_vsa6', 'chi2v', 'bcut2d_mwhi', 'estate_vsa6', 'bcut2d_logplow', 'peoe_vsa2', 'fractioncsp3', 'slogp_vsa2', 'c3sp3', 'peoe_vsa7', 'estate_vsa9', 'peoe_vsa9', 'avgipc', 'smr_vsa9', 'xpc_4dv', 'balabanj', 'axp_1dv', 'mv', 'minabsestateindex', 'bcut2d_chglo', 'fpdensitymorgan2', 'axp_4d', 'numsaturatedheterocycles', 'fpdensitymorgan1', 'axp_3dv', 'axp_5d', 'smr_vsa5', 'bcut2d_chghi', 'axp_3d', 'xpc_5dv', 'chi4n', 'peoe_vsa10', 'vsa_estate7', 'peoe_vsa11', 'estate_vsa10', 'xp_7d', 'slogp_vsa5', 'xch_7dv', 'vsa_estate10', 'labuteasa', 'estate_vsa5', 'xp_3d', 'chi1', 'xch_4dv', 'xp_6d', 'estate_vsa1', 'axp_4dv', 'phi', 'xp_3dv', 'xch_6dv', 'smr_vsa10', 'num_r_centers', 'xc_5d', 'maxpartialcharge', 'xc_3d', 'peoe_vsa6', 'fr_imidazole', 'axp_2d', 'slogp_vsa3', 'mz', 'axp_6dv', 'xch_6d', 'mm', 'numatomstereocenters', 'c1sp3', 'chi1n', 'fpdensitymorgan3', 'xp_5dv', 'chi3v', 'slogp_vsa4', 'fr_ether', 'xp_2d', 'chi3n', 'xch_5dv', 'axp_6d', 'xc_5dv', 'numheterocycles', 'mpe', 'fr_hoccn', 'xc_3dv', 'type_i_pattern_count', 'chi0v', 'xch_4d', 'numsaturatedcarbocycles', 'mp', 'xch_5d', 'maxabspartialcharge', 'axp_2dv', 'bertzct', 'sse', 'xpc_6dv', 'sv', 'xpc_4d', 'si', 'chi0n', 'mse', 'xpc_6d', 'peoe_vsa12', 'xpc_5d', 'kappa2', 'axp_5dv', 'kappa1', 'chi2n', 'intramolecular_hbond_potential', 'fr_nh0', 'numaliphaticheterocycles', 'smr_vsa7', 'mare', 'fr_priamide', 'vsa_estate1', 'num_stereocenters', 'fr_nh1', 'estate_vsa7', 'fr_piperzine', 'c1sp2', 'slogp_vsa6', 'xp_5d', 'fr_aryl_methyl', 'molwt', 'chi4v', 'xc_6dv', 'heavyatommolwt', 'xp_4d', 'sp', 'slogp_vsa7', 'numhacceptors', 'c2sp3', 'peoe_vsa4', 'slogp_vsa10', 'fr_morpholine', 'fr_methoxy', 'fr_aniline', 'xp_4dv', 'fr_urea', 'c3sp2', 'fr_pyridine', 'hybratio', 'fr_thiazole', 'minabspartialcharge', 'sm', 'axp_0d', 'numaromaticheterocycles', 'nocount', 'xc_4d', 'peoe_vsa13', 'fr_amide', 'num_defined_stereocenters', 'amide_count', 'xc_6d', 'numrotatablebonds', 'c2sp2', 'fr_piperdine', 'numvalenceelectrons', 'c1sp1', 'fr_nitrile', 'fr_phenol', 'c4sp3', 'spe', 'numheteroatoms', 'estate_vsa11', 'sz', 'chi0', 'smr_vsa2', 'fr_ketone_topliss', 'slogp_vsa11', 'fr_benzene', 'fr_ndealkylation2', 'peoe_vsa5', 'fr_c_o', 'numsaturatedrings', 'exactmolwt', 'sare', 'numaliphaticrings', 'fr_al_oh_notert', 'fr_imine', 'frac_defined_stereo', 'numunspecifiedatomstereocenters', 'fr_ar_n', 'fr_bicyclic', 'fr_c_o_nocoo', 'numspiroatoms', 'fr_sulfone', 'fr_ndealkylation1'],
|
|
24
24
|
"compressed_features": [],
|
|
25
25
|
"train_all_data": True,
|
|
26
|
-
"hyperparameters": {'
|
|
26
|
+
"hyperparameters": {'n_estimators': 500, 'max_depth': 6, 'learning_rate': 0.04},
|
|
27
27
|
}
|
|
28
28
|
|
|
29
29
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: workbench
|
|
3
|
-
Version: 0.8.
|
|
3
|
+
Version: 0.8.205
|
|
4
4
|
Summary: Workbench: A Dashboard and Python API for creating and deploying AWS SageMaker Model Pipelines
|
|
5
5
|
Author-email: SuperCowPowers LLC <support@supercowpowers.com>
|
|
6
6
|
License: MIT License
|
|
@@ -55,7 +55,7 @@ workbench/core/artifacts/data_capture_core.py,sha256=q8f79rRTYiZ7T4IQRWXl8ZvPpcv
|
|
|
55
55
|
workbench/core/artifacts/data_source_abstract.py,sha256=5IRCzFVK-17cd4NXPMRfx99vQAmQ0WHE5jcm5RfsVTg,10619
|
|
56
56
|
workbench/core/artifacts/data_source_factory.py,sha256=YL_tA5fsgubbB3dPF6T4tO0rGgz-6oo3ge4i_YXVC-M,2380
|
|
57
57
|
workbench/core/artifacts/df_store_core.py,sha256=AueNr_JvuLLu_ByE7cb3u-isH9u0Q7cMP-UCgCX-Ctg,3536
|
|
58
|
-
workbench/core/artifacts/endpoint_core.py,sha256=
|
|
58
|
+
workbench/core/artifacts/endpoint_core.py,sha256=eyjEd8KXMkqUwI7rFuuT0cMZMMrdSBSj3moR-EagS8w,60244
|
|
59
59
|
workbench/core/artifacts/feature_set_core.py,sha256=wZy-02WXWmSBet5t8mWXFRdv9O4MtW3hWqJuVv7Kok0,39330
|
|
60
60
|
workbench/core/artifacts/model_core.py,sha256=QIgV5MJr8aDY63in83thdNc5-bzkWLn5f5vvsS4aNYo,52348
|
|
61
61
|
workbench/core/artifacts/monitor_core.py,sha256=M307yz7tEzOEHgv-LmtVy9jKjSbM98fHW3ckmNYrwlU,27897
|
|
@@ -104,7 +104,7 @@ workbench/core/transforms/features_to_features/heavy/glue/Readme.md,sha256=TuyCa
|
|
|
104
104
|
workbench/core/transforms/features_to_model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
105
105
|
workbench/core/transforms/features_to_model/features_to_model.py,sha256=JdKKz3eKrKhicA1WxTfmb1IqQNCdHJE0CKDs66bLHYU,21071
|
|
106
106
|
workbench/core/transforms/model_to_endpoint/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
107
|
-
workbench/core/transforms/model_to_endpoint/model_to_endpoint.py,sha256=
|
|
107
|
+
workbench/core/transforms/model_to_endpoint/model_to_endpoint.py,sha256=QjfUY_Ay2-W8OszWw2vGtsKfnMY7VjiWQmnjuzLBITk,7020
|
|
108
108
|
workbench/core/transforms/pandas_transforms/__init__.py,sha256=xL4MT8-fZ1SFqDbTLc8XyxjupHtB1YR6Ej0AC2nwd7I,894
|
|
109
109
|
workbench/core/transforms/pandas_transforms/data_to_pandas.py,sha256=sJHPeuNF8Q8aQqgRnkdWkyvur5cbggdUVIwR-xF3Dlo,3621
|
|
110
110
|
workbench/core/transforms/pandas_transforms/features_to_pandas.py,sha256=af6xdPt2V4zhh-SzQa_UYxdmNMzMLXbrbsznV5QoIJg,3441
|
|
@@ -123,8 +123,8 @@ workbench/core/views/view.py,sha256=DvmEA1xdvL980GET_cnbmHzqSy6IhlNaZcoQnVTtYis,
|
|
|
123
123
|
workbench/core/views/view_utils.py,sha256=CwOlpqXpumCr6REi-ey7Qjz5_tpg-s4oWHmlOVu8POQ,12270
|
|
124
124
|
workbench/core/views/storage/mdq_view.py,sha256=qf_ep1KwaXOIfO930laEwNIiCYP7VNOqjE3VdHfopRE,5195
|
|
125
125
|
workbench/model_scripts/script_generation.py,sha256=_AhzM2qzjBuI7pIaXBRZ1YOOs2lwsKQGVM_ovL6T1bo,8135
|
|
126
|
-
workbench/model_scripts/chemprop/chemprop.template,sha256=
|
|
127
|
-
workbench/model_scripts/chemprop/generated_model_script.py,sha256=
|
|
126
|
+
workbench/model_scripts/chemprop/chemprop.template,sha256=XcRBEz_JYS1Vjv9MI_5BalvrWL9v2vTq1eRlVpLAtPE,38883
|
|
127
|
+
workbench/model_scripts/chemprop/generated_model_script.py,sha256=lSr5qHZljCzttxlq4YwypUYmYbIAl7flo5RT8nXt_vs,39755
|
|
128
128
|
workbench/model_scripts/chemprop/requirements.txt,sha256=PIuUdPAeDUH3I2M_5nIrCnCfs3FL1l9V5kzHqgCcu7s,281
|
|
129
129
|
workbench/model_scripts/custom_models/chem_info/Readme.md,sha256=mH1lxJ4Pb7F5nBnVXaiuxpi8zS_yjUw_LBJepVKXhlA,574
|
|
130
130
|
workbench/model_scripts/custom_models/chem_info/fingerprints.py,sha256=Qvs8jaUwguWUq3Q3j695MY0t0Wk3BvroW-oWBwalMUo,5255
|
|
@@ -157,7 +157,7 @@ workbench/model_scripts/pytorch_model/requirements.txt,sha256=ICS5nW0wix44EJO2tJ
|
|
|
157
157
|
workbench/model_scripts/scikit_learn/generated_model_script.py,sha256=xhQIglpAgPRCH9iwI3wI0N0V6p9AgqW0mVOMuSXzUCk,17187
|
|
158
158
|
workbench/model_scripts/scikit_learn/requirements.txt,sha256=aVvwiJ3LgBUhM_PyFlb2gHXu_kpGPho3ANBzlOkfcvs,107
|
|
159
159
|
workbench/model_scripts/scikit_learn/scikit_learn.template,sha256=QQvqx-eX9ZTbYmyupq6R6vIQwosmsmY_MRBPaHyfjdk,12586
|
|
160
|
-
workbench/model_scripts/uq_models/generated_model_script.py,sha256=
|
|
160
|
+
workbench/model_scripts/uq_models/generated_model_script.py,sha256=0HqH1bY3fXgZTQAFLxfnrPfBEQvTmeMus5C2z7HoeyU,26765
|
|
161
161
|
workbench/model_scripts/uq_models/mapie.template,sha256=on3I40D7zyNfvfqBf5k8VXCFtmepcxKmqVWCH5Q9S84,23432
|
|
162
162
|
workbench/model_scripts/uq_models/requirements.txt,sha256=fw7T7t_YJAXK3T6Ysbesxh_Agx_tv0oYx72cEBTqRDY,98
|
|
163
163
|
workbench/model_scripts/xgb_model/generated_model_script.py,sha256=qUGg5R-boaswzXtgKp_J7JPxFzMdRNv51QeF-lMWL-4,19334
|
|
@@ -291,9 +291,9 @@ workbench/web_interface/page_views/main_page.py,sha256=X4-KyGTKLAdxR-Zk2niuLJB2Y
|
|
|
291
291
|
workbench/web_interface/page_views/models_page_view.py,sha256=M0bdC7bAzLyIaE2jviY12FF4abdMFZmg6sFuOY_LaGI,2650
|
|
292
292
|
workbench/web_interface/page_views/page_view.py,sha256=Gh6YnpOGlUejx-bHZAf5pzqoQ1H1R0OSwOpGhOBO06w,455
|
|
293
293
|
workbench/web_interface/page_views/pipelines_page_view.py,sha256=v2pxrIbsHBcYiblfius3JK766NZ7ciD2yPx0t3E5IJo,2656
|
|
294
|
-
workbench-0.8.
|
|
295
|
-
workbench-0.8.
|
|
296
|
-
workbench-0.8.
|
|
297
|
-
workbench-0.8.
|
|
298
|
-
workbench-0.8.
|
|
299
|
-
workbench-0.8.
|
|
294
|
+
workbench-0.8.205.dist-info/licenses/LICENSE,sha256=RTBoTMeEwTgEhS-n8vgQ-VUo5qig0PWVd8xFPKU6Lck,1080
|
|
295
|
+
workbench-0.8.205.dist-info/METADATA,sha256=4fgPE_3_5UQK9Av-WuIaRPZW-nwcIJVekAXYPbyx5hU,10500
|
|
296
|
+
workbench-0.8.205.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
297
|
+
workbench-0.8.205.dist-info/entry_points.txt,sha256=j02NCuno2Y_BuE4jEvw-IL73WZ9lkTpLwom29uKcLCw,458
|
|
298
|
+
workbench-0.8.205.dist-info/top_level.txt,sha256=Dhy72zTxaA_o_yRkPZx5zw-fwumnjGaeGf0hBN3jc_w,10
|
|
299
|
+
workbench-0.8.205.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|