workbench 0.8.162__py3-none-any.whl → 0.8.202__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of workbench might be problematic. Click here for more details.
- workbench/algorithms/dataframe/__init__.py +1 -2
- workbench/algorithms/dataframe/fingerprint_proximity.py +2 -2
- workbench/algorithms/dataframe/proximity.py +261 -235
- workbench/algorithms/graph/light/proximity_graph.py +10 -8
- workbench/api/__init__.py +2 -1
- workbench/api/compound.py +1 -1
- workbench/api/endpoint.py +11 -0
- workbench/api/feature_set.py +11 -8
- workbench/api/meta.py +5 -2
- workbench/api/model.py +16 -15
- workbench/api/monitor.py +1 -16
- workbench/core/artifacts/__init__.py +11 -2
- workbench/core/artifacts/artifact.py +11 -3
- workbench/core/artifacts/data_capture_core.py +355 -0
- workbench/core/artifacts/endpoint_core.py +256 -118
- workbench/core/artifacts/feature_set_core.py +265 -16
- workbench/core/artifacts/model_core.py +107 -60
- workbench/core/artifacts/monitor_core.py +33 -248
- workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
- workbench/core/cloud_platform/aws/aws_meta.py +12 -5
- workbench/core/cloud_platform/aws/aws_parameter_store.py +18 -2
- workbench/core/cloud_platform/aws/aws_session.py +4 -4
- workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
- workbench/core/transforms/features_to_model/features_to_model.py +42 -32
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +36 -6
- workbench/core/transforms/pandas_transforms/pandas_to_features.py +27 -0
- workbench/core/views/training_view.py +113 -42
- workbench/core/views/view.py +53 -3
- workbench/core/views/view_utils.py +4 -4
- workbench/model_scripts/chemprop/chemprop.template +852 -0
- workbench/model_scripts/chemprop/generated_model_script.py +852 -0
- workbench/model_scripts/chemprop/requirements.txt +11 -0
- workbench/model_scripts/custom_models/chem_info/fingerprints.py +134 -0
- workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
- workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
- workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
- workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +3 -5
- workbench/model_scripts/custom_models/proximity/proximity.py +261 -235
- workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
- workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
- workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
- workbench/model_scripts/custom_models/uq_models/meta_uq.template +166 -62
- workbench/model_scripts/custom_models/uq_models/ngboost.template +30 -18
- workbench/model_scripts/custom_models/uq_models/proximity.py +261 -235
- workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
- workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
- workbench/model_scripts/pytorch_model/generated_model_script.py +373 -190
- workbench/model_scripts/pytorch_model/pytorch.template +370 -187
- workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
- workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
- workbench/model_scripts/script_generation.py +17 -9
- workbench/model_scripts/uq_models/generated_model_script.py +605 -0
- workbench/model_scripts/uq_models/mapie.template +605 -0
- workbench/model_scripts/uq_models/requirements.txt +1 -0
- workbench/model_scripts/xgb_model/generated_model_script.py +37 -46
- workbench/model_scripts/xgb_model/xgb_model.template +44 -46
- workbench/repl/workbench_shell.py +28 -14
- workbench/scripts/endpoint_test.py +162 -0
- workbench/scripts/lambda_test.py +73 -0
- workbench/scripts/ml_pipeline_batch.py +137 -0
- workbench/scripts/ml_pipeline_sqs.py +186 -0
- workbench/scripts/monitor_cloud_watch.py +20 -100
- workbench/utils/aws_utils.py +4 -3
- workbench/utils/chem_utils/__init__.py +0 -0
- workbench/utils/chem_utils/fingerprints.py +134 -0
- workbench/utils/chem_utils/misc.py +194 -0
- workbench/utils/chem_utils/mol_descriptors.py +483 -0
- workbench/utils/chem_utils/mol_standardize.py +450 -0
- workbench/utils/chem_utils/mol_tagging.py +348 -0
- workbench/utils/chem_utils/projections.py +209 -0
- workbench/utils/chem_utils/salts.py +256 -0
- workbench/utils/chem_utils/sdf.py +292 -0
- workbench/utils/chem_utils/toxicity.py +250 -0
- workbench/utils/chem_utils/vis.py +253 -0
- workbench/utils/chemprop_utils.py +760 -0
- workbench/utils/cloudwatch_handler.py +1 -1
- workbench/utils/cloudwatch_utils.py +137 -0
- workbench/utils/config_manager.py +3 -7
- workbench/utils/endpoint_utils.py +5 -7
- workbench/utils/license_manager.py +2 -6
- workbench/utils/model_utils.py +95 -34
- workbench/utils/monitor_utils.py +44 -62
- workbench/utils/pandas_utils.py +3 -3
- workbench/utils/pytorch_utils.py +526 -0
- workbench/utils/shap_utils.py +10 -2
- workbench/utils/workbench_logging.py +0 -3
- workbench/utils/workbench_sqs.py +1 -1
- workbench/utils/xgboost_model_utils.py +371 -156
- workbench/web_interface/components/model_plot.py +7 -1
- workbench/web_interface/components/plugin_unit_test.py +5 -2
- workbench/web_interface/components/plugins/dashboard_status.py +3 -1
- workbench/web_interface/components/plugins/generated_compounds.py +1 -1
- workbench/web_interface/components/plugins/model_details.py +9 -7
- workbench/web_interface/components/plugins/scatter_plot.py +3 -3
- {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/METADATA +27 -6
- {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/RECORD +101 -85
- {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/entry_points.txt +4 -0
- {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/licenses/LICENSE +1 -1
- workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
- workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
- workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
- workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
- workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
- workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
- workbench/model_scripts/quant_regression/quant_regression.template +0 -279
- workbench/model_scripts/quant_regression/requirements.txt +0 -1
- workbench/utils/chem_utils.py +0 -1556
- workbench/utils/execution_environment.py +0 -211
- workbench/utils/fast_inference.py +0 -167
- workbench/utils/resource_utils.py +0 -39
- {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/WHEEL +0 -0
- {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/top_level.txt +0 -0
|
@@ -1,83 +0,0 @@
|
|
|
1
|
-
# Model: tautomerization_processor
|
|
2
|
-
#
|
|
3
|
-
# Description: The tautomerization_processor model uses RDKit to perform tautomer enumeration
|
|
4
|
-
# and canonicalization of chemical compounds. Tautomerization is the chemical process where
|
|
5
|
-
# compounds can interconvert between structurally distinct forms, often affecting their
|
|
6
|
-
# chemical properties and reactivity. This model provides a robust approach to identifying
|
|
7
|
-
# and processing tautomers, crucial for improving molecular modeling and cheminformatics tasks
|
|
8
|
-
# like virtual screening, QSAR modeling, and property prediction.
|
|
9
|
-
#
|
|
10
|
-
import argparse
|
|
11
|
-
import os
|
|
12
|
-
import joblib
|
|
13
|
-
from io import StringIO
|
|
14
|
-
import pandas as pd
|
|
15
|
-
import json
|
|
16
|
-
|
|
17
|
-
# Local imports
|
|
18
|
-
from local_utils import tautomerize_smiles
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
# TRAINING SECTION
|
|
22
|
-
#
|
|
23
|
-
# This section (__main__) is where SageMaker will execute the job and save the model artifacts.
|
|
24
|
-
#
|
|
25
|
-
if __name__ == "__main__":
|
|
26
|
-
# Script arguments for input/output directories
|
|
27
|
-
parser = argparse.ArgumentParser()
|
|
28
|
-
parser.add_argument("--model-dir", type=str, default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model"))
|
|
29
|
-
parser.add_argument("--train", type=str, default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"))
|
|
30
|
-
parser.add_argument(
|
|
31
|
-
"--output-data-dir", type=str, default=os.environ.get("SM_OUTPUT_DATA_DIR", "/opt/ml/output/data")
|
|
32
|
-
)
|
|
33
|
-
args = parser.parse_args()
|
|
34
|
-
|
|
35
|
-
# This model doesn't get trained; it's a feature processing 'model'
|
|
36
|
-
|
|
37
|
-
# Sagemaker expects a model artifact, so we'll save a placeholder
|
|
38
|
-
placeholder_model = {}
|
|
39
|
-
joblib.dump(placeholder_model, os.path.join(args.model_dir, "model.joblib"))
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
# Model loading and prediction functions
|
|
43
|
-
def model_fn(model_dir):
|
|
44
|
-
return joblib.load(os.path.join(model_dir, "model.joblib"))
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
def input_fn(input_data, content_type):
|
|
48
|
-
"""Parse input data and return a DataFrame."""
|
|
49
|
-
if not input_data:
|
|
50
|
-
raise ValueError("Empty input data is not supported!")
|
|
51
|
-
|
|
52
|
-
# Decode bytes to string if necessary
|
|
53
|
-
if isinstance(input_data, bytes):
|
|
54
|
-
input_data = input_data.decode("utf-8")
|
|
55
|
-
|
|
56
|
-
if "text/csv" in content_type:
|
|
57
|
-
return pd.read_csv(StringIO(input_data))
|
|
58
|
-
elif "application/json" in content_type:
|
|
59
|
-
return pd.DataFrame(json.loads(input_data)) # Assumes JSON array of records
|
|
60
|
-
else:
|
|
61
|
-
raise ValueError(f"{content_type} not supported!")
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
def output_fn(output_df, accept_type):
|
|
65
|
-
"""Supports both CSV and JSON output formats."""
|
|
66
|
-
use_explicit_na = False
|
|
67
|
-
if "text/csv" in accept_type:
|
|
68
|
-
if use_explicit_na:
|
|
69
|
-
csv_output = output_df.fillna("N/A").to_csv(index=False) # CSV with N/A for missing values
|
|
70
|
-
else:
|
|
71
|
-
csv_output = output_df.to_csv(index=False)
|
|
72
|
-
return csv_output, "text/csv"
|
|
73
|
-
elif "application/json" in accept_type:
|
|
74
|
-
return output_df.to_json(orient="records"), "application/json" # JSON array of records (NaNs -> null)
|
|
75
|
-
else:
|
|
76
|
-
raise RuntimeError(f"{accept_type} accept type is not supported by this script.")
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
# Prediction function
|
|
80
|
-
def predict_fn(df, model):
|
|
81
|
-
# Perform Tautomerization
|
|
82
|
-
df = tautomerize_smiles(df)
|
|
83
|
-
return df
|
|
@@ -1,138 +0,0 @@
|
|
|
1
|
-
# Model: feature_space_proximity
|
|
2
|
-
#
|
|
3
|
-
# Description: The feature_space_proximity model computes nearest neighbors for the given feature space
|
|
4
|
-
#
|
|
5
|
-
|
|
6
|
-
# Template Placeholders
|
|
7
|
-
TEMPLATE_PARAMS = {
|
|
8
|
-
"id_column": "udm_mol_bat_id",
|
|
9
|
-
"features": ['bcut2d_logplow', 'numradicalelectrons', 'smr_vsa5', 'fr_lactam', 'fr_morpholine', 'fr_aldehyde', 'slogp_vsa1', 'fr_amidine', 'bpol', 'fr_ester', 'fr_azo', 'kappa3', 'peoe_vsa5', 'fr_ketone_topliss', 'vsa_estate9', 'estate_vsa9', 'bcut2d_mrhi', 'fr_ndealkylation1', 'numrotatablebonds', 'minestateindex', 'fr_quatn', 'peoe_vsa3', 'fr_epoxide', 'fr_aniline', 'minpartialcharge', 'fr_nitroso', 'fpdensitymorgan2', 'fr_oxime', 'fr_sulfone', 'smr_vsa1', 'kappa1', 'fr_pyridine', 'numaromaticrings', 'vsa_estate6', 'molmr', 'estate_vsa1', 'fr_dihydropyridine', 'vsa_estate10', 'fr_alkyl_halide', 'chi2n', 'fr_thiocyan', 'fpdensitymorgan1', 'fr_unbrch_alkane', 'slogp_vsa9', 'chi4n', 'fr_nitro_arom', 'fr_al_oh', 'fr_furan', 'fr_c_s', 'peoe_vsa8', 'peoe_vsa14', 'numheteroatoms', 'fr_ndealkylation2', 'maxabspartialcharge', 'vsa_estate2', 'peoe_vsa7', 'apol', 'numhacceptors', 'fr_tetrazole', 'vsa_estate1', 'peoe_vsa9', 'naromatom', 'bcut2d_chghi', 'fr_sh', 'fr_halogen', 'slogp_vsa4', 'fr_benzodiazepine', 'molwt', 'fr_isocyan', 'fr_prisulfonamd', 'maxabsestateindex', 'minabsestateindex', 'peoe_vsa11', 'slogp_vsa12', 'estate_vsa5', 'numaliphaticcarbocycles', 'bcut2d_mwlow', 'slogp_vsa7', 'fr_allylic_oxid', 'fr_methoxy', 'fr_nh0', 'fr_coo2', 'fr_phenol', 'nacid', 'nbase', 'chi3v', 'fr_ar_nh', 'fr_nitrile', 'fr_imidazole', 'fr_urea', 'bcut2d_mrlow', 'chi1', 'smr_vsa6', 'fr_aryl_methyl', 'narombond', 'fr_alkyl_carbamate', 'fr_piperzine', 'exactmolwt', 'qed', 'chi0n', 'fr_sulfonamd', 'fr_thiazole', 'numvalenceelectrons', 'fr_phos_acid', 'peoe_vsa12', 'fr_nh1', 'fr_hdrzine', 'fr_c_o_nocoo', 'fr_lactone', 'estate_vsa6', 'bcut2d_logphi', 'vsa_estate7', 'peoe_vsa13', 'numsaturatedcarbocycles', 'fr_nitro', 'fr_phenol_noorthohbond', 'rotratio', 'fr_barbitur', 'fr_isothiocyan', 'balabanj', 'fr_arn', 'fr_imine', 'maxpartialcharge', 'fr_sulfide', 'slogp_vsa11', 'fr_hoccn', 'fr_n_o', 'peoe_vsa1', 'slogp_vsa6', 'heavyatommolwt', 'fractioncsp3', 'estate_vsa8', 'peoe_vsa10', 'numaliphaticrings', 'fr_thiophene', 'maxestateindex', 'smr_vsa10', 'labuteasa', 'smr_vsa2', 'fpdensitymorgan3', 'smr_vsa9', 'slogp_vsa10', 'numaromaticheterocycles', 'fr_nh2', 'fr_diazo', 'chi3n', 'fr_ar_coo', 'slogp_vsa5', 'fr_bicyclic', 'fr_amide', 'estate_vsa10', 'fr_guanido', 'chi1n', 'numsaturatedrings', 'fr_piperdine', 'fr_term_acetylene', 'estate_vsa4', 'slogp_vsa3', 'fr_coo', 'fr_ether', 'estate_vsa7', 'bcut2d_chglo', 'fr_oxazole', 'peoe_vsa6', 'hallkieralpha', 'peoe_vsa2', 'chi2v', 'nocount', 'vsa_estate5', 'fr_nhpyrrole', 'fr_al_coo', 'bertzct', 'estate_vsa11', 'minabspartialcharge', 'slogp_vsa8', 'fr_imide', 'kappa2', 'numaliphaticheterocycles', 'numsaturatedheterocycles', 'fr_hdrzone', 'smr_vsa4', 'fr_ar_n', 'nrot', 'smr_vsa8', 'slogp_vsa2', 'chi4v', 'fr_phos_ester', 'fr_para_hydroxylation', 'smr_vsa3', 'nhohcount', 'estate_vsa2', 'mollogp', 'tpsa', 'fr_azide', 'peoe_vsa4', 'numhdonors', 'fr_al_oh_notert', 'fr_c_o', 'chi0', 'fr_nitro_arom_nonortho', 'vsa_estate3', 'fr_benzene', 'fr_ketone', 'vsa_estate8', 'smr_vsa7', 'fr_ar_oh', 'fr_priamide', 'ringcount', 'estate_vsa3', 'numaromaticcarbocycles', 'bcut2d_mwhi', 'chi1v', 'heavyatomcount', 'vsa_estate4', 'chi0v'],
|
|
10
|
-
"target": "udm_asy_res_value",
|
|
11
|
-
"track_columns": None
|
|
12
|
-
}
|
|
13
|
-
|
|
14
|
-
from io import StringIO
|
|
15
|
-
import json
|
|
16
|
-
import argparse
|
|
17
|
-
import os
|
|
18
|
-
import pandas as pd
|
|
19
|
-
|
|
20
|
-
# Local Imports
|
|
21
|
-
from proximity import Proximity
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
# Function to check if dataframe is empty
|
|
25
|
-
def check_dataframe(df: pd.DataFrame, df_name: str) -> None:
|
|
26
|
-
"""Check if the DataFrame is empty and raise an error if so."""
|
|
27
|
-
if df.empty:
|
|
28
|
-
msg = f"*** The training data {df_name} has 0 rows! ***STOPPING***"
|
|
29
|
-
print(msg)
|
|
30
|
-
raise ValueError(msg)
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
# Function to match DataFrame columns to model features (case-insensitive)
|
|
34
|
-
def match_features_case_insensitive(df: pd.DataFrame, model_features: list) -> pd.DataFrame:
|
|
35
|
-
"""Match and rename DataFrame columns to match the model's features, case-insensitively."""
|
|
36
|
-
# Create a set of exact matches from the DataFrame columns
|
|
37
|
-
exact_match_set = set(df.columns)
|
|
38
|
-
|
|
39
|
-
# Create a case-insensitive map of DataFrame columns
|
|
40
|
-
column_map = {col.lower(): col for col in df.columns}
|
|
41
|
-
rename_dict = {}
|
|
42
|
-
|
|
43
|
-
# Build a dictionary for renaming columns based on case-insensitive matching
|
|
44
|
-
for feature in model_features:
|
|
45
|
-
if feature in exact_match_set:
|
|
46
|
-
rename_dict[feature] = feature
|
|
47
|
-
elif feature.lower() in column_map:
|
|
48
|
-
rename_dict[column_map[feature.lower()]] = feature
|
|
49
|
-
|
|
50
|
-
# Rename columns in the DataFrame to match model features
|
|
51
|
-
return df.rename(columns=rename_dict)
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
# TRAINING SECTION
|
|
55
|
-
#
|
|
56
|
-
# This section (__main__) is where SageMaker will execute the training job
|
|
57
|
-
# and save the model artifacts to the model directory.
|
|
58
|
-
#
|
|
59
|
-
if __name__ == "__main__":
|
|
60
|
-
# Template Parameters
|
|
61
|
-
id_column = TEMPLATE_PARAMS["id_column"]
|
|
62
|
-
features = TEMPLATE_PARAMS["features"]
|
|
63
|
-
target = TEMPLATE_PARAMS["target"] # Can be None for unsupervised models
|
|
64
|
-
track_columns = TEMPLATE_PARAMS["track_columns"] # Can be None
|
|
65
|
-
|
|
66
|
-
# Script arguments for input/output directories
|
|
67
|
-
parser = argparse.ArgumentParser()
|
|
68
|
-
parser.add_argument("--model-dir", type=str, default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model"))
|
|
69
|
-
parser.add_argument("--train", type=str, default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"))
|
|
70
|
-
parser.add_argument(
|
|
71
|
-
"--output-data-dir", type=str, default=os.environ.get("SM_OUTPUT_DATA_DIR", "/opt/ml/output/data")
|
|
72
|
-
)
|
|
73
|
-
args = parser.parse_args()
|
|
74
|
-
|
|
75
|
-
# Load training data from the specified directory
|
|
76
|
-
training_files = [
|
|
77
|
-
os.path.join(args.train, file)
|
|
78
|
-
for file in os.listdir(args.train) if file.endswith(".csv")
|
|
79
|
-
]
|
|
80
|
-
all_df = pd.concat([pd.read_csv(file, engine="python") for file in training_files])
|
|
81
|
-
|
|
82
|
-
# Check if the DataFrame is empty
|
|
83
|
-
check_dataframe(all_df, "training_df")
|
|
84
|
-
|
|
85
|
-
# Create the Proximity model
|
|
86
|
-
model = Proximity(all_df, id_column, features, target, track_columns=track_columns)
|
|
87
|
-
|
|
88
|
-
# Now serialize the model
|
|
89
|
-
model.serialize(args.model_dir)
|
|
90
|
-
|
|
91
|
-
# Model loading and prediction functions
|
|
92
|
-
def model_fn(model_dir):
|
|
93
|
-
|
|
94
|
-
# Deserialize the model
|
|
95
|
-
model = Proximity.deserialize(model_dir)
|
|
96
|
-
return model
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
def input_fn(input_data, content_type):
|
|
100
|
-
"""Parse input data and return a DataFrame."""
|
|
101
|
-
if not input_data:
|
|
102
|
-
raise ValueError("Empty input data is not supported!")
|
|
103
|
-
|
|
104
|
-
# Decode bytes to string if necessary
|
|
105
|
-
if isinstance(input_data, bytes):
|
|
106
|
-
input_data = input_data.decode("utf-8")
|
|
107
|
-
|
|
108
|
-
if "text/csv" in content_type:
|
|
109
|
-
return pd.read_csv(StringIO(input_data))
|
|
110
|
-
elif "application/json" in content_type:
|
|
111
|
-
return pd.DataFrame(json.loads(input_data)) # Assumes JSON array of records
|
|
112
|
-
else:
|
|
113
|
-
raise ValueError(f"{content_type} not supported!")
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
def output_fn(output_df, accept_type):
|
|
117
|
-
"""Supports both CSV and JSON output formats."""
|
|
118
|
-
use_explicit_na = False
|
|
119
|
-
if "text/csv" in accept_type:
|
|
120
|
-
if use_explicit_na:
|
|
121
|
-
csv_output = output_df.fillna("N/A").to_csv(index=False) # CSV with N/A for missing values
|
|
122
|
-
else:
|
|
123
|
-
csv_output = output_df.to_csv(index=False)
|
|
124
|
-
return csv_output, "text/csv"
|
|
125
|
-
elif "application/json" in accept_type:
|
|
126
|
-
return output_df.to_json(orient="records"), "application/json" # JSON array of records (NaNs -> null)
|
|
127
|
-
else:
|
|
128
|
-
raise RuntimeError(f"{accept_type} accept type is not supported by this script.")
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
# Prediction function
|
|
132
|
-
def predict_fn(df, model):
|
|
133
|
-
# Match column names before prediction if needed
|
|
134
|
-
df = match_features_case_insensitive(df, model.features + [model.id_column])
|
|
135
|
-
|
|
136
|
-
# Compute Nearest neighbors
|
|
137
|
-
df = model.neighbors(df)
|
|
138
|
-
return df
|
|
@@ -1,393 +0,0 @@
|
|
|
1
|
-
# Model: NGBoost Regressor with Distribution output
|
|
2
|
-
from ngboost import NGBRegressor
|
|
3
|
-
from xgboost import XGBRegressor # Base Estimator
|
|
4
|
-
from sklearn.model_selection import train_test_split
|
|
5
|
-
import numpy as np
|
|
6
|
-
|
|
7
|
-
# Model Performance Scores
|
|
8
|
-
from sklearn.metrics import (
|
|
9
|
-
mean_absolute_error,
|
|
10
|
-
r2_score,
|
|
11
|
-
root_mean_squared_error
|
|
12
|
-
)
|
|
13
|
-
|
|
14
|
-
from io import StringIO
|
|
15
|
-
import json
|
|
16
|
-
import argparse
|
|
17
|
-
import joblib
|
|
18
|
-
import os
|
|
19
|
-
import pandas as pd
|
|
20
|
-
|
|
21
|
-
# Local Imports
|
|
22
|
-
from proximity import Proximity
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
# Template Placeholders
|
|
27
|
-
TEMPLATE_PARAMS = {
|
|
28
|
-
"id_column": "id",
|
|
29
|
-
"features": ['molwt', 'mollogp', 'molmr', 'heavyatomcount', 'numhacceptors', 'numhdonors', 'numheteroatoms', 'numrotatablebonds', 'numvalenceelectrons', 'numaromaticrings', 'numsaturatedrings', 'numaliphaticrings', 'ringcount', 'tpsa', 'labuteasa', 'balabanj', 'bertzct'],
|
|
30
|
-
"target": "solubility",
|
|
31
|
-
"train_all_data": True,
|
|
32
|
-
"track_columns": ['solubility']
|
|
33
|
-
}
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
# Function to check if dataframe is empty
|
|
37
|
-
def check_dataframe(df: pd.DataFrame, df_name: str) -> None:
|
|
38
|
-
"""
|
|
39
|
-
Check if the provided dataframe is empty and raise an exception if it is.
|
|
40
|
-
|
|
41
|
-
Args:
|
|
42
|
-
df (pd.DataFrame): DataFrame to check
|
|
43
|
-
df_name (str): Name of the DataFrame
|
|
44
|
-
"""
|
|
45
|
-
if df.empty:
|
|
46
|
-
msg = f"*** The training data {df_name} has 0 rows! ***STOPPING***"
|
|
47
|
-
print(msg)
|
|
48
|
-
raise ValueError(msg)
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
def match_features_case_insensitive(df: pd.DataFrame, model_features: list) -> pd.DataFrame:
|
|
52
|
-
"""
|
|
53
|
-
Matches and renames DataFrame columns to match model feature names (case-insensitive).
|
|
54
|
-
Prioritizes exact matches, then case-insensitive matches.
|
|
55
|
-
|
|
56
|
-
Raises ValueError if any model features cannot be matched.
|
|
57
|
-
"""
|
|
58
|
-
df_columns_lower = {col.lower(): col for col in df.columns}
|
|
59
|
-
rename_dict = {}
|
|
60
|
-
missing = []
|
|
61
|
-
for feature in model_features:
|
|
62
|
-
if feature in df.columns:
|
|
63
|
-
continue # Exact match
|
|
64
|
-
elif feature.lower() in df_columns_lower:
|
|
65
|
-
rename_dict[df_columns_lower[feature.lower()]] = feature
|
|
66
|
-
else:
|
|
67
|
-
missing.append(feature)
|
|
68
|
-
|
|
69
|
-
if missing:
|
|
70
|
-
raise ValueError(f"Features not found: {missing}")
|
|
71
|
-
|
|
72
|
-
# Rename the DataFrame columns to match the model features
|
|
73
|
-
return df.rename(columns=rename_dict)
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
def distance_weighted_calibrated_intervals(
|
|
77
|
-
df_pred: pd.DataFrame,
|
|
78
|
-
prox_df: pd.DataFrame,
|
|
79
|
-
calibration_strength: float = 0.7,
|
|
80
|
-
distance_decay: float = 3.0,
|
|
81
|
-
) -> pd.DataFrame:
|
|
82
|
-
"""
|
|
83
|
-
Calibrate intervals using distance-weighted neighbor quantiles.
|
|
84
|
-
Uses all 10 neighbors with distance-based weighting.
|
|
85
|
-
"""
|
|
86
|
-
id_column = TEMPLATE_PARAMS["id_column"]
|
|
87
|
-
target_column = TEMPLATE_PARAMS["target"]
|
|
88
|
-
|
|
89
|
-
# Distance-weighted neighbor statistics
|
|
90
|
-
def weighted_quantile(values, weights, q):
|
|
91
|
-
"""Calculate weighted quantile"""
|
|
92
|
-
if len(values) == 0:
|
|
93
|
-
return np.nan
|
|
94
|
-
sorted_indices = np.argsort(values)
|
|
95
|
-
sorted_values = values[sorted_indices]
|
|
96
|
-
sorted_weights = weights[sorted_indices]
|
|
97
|
-
cumsum = np.cumsum(sorted_weights)
|
|
98
|
-
cutoff = q * cumsum[-1]
|
|
99
|
-
return np.interp(cutoff, cumsum, sorted_values)
|
|
100
|
-
|
|
101
|
-
# Calculate distance weights (closer neighbors get more weight)
|
|
102
|
-
prox_df = prox_df.copy()
|
|
103
|
-
prox_df['weight'] = 1 / (1 + prox_df['distance'] ** distance_decay)
|
|
104
|
-
|
|
105
|
-
# Get weighted quantiles and statistics for each ID
|
|
106
|
-
neighbor_stats = []
|
|
107
|
-
for id_val, group in prox_df.groupby(id_column):
|
|
108
|
-
values = group[target_column].values
|
|
109
|
-
weights = group['weight'].values
|
|
110
|
-
|
|
111
|
-
# Normalize weights
|
|
112
|
-
weights = weights / weights.sum()
|
|
113
|
-
|
|
114
|
-
stats = {
|
|
115
|
-
id_column: id_val,
|
|
116
|
-
'local_q025': weighted_quantile(values, weights, 0.025),
|
|
117
|
-
'local_q25': weighted_quantile(values, weights, 0.25),
|
|
118
|
-
'local_q75': weighted_quantile(values, weights, 0.75),
|
|
119
|
-
'local_q975': weighted_quantile(values, weights, 0.975),
|
|
120
|
-
'local_median': weighted_quantile(values, weights, 0.5),
|
|
121
|
-
'local_std': np.sqrt(np.average((values - np.average(values, weights=weights)) ** 2, weights=weights)),
|
|
122
|
-
'avg_distance': group['distance'].mean(),
|
|
123
|
-
'min_distance': group['distance'].min(),
|
|
124
|
-
'max_distance': group['distance'].max(),
|
|
125
|
-
}
|
|
126
|
-
neighbor_stats.append(stats)
|
|
127
|
-
|
|
128
|
-
neighbor_df = pd.DataFrame(neighbor_stats)
|
|
129
|
-
out = df_pred.merge(neighbor_df, on=id_column, how='left')
|
|
130
|
-
|
|
131
|
-
# Model disagreement score (normalized by prediction std)
|
|
132
|
-
model_disagreement = (out["prediction"] - out["prediction_uq"]).abs()
|
|
133
|
-
disagreement_score = (model_disagreement / out["prediction_std"]).clip(0, 2)
|
|
134
|
-
|
|
135
|
-
# Local confidence based on:
|
|
136
|
-
# 1. How close the neighbors are (closer = more confident)
|
|
137
|
-
# 2. How much local variance there is (less variance = more confident)
|
|
138
|
-
max_reasonable_distance = out['max_distance'].quantile(0.8) # 80th percentile as reference
|
|
139
|
-
distance_confidence = (1 - (out['avg_distance'] / max_reasonable_distance)).clip(0.1, 1.0)
|
|
140
|
-
|
|
141
|
-
variance_confidence = (out["prediction_std"] / out["local_std"]).clip(0.5, 2.0)
|
|
142
|
-
local_confidence = distance_confidence * variance_confidence.clip(0.5, 1.5)
|
|
143
|
-
|
|
144
|
-
# Calibration weight: higher when models disagree and we have good local data
|
|
145
|
-
calibration_weight = (
|
|
146
|
-
calibration_strength *
|
|
147
|
-
local_confidence * # Weight by local data quality
|
|
148
|
-
disagreement_score.clip(0.3, 1.0) # More calibration when models disagree
|
|
149
|
-
)
|
|
150
|
-
|
|
151
|
-
# Consensus prediction (slight preference for NGBoost since it provides intervals)
|
|
152
|
-
consensus_pred = 0.65 * out["prediction_uq"] + 0.35 * out["prediction"]
|
|
153
|
-
|
|
154
|
-
# Re-center local intervals around consensus prediction
|
|
155
|
-
local_center_offset = consensus_pred - out["local_median"]
|
|
156
|
-
|
|
157
|
-
# Apply calibration to each quantile
|
|
158
|
-
quantile_pairs = [
|
|
159
|
-
("q_025", "local_q025"),
|
|
160
|
-
("q_25", "local_q25"),
|
|
161
|
-
("q_75", "local_q75"),
|
|
162
|
-
("q_975", "local_q975")
|
|
163
|
-
]
|
|
164
|
-
|
|
165
|
-
for model_q, local_q in quantile_pairs:
|
|
166
|
-
# Adjust local quantiles to be centered around consensus
|
|
167
|
-
adjusted_local_q = out[local_q] + local_center_offset
|
|
168
|
-
|
|
169
|
-
# Blend model and local intervals
|
|
170
|
-
out[model_q] = (
|
|
171
|
-
(1 - calibration_weight) * out[model_q] +
|
|
172
|
-
calibration_weight * adjusted_local_q
|
|
173
|
-
)
|
|
174
|
-
|
|
175
|
-
# Ensure proper interval ordering and bounds using pandas
|
|
176
|
-
out["q_025"] = pd.concat([out["q_025"], consensus_pred], axis=1).min(axis=1)
|
|
177
|
-
out["q_975"] = pd.concat([out["q_975"], consensus_pred], axis=1).max(axis=1)
|
|
178
|
-
out["q_25"] = pd.concat([out["q_25"], out["q_75"]], axis=1).min(axis=1)
|
|
179
|
-
|
|
180
|
-
# Optional: Add some interval expansion when neighbors are very far
|
|
181
|
-
# (indicates we're in a sparse region of feature space)
|
|
182
|
-
sparse_region_mask = out['min_distance'] > out['min_distance'].quantile(0.9)
|
|
183
|
-
expansion_factor = 1 + 0.2 * sparse_region_mask # 20% expansion in sparse regions
|
|
184
|
-
|
|
185
|
-
for q in ["q_025", "q_25", "q_75", "q_975"]:
|
|
186
|
-
interval_width = out[q] - consensus_pred
|
|
187
|
-
out[q] = consensus_pred + interval_width * expansion_factor
|
|
188
|
-
|
|
189
|
-
# Clean up temporary columns
|
|
190
|
-
cleanup_cols = [col for col in out.columns if col.startswith("local_")] + \
|
|
191
|
-
['avg_distance', 'min_distance', 'max_distance']
|
|
192
|
-
|
|
193
|
-
return out.drop(columns=cleanup_cols)
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
# TRAINING SECTION
|
|
197
|
-
#
|
|
198
|
-
# This section (__main__) is where SageMaker will execute the training job
|
|
199
|
-
# and save the model artifacts to the model directory.
|
|
200
|
-
#
|
|
201
|
-
if __name__ == "__main__":
|
|
202
|
-
# Template Parameters
|
|
203
|
-
id_column = TEMPLATE_PARAMS["id_column"]
|
|
204
|
-
features = TEMPLATE_PARAMS["features"]
|
|
205
|
-
target = TEMPLATE_PARAMS["target"]
|
|
206
|
-
train_all_data = TEMPLATE_PARAMS["train_all_data"]
|
|
207
|
-
track_columns = TEMPLATE_PARAMS["track_columns"] # Can be None
|
|
208
|
-
validation_split = 0.2
|
|
209
|
-
|
|
210
|
-
# Script arguments for input/output directories
|
|
211
|
-
parser = argparse.ArgumentParser()
|
|
212
|
-
parser.add_argument("--model-dir", type=str, default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model"))
|
|
213
|
-
parser.add_argument("--train", type=str, default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"))
|
|
214
|
-
parser.add_argument(
|
|
215
|
-
"--output-data-dir", type=str, default=os.environ.get("SM_OUTPUT_DATA_DIR", "/opt/ml/output/data")
|
|
216
|
-
)
|
|
217
|
-
args = parser.parse_args()
|
|
218
|
-
|
|
219
|
-
# Load training data from the specified directory
|
|
220
|
-
training_files = [
|
|
221
|
-
os.path.join(args.train, file)
|
|
222
|
-
for file in os.listdir(args.train) if file.endswith(".csv")
|
|
223
|
-
]
|
|
224
|
-
print(f"Training Files: {training_files}")
|
|
225
|
-
|
|
226
|
-
# Combine files and read them all into a single pandas dataframe
|
|
227
|
-
df = pd.concat([pd.read_csv(file, engine="python") for file in training_files])
|
|
228
|
-
|
|
229
|
-
# Check if the DataFrame is empty
|
|
230
|
-
check_dataframe(df, "training_df")
|
|
231
|
-
|
|
232
|
-
# Training data split logic
|
|
233
|
-
if train_all_data:
|
|
234
|
-
# Use all data for both training and validation
|
|
235
|
-
print("Training on all data...")
|
|
236
|
-
df_train = df.copy()
|
|
237
|
-
df_val = df.copy()
|
|
238
|
-
elif "training" in df.columns:
|
|
239
|
-
# Split data based on a 'training' column if it exists
|
|
240
|
-
print("Splitting data based on 'training' column...")
|
|
241
|
-
df_train = df[df["training"]].copy()
|
|
242
|
-
df_val = df[~df["training"]].copy()
|
|
243
|
-
else:
|
|
244
|
-
# Perform a random split if no 'training' column is found
|
|
245
|
-
print("Splitting data randomly...")
|
|
246
|
-
df_train, df_val = train_test_split(df, test_size=validation_split, random_state=42)
|
|
247
|
-
|
|
248
|
-
# We're using XGBoost for point predictions and NGBoost for uncertainty quantification
|
|
249
|
-
xgb_model = XGBRegressor()
|
|
250
|
-
ngb_model = NGBRegressor()
|
|
251
|
-
|
|
252
|
-
# Prepare features and targets for training
|
|
253
|
-
X_train = df_train[features]
|
|
254
|
-
X_val = df_val[features]
|
|
255
|
-
y_train = df_train[target]
|
|
256
|
-
y_val = df_val[target]
|
|
257
|
-
|
|
258
|
-
# Train both models using the training data
|
|
259
|
-
xgb_model.fit(X_train, y_train)
|
|
260
|
-
ngb_model.fit(X_train, y_train, X_val=X_val, Y_val=y_val)
|
|
261
|
-
|
|
262
|
-
# Make Predictions on the Validation Set
|
|
263
|
-
print(f"Making Predictions on Validation Set...")
|
|
264
|
-
y_validate = df_val[target]
|
|
265
|
-
X_validate = df_val[features]
|
|
266
|
-
preds = xgb_model.predict(X_validate)
|
|
267
|
-
|
|
268
|
-
# Calculate various model performance metrics (regression)
|
|
269
|
-
rmse = root_mean_squared_error(y_validate, preds)
|
|
270
|
-
mae = mean_absolute_error(y_validate, preds)
|
|
271
|
-
r2 = r2_score(y_validate, preds)
|
|
272
|
-
print(f"RMSE: {rmse:.3f}")
|
|
273
|
-
print(f"MAE: {mae:.3f}")
|
|
274
|
-
print(f"R2: {r2:.3f}")
|
|
275
|
-
print(f"NumRows: {len(df_val)}")
|
|
276
|
-
|
|
277
|
-
# Save the trained XGBoost model
|
|
278
|
-
xgb_model.save_model(os.path.join(args.model_dir, "xgb_model.json"))
|
|
279
|
-
|
|
280
|
-
# Save the trained NGBoost model
|
|
281
|
-
joblib.dump(ngb_model, os.path.join(args.model_dir, "ngb_model.joblib"))
|
|
282
|
-
|
|
283
|
-
# Save the feature list to validate input during predictions
|
|
284
|
-
with open(os.path.join(args.model_dir, "feature_columns.json"), "w") as fp:
|
|
285
|
-
json.dump(features, fp)
|
|
286
|
-
|
|
287
|
-
# Now the Proximity model
|
|
288
|
-
model = Proximity(df_train, id_column, features, target, track_columns=track_columns)
|
|
289
|
-
|
|
290
|
-
# Now serialize the model
|
|
291
|
-
model.serialize(args.model_dir)
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
#
|
|
295
|
-
# Inference Section
|
|
296
|
-
#
|
|
297
|
-
def model_fn(model_dir) -> dict:
|
|
298
|
-
"""Load and return XGBoost and NGBoost regressors from model directory."""
|
|
299
|
-
|
|
300
|
-
# Load XGBoost regressor
|
|
301
|
-
xgb_path = os.path.join(model_dir, "xgb_model.json")
|
|
302
|
-
xgb_model = XGBRegressor(enable_categorical=True)
|
|
303
|
-
xgb_model.load_model(xgb_path)
|
|
304
|
-
|
|
305
|
-
# Load NGBoost regressor
|
|
306
|
-
ngb_model = joblib.load(os.path.join(model_dir, "ngb_model.joblib"))
|
|
307
|
-
|
|
308
|
-
# Deserialize the proximity model
|
|
309
|
-
prox_model = Proximity.deserialize(model_dir)
|
|
310
|
-
|
|
311
|
-
return {
|
|
312
|
-
"xgboost": xgb_model,
|
|
313
|
-
"ngboost": ngb_model,
|
|
314
|
-
"proximity": prox_model
|
|
315
|
-
}
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
def input_fn(input_data, content_type):
|
|
319
|
-
"""Parse input data and return a DataFrame."""
|
|
320
|
-
if not input_data:
|
|
321
|
-
raise ValueError("Empty input data is not supported!")
|
|
322
|
-
|
|
323
|
-
# Decode bytes to string if necessary
|
|
324
|
-
if isinstance(input_data, bytes):
|
|
325
|
-
input_data = input_data.decode("utf-8")
|
|
326
|
-
|
|
327
|
-
if "text/csv" in content_type:
|
|
328
|
-
return pd.read_csv(StringIO(input_data))
|
|
329
|
-
elif "application/json" in content_type:
|
|
330
|
-
return pd.DataFrame(json.loads(input_data)) # Assumes JSON array of records
|
|
331
|
-
else:
|
|
332
|
-
raise ValueError(f"{content_type} not supported!")
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
def output_fn(output_df, accept_type):
|
|
336
|
-
"""Supports both CSV and JSON output formats."""
|
|
337
|
-
if "text/csv" in accept_type:
|
|
338
|
-
csv_output = output_df.fillna("N/A").to_csv(index=False) # CSV with N/A for missing values
|
|
339
|
-
return csv_output, "text/csv"
|
|
340
|
-
elif "application/json" in accept_type:
|
|
341
|
-
return output_df.to_json(orient="records"), "application/json" # JSON array of records (NaNs -> null)
|
|
342
|
-
else:
|
|
343
|
-
raise RuntimeError(f"{accept_type} accept type is not supported by this script.")
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
def predict_fn(df, models) -> pd.DataFrame:
|
|
347
|
-
"""Make Predictions with our XGB Quantile Regression Model
|
|
348
|
-
|
|
349
|
-
Args:
|
|
350
|
-
df (pd.DataFrame): The input DataFrame
|
|
351
|
-
models (dict): The dictionary of models to use for predictions
|
|
352
|
-
|
|
353
|
-
Returns:
|
|
354
|
-
pd.DataFrame: The DataFrame with the predictions added
|
|
355
|
-
"""
|
|
356
|
-
|
|
357
|
-
# Grab our feature columns (from training)
|
|
358
|
-
model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
|
|
359
|
-
with open(os.path.join(model_dir, "feature_columns.json")) as fp:
|
|
360
|
-
model_features = json.load(fp)
|
|
361
|
-
|
|
362
|
-
# Match features in a case-insensitive manner
|
|
363
|
-
matched_df = match_features_case_insensitive(df, model_features)
|
|
364
|
-
|
|
365
|
-
# Use XGBoost for point predictions
|
|
366
|
-
df["prediction"] = models["xgboost"].predict(matched_df[model_features])
|
|
367
|
-
|
|
368
|
-
# NGBoost predict returns distribution objects
|
|
369
|
-
y_dists = models["ngboost"].pred_dist(matched_df[model_features])
|
|
370
|
-
|
|
371
|
-
# Extract parameters from distribution
|
|
372
|
-
dist_params = y_dists.params
|
|
373
|
-
|
|
374
|
-
# Extract mean and std from distribution parameters
|
|
375
|
-
df["prediction_uq"] = dist_params['loc'] # mean
|
|
376
|
-
df["prediction_std"] = dist_params['scale'] # standard deviation
|
|
377
|
-
|
|
378
|
-
# Add 95% prediction intervals using ppf (percent point function)
|
|
379
|
-
df["q_025"] = y_dists.ppf(0.025) # 2.5th percentile
|
|
380
|
-
df["q_975"] = y_dists.ppf(0.975) # 97.5th percentile
|
|
381
|
-
|
|
382
|
-
# Add 50% prediction intervals
|
|
383
|
-
df["q_25"] = y_dists.ppf(0.25) # 25th percentile
|
|
384
|
-
df["q_75"] = y_dists.ppf(0.75) # 75th percentile
|
|
385
|
-
|
|
386
|
-
# Compute Nearest neighbors with Proximity model
|
|
387
|
-
prox_df = models["proximity"].neighbors(df)
|
|
388
|
-
|
|
389
|
-
# Shrink prediction intervals based on KNN variance
|
|
390
|
-
df = distance_weighted_calibrated_intervals(df, prox_df)
|
|
391
|
-
|
|
392
|
-
# Return the modified DataFrame
|
|
393
|
-
return df
|