workbench 0.8.162__py3-none-any.whl → 0.8.202__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (113) hide show
  1. workbench/algorithms/dataframe/__init__.py +1 -2
  2. workbench/algorithms/dataframe/fingerprint_proximity.py +2 -2
  3. workbench/algorithms/dataframe/proximity.py +261 -235
  4. workbench/algorithms/graph/light/proximity_graph.py +10 -8
  5. workbench/api/__init__.py +2 -1
  6. workbench/api/compound.py +1 -1
  7. workbench/api/endpoint.py +11 -0
  8. workbench/api/feature_set.py +11 -8
  9. workbench/api/meta.py +5 -2
  10. workbench/api/model.py +16 -15
  11. workbench/api/monitor.py +1 -16
  12. workbench/core/artifacts/__init__.py +11 -2
  13. workbench/core/artifacts/artifact.py +11 -3
  14. workbench/core/artifacts/data_capture_core.py +355 -0
  15. workbench/core/artifacts/endpoint_core.py +256 -118
  16. workbench/core/artifacts/feature_set_core.py +265 -16
  17. workbench/core/artifacts/model_core.py +107 -60
  18. workbench/core/artifacts/monitor_core.py +33 -248
  19. workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
  20. workbench/core/cloud_platform/aws/aws_meta.py +12 -5
  21. workbench/core/cloud_platform/aws/aws_parameter_store.py +18 -2
  22. workbench/core/cloud_platform/aws/aws_session.py +4 -4
  23. workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
  24. workbench/core/transforms/features_to_model/features_to_model.py +42 -32
  25. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +36 -6
  26. workbench/core/transforms/pandas_transforms/pandas_to_features.py +27 -0
  27. workbench/core/views/training_view.py +113 -42
  28. workbench/core/views/view.py +53 -3
  29. workbench/core/views/view_utils.py +4 -4
  30. workbench/model_scripts/chemprop/chemprop.template +852 -0
  31. workbench/model_scripts/chemprop/generated_model_script.py +852 -0
  32. workbench/model_scripts/chemprop/requirements.txt +11 -0
  33. workbench/model_scripts/custom_models/chem_info/fingerprints.py +134 -0
  34. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
  35. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
  36. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
  37. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
  38. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +3 -5
  39. workbench/model_scripts/custom_models/proximity/proximity.py +261 -235
  40. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  41. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  42. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  43. workbench/model_scripts/custom_models/uq_models/meta_uq.template +166 -62
  44. workbench/model_scripts/custom_models/uq_models/ngboost.template +30 -18
  45. workbench/model_scripts/custom_models/uq_models/proximity.py +261 -235
  46. workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
  47. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  48. workbench/model_scripts/pytorch_model/generated_model_script.py +373 -190
  49. workbench/model_scripts/pytorch_model/pytorch.template +370 -187
  50. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  51. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  52. workbench/model_scripts/script_generation.py +17 -9
  53. workbench/model_scripts/uq_models/generated_model_script.py +605 -0
  54. workbench/model_scripts/uq_models/mapie.template +605 -0
  55. workbench/model_scripts/uq_models/requirements.txt +1 -0
  56. workbench/model_scripts/xgb_model/generated_model_script.py +37 -46
  57. workbench/model_scripts/xgb_model/xgb_model.template +44 -46
  58. workbench/repl/workbench_shell.py +28 -14
  59. workbench/scripts/endpoint_test.py +162 -0
  60. workbench/scripts/lambda_test.py +73 -0
  61. workbench/scripts/ml_pipeline_batch.py +137 -0
  62. workbench/scripts/ml_pipeline_sqs.py +186 -0
  63. workbench/scripts/monitor_cloud_watch.py +20 -100
  64. workbench/utils/aws_utils.py +4 -3
  65. workbench/utils/chem_utils/__init__.py +0 -0
  66. workbench/utils/chem_utils/fingerprints.py +134 -0
  67. workbench/utils/chem_utils/misc.py +194 -0
  68. workbench/utils/chem_utils/mol_descriptors.py +483 -0
  69. workbench/utils/chem_utils/mol_standardize.py +450 -0
  70. workbench/utils/chem_utils/mol_tagging.py +348 -0
  71. workbench/utils/chem_utils/projections.py +209 -0
  72. workbench/utils/chem_utils/salts.py +256 -0
  73. workbench/utils/chem_utils/sdf.py +292 -0
  74. workbench/utils/chem_utils/toxicity.py +250 -0
  75. workbench/utils/chem_utils/vis.py +253 -0
  76. workbench/utils/chemprop_utils.py +760 -0
  77. workbench/utils/cloudwatch_handler.py +1 -1
  78. workbench/utils/cloudwatch_utils.py +137 -0
  79. workbench/utils/config_manager.py +3 -7
  80. workbench/utils/endpoint_utils.py +5 -7
  81. workbench/utils/license_manager.py +2 -6
  82. workbench/utils/model_utils.py +95 -34
  83. workbench/utils/monitor_utils.py +44 -62
  84. workbench/utils/pandas_utils.py +3 -3
  85. workbench/utils/pytorch_utils.py +526 -0
  86. workbench/utils/shap_utils.py +10 -2
  87. workbench/utils/workbench_logging.py +0 -3
  88. workbench/utils/workbench_sqs.py +1 -1
  89. workbench/utils/xgboost_model_utils.py +371 -156
  90. workbench/web_interface/components/model_plot.py +7 -1
  91. workbench/web_interface/components/plugin_unit_test.py +5 -2
  92. workbench/web_interface/components/plugins/dashboard_status.py +3 -1
  93. workbench/web_interface/components/plugins/generated_compounds.py +1 -1
  94. workbench/web_interface/components/plugins/model_details.py +9 -7
  95. workbench/web_interface/components/plugins/scatter_plot.py +3 -3
  96. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/METADATA +27 -6
  97. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/RECORD +101 -85
  98. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/entry_points.txt +4 -0
  99. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/licenses/LICENSE +1 -1
  100. workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
  101. workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
  102. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  103. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  104. workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
  105. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  106. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  107. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  108. workbench/utils/chem_utils.py +0 -1556
  109. workbench/utils/execution_environment.py +0 -211
  110. workbench/utils/fast_inference.py +0 -167
  111. workbench/utils/resource_utils.py +0 -39
  112. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/WHEEL +0 -0
  113. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/top_level.txt +0 -0
@@ -1,83 +0,0 @@
1
- # Model: tautomerization_processor
2
- #
3
- # Description: The tautomerization_processor model uses RDKit to perform tautomer enumeration
4
- # and canonicalization of chemical compounds. Tautomerization is the chemical process where
5
- # compounds can interconvert between structurally distinct forms, often affecting their
6
- # chemical properties and reactivity. This model provides a robust approach to identifying
7
- # and processing tautomers, crucial for improving molecular modeling and cheminformatics tasks
8
- # like virtual screening, QSAR modeling, and property prediction.
9
- #
10
- import argparse
11
- import os
12
- import joblib
13
- from io import StringIO
14
- import pandas as pd
15
- import json
16
-
17
- # Local imports
18
- from local_utils import tautomerize_smiles
19
-
20
-
21
- # TRAINING SECTION
22
- #
23
- # This section (__main__) is where SageMaker will execute the job and save the model artifacts.
24
- #
25
- if __name__ == "__main__":
26
- # Script arguments for input/output directories
27
- parser = argparse.ArgumentParser()
28
- parser.add_argument("--model-dir", type=str, default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model"))
29
- parser.add_argument("--train", type=str, default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"))
30
- parser.add_argument(
31
- "--output-data-dir", type=str, default=os.environ.get("SM_OUTPUT_DATA_DIR", "/opt/ml/output/data")
32
- )
33
- args = parser.parse_args()
34
-
35
- # This model doesn't get trained; it's a feature processing 'model'
36
-
37
- # Sagemaker expects a model artifact, so we'll save a placeholder
38
- placeholder_model = {}
39
- joblib.dump(placeholder_model, os.path.join(args.model_dir, "model.joblib"))
40
-
41
-
42
- # Model loading and prediction functions
43
- def model_fn(model_dir):
44
- return joblib.load(os.path.join(model_dir, "model.joblib"))
45
-
46
-
47
- def input_fn(input_data, content_type):
48
- """Parse input data and return a DataFrame."""
49
- if not input_data:
50
- raise ValueError("Empty input data is not supported!")
51
-
52
- # Decode bytes to string if necessary
53
- if isinstance(input_data, bytes):
54
- input_data = input_data.decode("utf-8")
55
-
56
- if "text/csv" in content_type:
57
- return pd.read_csv(StringIO(input_data))
58
- elif "application/json" in content_type:
59
- return pd.DataFrame(json.loads(input_data)) # Assumes JSON array of records
60
- else:
61
- raise ValueError(f"{content_type} not supported!")
62
-
63
-
64
- def output_fn(output_df, accept_type):
65
- """Supports both CSV and JSON output formats."""
66
- use_explicit_na = False
67
- if "text/csv" in accept_type:
68
- if use_explicit_na:
69
- csv_output = output_df.fillna("N/A").to_csv(index=False) # CSV with N/A for missing values
70
- else:
71
- csv_output = output_df.to_csv(index=False)
72
- return csv_output, "text/csv"
73
- elif "application/json" in accept_type:
74
- return output_df.to_json(orient="records"), "application/json" # JSON array of records (NaNs -> null)
75
- else:
76
- raise RuntimeError(f"{accept_type} accept type is not supported by this script.")
77
-
78
-
79
- # Prediction function
80
- def predict_fn(df, model):
81
- # Perform Tautomerization
82
- df = tautomerize_smiles(df)
83
- return df
@@ -1,138 +0,0 @@
1
- # Model: feature_space_proximity
2
- #
3
- # Description: The feature_space_proximity model computes nearest neighbors for the given feature space
4
- #
5
-
6
- # Template Placeholders
7
- TEMPLATE_PARAMS = {
8
- "id_column": "udm_mol_bat_id",
9
- "features": ['bcut2d_logplow', 'numradicalelectrons', 'smr_vsa5', 'fr_lactam', 'fr_morpholine', 'fr_aldehyde', 'slogp_vsa1', 'fr_amidine', 'bpol', 'fr_ester', 'fr_azo', 'kappa3', 'peoe_vsa5', 'fr_ketone_topliss', 'vsa_estate9', 'estate_vsa9', 'bcut2d_mrhi', 'fr_ndealkylation1', 'numrotatablebonds', 'minestateindex', 'fr_quatn', 'peoe_vsa3', 'fr_epoxide', 'fr_aniline', 'minpartialcharge', 'fr_nitroso', 'fpdensitymorgan2', 'fr_oxime', 'fr_sulfone', 'smr_vsa1', 'kappa1', 'fr_pyridine', 'numaromaticrings', 'vsa_estate6', 'molmr', 'estate_vsa1', 'fr_dihydropyridine', 'vsa_estate10', 'fr_alkyl_halide', 'chi2n', 'fr_thiocyan', 'fpdensitymorgan1', 'fr_unbrch_alkane', 'slogp_vsa9', 'chi4n', 'fr_nitro_arom', 'fr_al_oh', 'fr_furan', 'fr_c_s', 'peoe_vsa8', 'peoe_vsa14', 'numheteroatoms', 'fr_ndealkylation2', 'maxabspartialcharge', 'vsa_estate2', 'peoe_vsa7', 'apol', 'numhacceptors', 'fr_tetrazole', 'vsa_estate1', 'peoe_vsa9', 'naromatom', 'bcut2d_chghi', 'fr_sh', 'fr_halogen', 'slogp_vsa4', 'fr_benzodiazepine', 'molwt', 'fr_isocyan', 'fr_prisulfonamd', 'maxabsestateindex', 'minabsestateindex', 'peoe_vsa11', 'slogp_vsa12', 'estate_vsa5', 'numaliphaticcarbocycles', 'bcut2d_mwlow', 'slogp_vsa7', 'fr_allylic_oxid', 'fr_methoxy', 'fr_nh0', 'fr_coo2', 'fr_phenol', 'nacid', 'nbase', 'chi3v', 'fr_ar_nh', 'fr_nitrile', 'fr_imidazole', 'fr_urea', 'bcut2d_mrlow', 'chi1', 'smr_vsa6', 'fr_aryl_methyl', 'narombond', 'fr_alkyl_carbamate', 'fr_piperzine', 'exactmolwt', 'qed', 'chi0n', 'fr_sulfonamd', 'fr_thiazole', 'numvalenceelectrons', 'fr_phos_acid', 'peoe_vsa12', 'fr_nh1', 'fr_hdrzine', 'fr_c_o_nocoo', 'fr_lactone', 'estate_vsa6', 'bcut2d_logphi', 'vsa_estate7', 'peoe_vsa13', 'numsaturatedcarbocycles', 'fr_nitro', 'fr_phenol_noorthohbond', 'rotratio', 'fr_barbitur', 'fr_isothiocyan', 'balabanj', 'fr_arn', 'fr_imine', 'maxpartialcharge', 'fr_sulfide', 'slogp_vsa11', 'fr_hoccn', 'fr_n_o', 'peoe_vsa1', 'slogp_vsa6', 'heavyatommolwt', 'fractioncsp3', 'estate_vsa8', 'peoe_vsa10', 'numaliphaticrings', 'fr_thiophene', 'maxestateindex', 'smr_vsa10', 'labuteasa', 'smr_vsa2', 'fpdensitymorgan3', 'smr_vsa9', 'slogp_vsa10', 'numaromaticheterocycles', 'fr_nh2', 'fr_diazo', 'chi3n', 'fr_ar_coo', 'slogp_vsa5', 'fr_bicyclic', 'fr_amide', 'estate_vsa10', 'fr_guanido', 'chi1n', 'numsaturatedrings', 'fr_piperdine', 'fr_term_acetylene', 'estate_vsa4', 'slogp_vsa3', 'fr_coo', 'fr_ether', 'estate_vsa7', 'bcut2d_chglo', 'fr_oxazole', 'peoe_vsa6', 'hallkieralpha', 'peoe_vsa2', 'chi2v', 'nocount', 'vsa_estate5', 'fr_nhpyrrole', 'fr_al_coo', 'bertzct', 'estate_vsa11', 'minabspartialcharge', 'slogp_vsa8', 'fr_imide', 'kappa2', 'numaliphaticheterocycles', 'numsaturatedheterocycles', 'fr_hdrzone', 'smr_vsa4', 'fr_ar_n', 'nrot', 'smr_vsa8', 'slogp_vsa2', 'chi4v', 'fr_phos_ester', 'fr_para_hydroxylation', 'smr_vsa3', 'nhohcount', 'estate_vsa2', 'mollogp', 'tpsa', 'fr_azide', 'peoe_vsa4', 'numhdonors', 'fr_al_oh_notert', 'fr_c_o', 'chi0', 'fr_nitro_arom_nonortho', 'vsa_estate3', 'fr_benzene', 'fr_ketone', 'vsa_estate8', 'smr_vsa7', 'fr_ar_oh', 'fr_priamide', 'ringcount', 'estate_vsa3', 'numaromaticcarbocycles', 'bcut2d_mwhi', 'chi1v', 'heavyatomcount', 'vsa_estate4', 'chi0v'],
10
- "target": "udm_asy_res_value",
11
- "track_columns": None
12
- }
13
-
14
- from io import StringIO
15
- import json
16
- import argparse
17
- import os
18
- import pandas as pd
19
-
20
- # Local Imports
21
- from proximity import Proximity
22
-
23
-
24
- # Function to check if dataframe is empty
25
- def check_dataframe(df: pd.DataFrame, df_name: str) -> None:
26
- """Check if the DataFrame is empty and raise an error if so."""
27
- if df.empty:
28
- msg = f"*** The training data {df_name} has 0 rows! ***STOPPING***"
29
- print(msg)
30
- raise ValueError(msg)
31
-
32
-
33
- # Function to match DataFrame columns to model features (case-insensitive)
34
- def match_features_case_insensitive(df: pd.DataFrame, model_features: list) -> pd.DataFrame:
35
- """Match and rename DataFrame columns to match the model's features, case-insensitively."""
36
- # Create a set of exact matches from the DataFrame columns
37
- exact_match_set = set(df.columns)
38
-
39
- # Create a case-insensitive map of DataFrame columns
40
- column_map = {col.lower(): col for col in df.columns}
41
- rename_dict = {}
42
-
43
- # Build a dictionary for renaming columns based on case-insensitive matching
44
- for feature in model_features:
45
- if feature in exact_match_set:
46
- rename_dict[feature] = feature
47
- elif feature.lower() in column_map:
48
- rename_dict[column_map[feature.lower()]] = feature
49
-
50
- # Rename columns in the DataFrame to match model features
51
- return df.rename(columns=rename_dict)
52
-
53
-
54
- # TRAINING SECTION
55
- #
56
- # This section (__main__) is where SageMaker will execute the training job
57
- # and save the model artifacts to the model directory.
58
- #
59
- if __name__ == "__main__":
60
- # Template Parameters
61
- id_column = TEMPLATE_PARAMS["id_column"]
62
- features = TEMPLATE_PARAMS["features"]
63
- target = TEMPLATE_PARAMS["target"] # Can be None for unsupervised models
64
- track_columns = TEMPLATE_PARAMS["track_columns"] # Can be None
65
-
66
- # Script arguments for input/output directories
67
- parser = argparse.ArgumentParser()
68
- parser.add_argument("--model-dir", type=str, default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model"))
69
- parser.add_argument("--train", type=str, default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"))
70
- parser.add_argument(
71
- "--output-data-dir", type=str, default=os.environ.get("SM_OUTPUT_DATA_DIR", "/opt/ml/output/data")
72
- )
73
- args = parser.parse_args()
74
-
75
- # Load training data from the specified directory
76
- training_files = [
77
- os.path.join(args.train, file)
78
- for file in os.listdir(args.train) if file.endswith(".csv")
79
- ]
80
- all_df = pd.concat([pd.read_csv(file, engine="python") for file in training_files])
81
-
82
- # Check if the DataFrame is empty
83
- check_dataframe(all_df, "training_df")
84
-
85
- # Create the Proximity model
86
- model = Proximity(all_df, id_column, features, target, track_columns=track_columns)
87
-
88
- # Now serialize the model
89
- model.serialize(args.model_dir)
90
-
91
- # Model loading and prediction functions
92
- def model_fn(model_dir):
93
-
94
- # Deserialize the model
95
- model = Proximity.deserialize(model_dir)
96
- return model
97
-
98
-
99
- def input_fn(input_data, content_type):
100
- """Parse input data and return a DataFrame."""
101
- if not input_data:
102
- raise ValueError("Empty input data is not supported!")
103
-
104
- # Decode bytes to string if necessary
105
- if isinstance(input_data, bytes):
106
- input_data = input_data.decode("utf-8")
107
-
108
- if "text/csv" in content_type:
109
- return pd.read_csv(StringIO(input_data))
110
- elif "application/json" in content_type:
111
- return pd.DataFrame(json.loads(input_data)) # Assumes JSON array of records
112
- else:
113
- raise ValueError(f"{content_type} not supported!")
114
-
115
-
116
- def output_fn(output_df, accept_type):
117
- """Supports both CSV and JSON output formats."""
118
- use_explicit_na = False
119
- if "text/csv" in accept_type:
120
- if use_explicit_na:
121
- csv_output = output_df.fillna("N/A").to_csv(index=False) # CSV with N/A for missing values
122
- else:
123
- csv_output = output_df.to_csv(index=False)
124
- return csv_output, "text/csv"
125
- elif "application/json" in accept_type:
126
- return output_df.to_json(orient="records"), "application/json" # JSON array of records (NaNs -> null)
127
- else:
128
- raise RuntimeError(f"{accept_type} accept type is not supported by this script.")
129
-
130
-
131
- # Prediction function
132
- def predict_fn(df, model):
133
- # Match column names before prediction if needed
134
- df = match_features_case_insensitive(df, model.features + [model.id_column])
135
-
136
- # Compute Nearest neighbors
137
- df = model.neighbors(df)
138
- return df
@@ -1,393 +0,0 @@
1
- # Model: NGBoost Regressor with Distribution output
2
- from ngboost import NGBRegressor
3
- from xgboost import XGBRegressor # Base Estimator
4
- from sklearn.model_selection import train_test_split
5
- import numpy as np
6
-
7
- # Model Performance Scores
8
- from sklearn.metrics import (
9
- mean_absolute_error,
10
- r2_score,
11
- root_mean_squared_error
12
- )
13
-
14
- from io import StringIO
15
- import json
16
- import argparse
17
- import joblib
18
- import os
19
- import pandas as pd
20
-
21
- # Local Imports
22
- from proximity import Proximity
23
-
24
-
25
-
26
- # Template Placeholders
27
- TEMPLATE_PARAMS = {
28
- "id_column": "id",
29
- "features": ['molwt', 'mollogp', 'molmr', 'heavyatomcount', 'numhacceptors', 'numhdonors', 'numheteroatoms', 'numrotatablebonds', 'numvalenceelectrons', 'numaromaticrings', 'numsaturatedrings', 'numaliphaticrings', 'ringcount', 'tpsa', 'labuteasa', 'balabanj', 'bertzct'],
30
- "target": "solubility",
31
- "train_all_data": True,
32
- "track_columns": ['solubility']
33
- }
34
-
35
-
36
- # Function to check if dataframe is empty
37
- def check_dataframe(df: pd.DataFrame, df_name: str) -> None:
38
- """
39
- Check if the provided dataframe is empty and raise an exception if it is.
40
-
41
- Args:
42
- df (pd.DataFrame): DataFrame to check
43
- df_name (str): Name of the DataFrame
44
- """
45
- if df.empty:
46
- msg = f"*** The training data {df_name} has 0 rows! ***STOPPING***"
47
- print(msg)
48
- raise ValueError(msg)
49
-
50
-
51
- def match_features_case_insensitive(df: pd.DataFrame, model_features: list) -> pd.DataFrame:
52
- """
53
- Matches and renames DataFrame columns to match model feature names (case-insensitive).
54
- Prioritizes exact matches, then case-insensitive matches.
55
-
56
- Raises ValueError if any model features cannot be matched.
57
- """
58
- df_columns_lower = {col.lower(): col for col in df.columns}
59
- rename_dict = {}
60
- missing = []
61
- for feature in model_features:
62
- if feature in df.columns:
63
- continue # Exact match
64
- elif feature.lower() in df_columns_lower:
65
- rename_dict[df_columns_lower[feature.lower()]] = feature
66
- else:
67
- missing.append(feature)
68
-
69
- if missing:
70
- raise ValueError(f"Features not found: {missing}")
71
-
72
- # Rename the DataFrame columns to match the model features
73
- return df.rename(columns=rename_dict)
74
-
75
-
76
- def distance_weighted_calibrated_intervals(
77
- df_pred: pd.DataFrame,
78
- prox_df: pd.DataFrame,
79
- calibration_strength: float = 0.7,
80
- distance_decay: float = 3.0,
81
- ) -> pd.DataFrame:
82
- """
83
- Calibrate intervals using distance-weighted neighbor quantiles.
84
- Uses all 10 neighbors with distance-based weighting.
85
- """
86
- id_column = TEMPLATE_PARAMS["id_column"]
87
- target_column = TEMPLATE_PARAMS["target"]
88
-
89
- # Distance-weighted neighbor statistics
90
- def weighted_quantile(values, weights, q):
91
- """Calculate weighted quantile"""
92
- if len(values) == 0:
93
- return np.nan
94
- sorted_indices = np.argsort(values)
95
- sorted_values = values[sorted_indices]
96
- sorted_weights = weights[sorted_indices]
97
- cumsum = np.cumsum(sorted_weights)
98
- cutoff = q * cumsum[-1]
99
- return np.interp(cutoff, cumsum, sorted_values)
100
-
101
- # Calculate distance weights (closer neighbors get more weight)
102
- prox_df = prox_df.copy()
103
- prox_df['weight'] = 1 / (1 + prox_df['distance'] ** distance_decay)
104
-
105
- # Get weighted quantiles and statistics for each ID
106
- neighbor_stats = []
107
- for id_val, group in prox_df.groupby(id_column):
108
- values = group[target_column].values
109
- weights = group['weight'].values
110
-
111
- # Normalize weights
112
- weights = weights / weights.sum()
113
-
114
- stats = {
115
- id_column: id_val,
116
- 'local_q025': weighted_quantile(values, weights, 0.025),
117
- 'local_q25': weighted_quantile(values, weights, 0.25),
118
- 'local_q75': weighted_quantile(values, weights, 0.75),
119
- 'local_q975': weighted_quantile(values, weights, 0.975),
120
- 'local_median': weighted_quantile(values, weights, 0.5),
121
- 'local_std': np.sqrt(np.average((values - np.average(values, weights=weights)) ** 2, weights=weights)),
122
- 'avg_distance': group['distance'].mean(),
123
- 'min_distance': group['distance'].min(),
124
- 'max_distance': group['distance'].max(),
125
- }
126
- neighbor_stats.append(stats)
127
-
128
- neighbor_df = pd.DataFrame(neighbor_stats)
129
- out = df_pred.merge(neighbor_df, on=id_column, how='left')
130
-
131
- # Model disagreement score (normalized by prediction std)
132
- model_disagreement = (out["prediction"] - out["prediction_uq"]).abs()
133
- disagreement_score = (model_disagreement / out["prediction_std"]).clip(0, 2)
134
-
135
- # Local confidence based on:
136
- # 1. How close the neighbors are (closer = more confident)
137
- # 2. How much local variance there is (less variance = more confident)
138
- max_reasonable_distance = out['max_distance'].quantile(0.8) # 80th percentile as reference
139
- distance_confidence = (1 - (out['avg_distance'] / max_reasonable_distance)).clip(0.1, 1.0)
140
-
141
- variance_confidence = (out["prediction_std"] / out["local_std"]).clip(0.5, 2.0)
142
- local_confidence = distance_confidence * variance_confidence.clip(0.5, 1.5)
143
-
144
- # Calibration weight: higher when models disagree and we have good local data
145
- calibration_weight = (
146
- calibration_strength *
147
- local_confidence * # Weight by local data quality
148
- disagreement_score.clip(0.3, 1.0) # More calibration when models disagree
149
- )
150
-
151
- # Consensus prediction (slight preference for NGBoost since it provides intervals)
152
- consensus_pred = 0.65 * out["prediction_uq"] + 0.35 * out["prediction"]
153
-
154
- # Re-center local intervals around consensus prediction
155
- local_center_offset = consensus_pred - out["local_median"]
156
-
157
- # Apply calibration to each quantile
158
- quantile_pairs = [
159
- ("q_025", "local_q025"),
160
- ("q_25", "local_q25"),
161
- ("q_75", "local_q75"),
162
- ("q_975", "local_q975")
163
- ]
164
-
165
- for model_q, local_q in quantile_pairs:
166
- # Adjust local quantiles to be centered around consensus
167
- adjusted_local_q = out[local_q] + local_center_offset
168
-
169
- # Blend model and local intervals
170
- out[model_q] = (
171
- (1 - calibration_weight) * out[model_q] +
172
- calibration_weight * adjusted_local_q
173
- )
174
-
175
- # Ensure proper interval ordering and bounds using pandas
176
- out["q_025"] = pd.concat([out["q_025"], consensus_pred], axis=1).min(axis=1)
177
- out["q_975"] = pd.concat([out["q_975"], consensus_pred], axis=1).max(axis=1)
178
- out["q_25"] = pd.concat([out["q_25"], out["q_75"]], axis=1).min(axis=1)
179
-
180
- # Optional: Add some interval expansion when neighbors are very far
181
- # (indicates we're in a sparse region of feature space)
182
- sparse_region_mask = out['min_distance'] > out['min_distance'].quantile(0.9)
183
- expansion_factor = 1 + 0.2 * sparse_region_mask # 20% expansion in sparse regions
184
-
185
- for q in ["q_025", "q_25", "q_75", "q_975"]:
186
- interval_width = out[q] - consensus_pred
187
- out[q] = consensus_pred + interval_width * expansion_factor
188
-
189
- # Clean up temporary columns
190
- cleanup_cols = [col for col in out.columns if col.startswith("local_")] + \
191
- ['avg_distance', 'min_distance', 'max_distance']
192
-
193
- return out.drop(columns=cleanup_cols)
194
-
195
-
196
- # TRAINING SECTION
197
- #
198
- # This section (__main__) is where SageMaker will execute the training job
199
- # and save the model artifacts to the model directory.
200
- #
201
- if __name__ == "__main__":
202
- # Template Parameters
203
- id_column = TEMPLATE_PARAMS["id_column"]
204
- features = TEMPLATE_PARAMS["features"]
205
- target = TEMPLATE_PARAMS["target"]
206
- train_all_data = TEMPLATE_PARAMS["train_all_data"]
207
- track_columns = TEMPLATE_PARAMS["track_columns"] # Can be None
208
- validation_split = 0.2
209
-
210
- # Script arguments for input/output directories
211
- parser = argparse.ArgumentParser()
212
- parser.add_argument("--model-dir", type=str, default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model"))
213
- parser.add_argument("--train", type=str, default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"))
214
- parser.add_argument(
215
- "--output-data-dir", type=str, default=os.environ.get("SM_OUTPUT_DATA_DIR", "/opt/ml/output/data")
216
- )
217
- args = parser.parse_args()
218
-
219
- # Load training data from the specified directory
220
- training_files = [
221
- os.path.join(args.train, file)
222
- for file in os.listdir(args.train) if file.endswith(".csv")
223
- ]
224
- print(f"Training Files: {training_files}")
225
-
226
- # Combine files and read them all into a single pandas dataframe
227
- df = pd.concat([pd.read_csv(file, engine="python") for file in training_files])
228
-
229
- # Check if the DataFrame is empty
230
- check_dataframe(df, "training_df")
231
-
232
- # Training data split logic
233
- if train_all_data:
234
- # Use all data for both training and validation
235
- print("Training on all data...")
236
- df_train = df.copy()
237
- df_val = df.copy()
238
- elif "training" in df.columns:
239
- # Split data based on a 'training' column if it exists
240
- print("Splitting data based on 'training' column...")
241
- df_train = df[df["training"]].copy()
242
- df_val = df[~df["training"]].copy()
243
- else:
244
- # Perform a random split if no 'training' column is found
245
- print("Splitting data randomly...")
246
- df_train, df_val = train_test_split(df, test_size=validation_split, random_state=42)
247
-
248
- # We're using XGBoost for point predictions and NGBoost for uncertainty quantification
249
- xgb_model = XGBRegressor()
250
- ngb_model = NGBRegressor()
251
-
252
- # Prepare features and targets for training
253
- X_train = df_train[features]
254
- X_val = df_val[features]
255
- y_train = df_train[target]
256
- y_val = df_val[target]
257
-
258
- # Train both models using the training data
259
- xgb_model.fit(X_train, y_train)
260
- ngb_model.fit(X_train, y_train, X_val=X_val, Y_val=y_val)
261
-
262
- # Make Predictions on the Validation Set
263
- print(f"Making Predictions on Validation Set...")
264
- y_validate = df_val[target]
265
- X_validate = df_val[features]
266
- preds = xgb_model.predict(X_validate)
267
-
268
- # Calculate various model performance metrics (regression)
269
- rmse = root_mean_squared_error(y_validate, preds)
270
- mae = mean_absolute_error(y_validate, preds)
271
- r2 = r2_score(y_validate, preds)
272
- print(f"RMSE: {rmse:.3f}")
273
- print(f"MAE: {mae:.3f}")
274
- print(f"R2: {r2:.3f}")
275
- print(f"NumRows: {len(df_val)}")
276
-
277
- # Save the trained XGBoost model
278
- xgb_model.save_model(os.path.join(args.model_dir, "xgb_model.json"))
279
-
280
- # Save the trained NGBoost model
281
- joblib.dump(ngb_model, os.path.join(args.model_dir, "ngb_model.joblib"))
282
-
283
- # Save the feature list to validate input during predictions
284
- with open(os.path.join(args.model_dir, "feature_columns.json"), "w") as fp:
285
- json.dump(features, fp)
286
-
287
- # Now the Proximity model
288
- model = Proximity(df_train, id_column, features, target, track_columns=track_columns)
289
-
290
- # Now serialize the model
291
- model.serialize(args.model_dir)
292
-
293
-
294
- #
295
- # Inference Section
296
- #
297
- def model_fn(model_dir) -> dict:
298
- """Load and return XGBoost and NGBoost regressors from model directory."""
299
-
300
- # Load XGBoost regressor
301
- xgb_path = os.path.join(model_dir, "xgb_model.json")
302
- xgb_model = XGBRegressor(enable_categorical=True)
303
- xgb_model.load_model(xgb_path)
304
-
305
- # Load NGBoost regressor
306
- ngb_model = joblib.load(os.path.join(model_dir, "ngb_model.joblib"))
307
-
308
- # Deserialize the proximity model
309
- prox_model = Proximity.deserialize(model_dir)
310
-
311
- return {
312
- "xgboost": xgb_model,
313
- "ngboost": ngb_model,
314
- "proximity": prox_model
315
- }
316
-
317
-
318
- def input_fn(input_data, content_type):
319
- """Parse input data and return a DataFrame."""
320
- if not input_data:
321
- raise ValueError("Empty input data is not supported!")
322
-
323
- # Decode bytes to string if necessary
324
- if isinstance(input_data, bytes):
325
- input_data = input_data.decode("utf-8")
326
-
327
- if "text/csv" in content_type:
328
- return pd.read_csv(StringIO(input_data))
329
- elif "application/json" in content_type:
330
- return pd.DataFrame(json.loads(input_data)) # Assumes JSON array of records
331
- else:
332
- raise ValueError(f"{content_type} not supported!")
333
-
334
-
335
- def output_fn(output_df, accept_type):
336
- """Supports both CSV and JSON output formats."""
337
- if "text/csv" in accept_type:
338
- csv_output = output_df.fillna("N/A").to_csv(index=False) # CSV with N/A for missing values
339
- return csv_output, "text/csv"
340
- elif "application/json" in accept_type:
341
- return output_df.to_json(orient="records"), "application/json" # JSON array of records (NaNs -> null)
342
- else:
343
- raise RuntimeError(f"{accept_type} accept type is not supported by this script.")
344
-
345
-
346
- def predict_fn(df, models) -> pd.DataFrame:
347
- """Make Predictions with our XGB Quantile Regression Model
348
-
349
- Args:
350
- df (pd.DataFrame): The input DataFrame
351
- models (dict): The dictionary of models to use for predictions
352
-
353
- Returns:
354
- pd.DataFrame: The DataFrame with the predictions added
355
- """
356
-
357
- # Grab our feature columns (from training)
358
- model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
359
- with open(os.path.join(model_dir, "feature_columns.json")) as fp:
360
- model_features = json.load(fp)
361
-
362
- # Match features in a case-insensitive manner
363
- matched_df = match_features_case_insensitive(df, model_features)
364
-
365
- # Use XGBoost for point predictions
366
- df["prediction"] = models["xgboost"].predict(matched_df[model_features])
367
-
368
- # NGBoost predict returns distribution objects
369
- y_dists = models["ngboost"].pred_dist(matched_df[model_features])
370
-
371
- # Extract parameters from distribution
372
- dist_params = y_dists.params
373
-
374
- # Extract mean and std from distribution parameters
375
- df["prediction_uq"] = dist_params['loc'] # mean
376
- df["prediction_std"] = dist_params['scale'] # standard deviation
377
-
378
- # Add 95% prediction intervals using ppf (percent point function)
379
- df["q_025"] = y_dists.ppf(0.025) # 2.5th percentile
380
- df["q_975"] = y_dists.ppf(0.975) # 97.5th percentile
381
-
382
- # Add 50% prediction intervals
383
- df["q_25"] = y_dists.ppf(0.25) # 25th percentile
384
- df["q_75"] = y_dists.ppf(0.75) # 75th percentile
385
-
386
- # Compute Nearest neighbors with Proximity model
387
- prox_df = models["proximity"].neighbors(df)
388
-
389
- # Shrink prediction intervals based on KNN variance
390
- df = distance_weighted_calibrated_intervals(df, prox_df)
391
-
392
- # Return the modified DataFrame
393
- return df