workbench 0.8.162__py3-none-any.whl → 0.8.202__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (113) hide show
  1. workbench/algorithms/dataframe/__init__.py +1 -2
  2. workbench/algorithms/dataframe/fingerprint_proximity.py +2 -2
  3. workbench/algorithms/dataframe/proximity.py +261 -235
  4. workbench/algorithms/graph/light/proximity_graph.py +10 -8
  5. workbench/api/__init__.py +2 -1
  6. workbench/api/compound.py +1 -1
  7. workbench/api/endpoint.py +11 -0
  8. workbench/api/feature_set.py +11 -8
  9. workbench/api/meta.py +5 -2
  10. workbench/api/model.py +16 -15
  11. workbench/api/monitor.py +1 -16
  12. workbench/core/artifacts/__init__.py +11 -2
  13. workbench/core/artifacts/artifact.py +11 -3
  14. workbench/core/artifacts/data_capture_core.py +355 -0
  15. workbench/core/artifacts/endpoint_core.py +256 -118
  16. workbench/core/artifacts/feature_set_core.py +265 -16
  17. workbench/core/artifacts/model_core.py +107 -60
  18. workbench/core/artifacts/monitor_core.py +33 -248
  19. workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
  20. workbench/core/cloud_platform/aws/aws_meta.py +12 -5
  21. workbench/core/cloud_platform/aws/aws_parameter_store.py +18 -2
  22. workbench/core/cloud_platform/aws/aws_session.py +4 -4
  23. workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
  24. workbench/core/transforms/features_to_model/features_to_model.py +42 -32
  25. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +36 -6
  26. workbench/core/transforms/pandas_transforms/pandas_to_features.py +27 -0
  27. workbench/core/views/training_view.py +113 -42
  28. workbench/core/views/view.py +53 -3
  29. workbench/core/views/view_utils.py +4 -4
  30. workbench/model_scripts/chemprop/chemprop.template +852 -0
  31. workbench/model_scripts/chemprop/generated_model_script.py +852 -0
  32. workbench/model_scripts/chemprop/requirements.txt +11 -0
  33. workbench/model_scripts/custom_models/chem_info/fingerprints.py +134 -0
  34. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
  35. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
  36. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
  37. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
  38. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +3 -5
  39. workbench/model_scripts/custom_models/proximity/proximity.py +261 -235
  40. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  41. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  42. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  43. workbench/model_scripts/custom_models/uq_models/meta_uq.template +166 -62
  44. workbench/model_scripts/custom_models/uq_models/ngboost.template +30 -18
  45. workbench/model_scripts/custom_models/uq_models/proximity.py +261 -235
  46. workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
  47. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  48. workbench/model_scripts/pytorch_model/generated_model_script.py +373 -190
  49. workbench/model_scripts/pytorch_model/pytorch.template +370 -187
  50. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  51. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  52. workbench/model_scripts/script_generation.py +17 -9
  53. workbench/model_scripts/uq_models/generated_model_script.py +605 -0
  54. workbench/model_scripts/uq_models/mapie.template +605 -0
  55. workbench/model_scripts/uq_models/requirements.txt +1 -0
  56. workbench/model_scripts/xgb_model/generated_model_script.py +37 -46
  57. workbench/model_scripts/xgb_model/xgb_model.template +44 -46
  58. workbench/repl/workbench_shell.py +28 -14
  59. workbench/scripts/endpoint_test.py +162 -0
  60. workbench/scripts/lambda_test.py +73 -0
  61. workbench/scripts/ml_pipeline_batch.py +137 -0
  62. workbench/scripts/ml_pipeline_sqs.py +186 -0
  63. workbench/scripts/monitor_cloud_watch.py +20 -100
  64. workbench/utils/aws_utils.py +4 -3
  65. workbench/utils/chem_utils/__init__.py +0 -0
  66. workbench/utils/chem_utils/fingerprints.py +134 -0
  67. workbench/utils/chem_utils/misc.py +194 -0
  68. workbench/utils/chem_utils/mol_descriptors.py +483 -0
  69. workbench/utils/chem_utils/mol_standardize.py +450 -0
  70. workbench/utils/chem_utils/mol_tagging.py +348 -0
  71. workbench/utils/chem_utils/projections.py +209 -0
  72. workbench/utils/chem_utils/salts.py +256 -0
  73. workbench/utils/chem_utils/sdf.py +292 -0
  74. workbench/utils/chem_utils/toxicity.py +250 -0
  75. workbench/utils/chem_utils/vis.py +253 -0
  76. workbench/utils/chemprop_utils.py +760 -0
  77. workbench/utils/cloudwatch_handler.py +1 -1
  78. workbench/utils/cloudwatch_utils.py +137 -0
  79. workbench/utils/config_manager.py +3 -7
  80. workbench/utils/endpoint_utils.py +5 -7
  81. workbench/utils/license_manager.py +2 -6
  82. workbench/utils/model_utils.py +95 -34
  83. workbench/utils/monitor_utils.py +44 -62
  84. workbench/utils/pandas_utils.py +3 -3
  85. workbench/utils/pytorch_utils.py +526 -0
  86. workbench/utils/shap_utils.py +10 -2
  87. workbench/utils/workbench_logging.py +0 -3
  88. workbench/utils/workbench_sqs.py +1 -1
  89. workbench/utils/xgboost_model_utils.py +371 -156
  90. workbench/web_interface/components/model_plot.py +7 -1
  91. workbench/web_interface/components/plugin_unit_test.py +5 -2
  92. workbench/web_interface/components/plugins/dashboard_status.py +3 -1
  93. workbench/web_interface/components/plugins/generated_compounds.py +1 -1
  94. workbench/web_interface/components/plugins/model_details.py +9 -7
  95. workbench/web_interface/components/plugins/scatter_plot.py +3 -3
  96. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/METADATA +27 -6
  97. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/RECORD +101 -85
  98. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/entry_points.txt +4 -0
  99. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/licenses/LICENSE +1 -1
  100. workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
  101. workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
  102. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  103. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  104. workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
  105. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  106. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  107. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  108. workbench/utils/chem_utils.py +0 -1556
  109. workbench/utils/execution_environment.py +0 -211
  110. workbench/utils/fast_inference.py +0 -167
  111. workbench/utils/resource_utils.py +0 -39
  112. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/WHEEL +0 -0
  113. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/top_level.txt +0 -0
@@ -1,203 +0,0 @@
1
- # Model: HistGradientBoosting with MAPIE Conformalized Quantile Regression
2
- from mapie.regression import MapieQuantileRegressor
3
- from sklearn.ensemble import HistGradientBoostingRegressor
4
- from sklearn.model_selection import train_test_split
5
- import numpy as np
6
-
7
- # Template Placeholders
8
- TEMPLATE_PARAMS = {
9
- "features": "{{feature_list}}",
10
- "target": "{{target_column}}",
11
- "train_all_data": "{{train_all_data}}"
12
- }
13
-
14
- from io import StringIO
15
- import json
16
- import argparse
17
- import joblib
18
- import os
19
- import pandas as pd
20
-
21
-
22
- # Function to check if dataframe is empty
23
- def check_dataframe(df: pd.DataFrame, df_name: str) -> None:
24
- """Check if the DataFrame is empty and raise an error if so."""
25
- if df.empty:
26
- msg = f"*** The training data {df_name} has 0 rows! ***STOPPING***"
27
- print(msg)
28
- raise ValueError(msg)
29
-
30
-
31
- def match_features_case_insensitive(df: pd.DataFrame, model_features: list) -> pd.DataFrame:
32
- """
33
- Matches and renames DataFrame columns to match model feature names (case-insensitive).
34
- Prioritizes exact matches, then case-insensitive matches.
35
-
36
- Raises ValueError if any model features cannot be matched.
37
- """
38
- df_columns_lower = {col.lower(): col for col in df.columns}
39
- rename_dict = {}
40
- missing = []
41
- for feature in model_features:
42
- if feature in df.columns:
43
- continue # Exact match
44
- elif feature.lower() in df_columns_lower:
45
- rename_dict[df_columns_lower[feature.lower()]] = feature
46
- else:
47
- missing.append(feature)
48
-
49
- if missing:
50
- raise ValueError(f"Features not found: {missing}")
51
-
52
- # Rename the DataFrame columns to match the model features
53
- return df.rename(columns=rename_dict)
54
-
55
-
56
- # TRAINING SECTION
57
- #
58
- # This section (__main__) is where SageMaker will execute the training job
59
- # and save the model artifacts to the model directory.
60
- #
61
- if __name__ == "__main__":
62
- # Template Parameters
63
- features = TEMPLATE_PARAMS["features"]
64
- target = TEMPLATE_PARAMS["target"]
65
- train_all_data = TEMPLATE_PARAMS["train_all_data"]
66
- validation_split = 0.2
67
-
68
- # Script arguments for input/output directories
69
- parser = argparse.ArgumentParser()
70
- parser.add_argument("--model-dir", type=str, default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model"))
71
- parser.add_argument("--train", type=str, default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"))
72
- parser.add_argument(
73
- "--output-data-dir", type=str, default=os.environ.get("SM_OUTPUT_DATA_DIR", "/opt/ml/output/data")
74
- )
75
- args = parser.parse_args()
76
-
77
- # Load training data from the specified directory
78
- training_files = [
79
- os.path.join(args.train, file)
80
- for file in os.listdir(args.train) if file.endswith(".csv")
81
- ]
82
- df = pd.concat([pd.read_csv(file, engine="python") for file in training_files])
83
-
84
- # Check if the DataFrame is empty
85
- check_dataframe(df, "training_df")
86
-
87
- # Training data split logic
88
- if train_all_data:
89
- # Use all data for both training and validation
90
- print("Training on all data...")
91
- df_train = df.copy()
92
- df_val = df.copy()
93
- elif "training" in df.columns:
94
- # Split data based on a 'training' column if it exists
95
- print("Splitting data based on 'training' column...")
96
- df_train = df[df["training"]].copy()
97
- df_val = df[~df["training"]].copy()
98
- else:
99
- # Perform a random split if no 'training' column is found
100
- print("Splitting data randomly...")
101
- df_train, df_val = train_test_split(df, test_size=validation_split, random_state=42)
102
-
103
- # Create HistGradientBoosting base model configured for quantile regression
104
- base_estimator = HistGradientBoostingRegressor(
105
- loss='quantile', # Required for MAPIE CQR
106
- quantile=0.5, # Will be overridden by MAPIE for different quantiles
107
- max_iter=1000,
108
- max_depth=6,
109
- learning_rate=0.01,
110
- random_state=42
111
- )
112
-
113
- # Create MAPIE CQR predictor - it will create quantile versions internally
114
- model = MapieQuantileRegressor(
115
- estimator=base_estimator,
116
- method="quantile",
117
- cv="split",
118
- alpha=0.05 # For 95% coverage
119
- )
120
-
121
- # Prepare features and targets for training
122
- X_train = df_train[features]
123
- X_val = df_val[features]
124
- y_train = df_train[target]
125
- y_val = df_val[target]
126
-
127
- # Fit the MAPIE CQR model (train/calibration is handled internally)
128
- model.fit(X_train, y_train)
129
-
130
- # Save the trained model and any necessary assets
131
- joblib.dump(model, os.path.join(args.model_dir, "model.joblib"))
132
-
133
- # Save the feature list to validate input during predictions
134
- with open(os.path.join(args.model_dir, "feature_columns.json"), "w") as fp:
135
- json.dump(features, fp)
136
-
137
-
138
- #
139
- # Inference Section
140
- #
141
- def model_fn(model_dir):
142
- """Load and return the model from the specified directory."""
143
- return joblib.load(os.path.join(model_dir, "model.joblib"))
144
-
145
-
146
- def input_fn(input_data, content_type):
147
- """Parse input data and return a DataFrame."""
148
- if not input_data:
149
- raise ValueError("Empty input data is not supported!")
150
-
151
- # Decode bytes to string if necessary
152
- if isinstance(input_data, bytes):
153
- input_data = input_data.decode("utf-8")
154
-
155
- if "text/csv" in content_type:
156
- return pd.read_csv(StringIO(input_data))
157
- elif "application/json" in content_type:
158
- return pd.DataFrame(json.loads(input_data)) # Assumes JSON array of records
159
- else:
160
- raise ValueError(f"{content_type} not supported!")
161
-
162
-
163
- def output_fn(output_df, accept_type):
164
- """Supports both CSV and JSON output formats."""
165
- if "text/csv" in accept_type:
166
- csv_output = output_df.fillna("N/A").to_csv(index=False) # CSV with N/A for missing values
167
- return csv_output, "text/csv"
168
- elif "application/json" in accept_type:
169
- return output_df.to_json(orient="records"), "application/json" # JSON array of records (NaNs -> null)
170
- else:
171
- raise RuntimeError(f"{accept_type} accept type is not supported by this script.")
172
-
173
-
174
- def predict_fn(df, model):
175
- """Make predictions using MAPIE CQR and return the DataFrame with results."""
176
- model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
177
-
178
- # Load feature columns from the saved file
179
- with open(os.path.join(model_dir, "feature_columns.json")) as fp:
180
- model_features = json.load(fp)
181
-
182
- # Match features in a case-insensitive manner
183
- matched_df = match_features_case_insensitive(df, model_features)
184
-
185
- # Get CQR predictions - returns point prediction and intervals
186
- X_pred = matched_df[model_features]
187
- y_pred, y_pis = model.predict(X_pred)
188
-
189
- # Add predictions to dataframe with 95% intervals
190
- df["prediction"] = y_pred
191
- df["q_025"] = y_pis[:, 0, 0] # Lower bound (2.5th percentile)
192
- df["q_975"] = y_pis[:, 1, 0] # Upper bound (97.5th percentile)
193
-
194
- # Calculate std estimate from 95% interval
195
- interval_width_95 = df["q_975"] - df["q_025"]
196
- df["prediction_std"] = interval_width_95 / 3.92 # 95% CI = ±1.96σ, so width = 3.92σ
197
-
198
- # Calculate 50% intervals using normal approximation
199
- df["q_25"] = df["prediction"] - 0.674 * df["prediction_std"]
200
- df["q_75"] = df["prediction"] + 0.674 * df["prediction_std"]
201
-
202
- # Return the modified DataFrame
203
- return df
@@ -1,279 +0,0 @@
1
- # Template Placeholders
2
- TEMPLATE_PARAMS = {
3
- "model_type": "ensemble_regressor",
4
- "target_column": "solubility",
5
- "feature_list": ['molwt', 'mollogp', 'molmr', 'heavyatomcount', 'numhacceptors', 'numhdonors', 'numheteroatoms', 'numrotatablebonds', 'numvalenceelectrons', 'numaromaticrings', 'numsaturatedrings', 'numaliphaticrings', 'ringcount', 'tpsa', 'labuteasa', 'balabanj', 'bertzct'],
6
- "model_metrics_s3_path": "s3://sandbox-sageworks-artifacts/models/aqsol-ensemble/training"
7
- }
8
-
9
- # Imports for XGB Model
10
- import xgboost as xgb
11
- import awswrangler as wr
12
- import numpy as np
13
-
14
- # Model Performance Scores
15
- from sklearn.metrics import (
16
- mean_absolute_error,
17
- r2_score,
18
- root_mean_squared_error
19
- )
20
-
21
- from io import StringIO
22
- import json
23
- import argparse
24
- import os
25
- import pandas as pd
26
-
27
-
28
- # Function to check if dataframe is empty
29
- def check_dataframe(df: pd.DataFrame, df_name: str) -> None:
30
- """
31
- Check if the provided dataframe is empty and raise an exception if it is.
32
-
33
- Args:
34
- df (pd.DataFrame): DataFrame to check
35
- df_name (str): Name of the DataFrame
36
- """
37
- if df.empty:
38
- msg = f"*** The training data {df_name} has 0 rows! ***STOPPING***"
39
- print(msg)
40
- raise ValueError(msg)
41
-
42
- def match_features_case_insensitive(df: pd.DataFrame, model_features: list) -> pd.DataFrame:
43
- """
44
- Matches and renames the DataFrame's column names to match the model's feature names (case-insensitive).
45
- Prioritizes exact case matches first, then falls back to case-insensitive matching if no exact match exists.
46
-
47
- Args:
48
- df (pd.DataFrame): The DataFrame with the original columns.
49
- model_features (list): The desired list of feature names (mixed case allowed).
50
-
51
- Returns:
52
- pd.DataFrame: The DataFrame with renamed columns to match the model's feature names.
53
- """
54
- # Create a mapping for exact and case-insensitive matching
55
- exact_match_set = set(df.columns)
56
- column_map = {}
57
-
58
- # Build the case-insensitive map (if we have any duplicate columns, the first one wins)
59
- for col in df.columns:
60
- lower_col = col.lower()
61
- if lower_col not in column_map:
62
- column_map[lower_col] = col
63
-
64
- # Create a dictionary for renaming
65
- rename_dict = {}
66
- for feature in model_features:
67
- # Check for an exact match first
68
- if feature in exact_match_set:
69
- rename_dict[feature] = feature
70
-
71
- # If not an exact match, fall back to case-insensitive matching
72
- elif feature.lower() in column_map:
73
- rename_dict[column_map[feature.lower()]] = feature
74
-
75
- # Rename the columns in the DataFrame to match the model's feature names
76
- return df.rename(columns=rename_dict)
77
-
78
-
79
- if __name__ == "__main__":
80
- """The main function is for training the XGBoost Quantile Regression models"""
81
-
82
- # Harness Template Parameters
83
- target = TEMPLATE_PARAMS["target_column"]
84
- feature_list = TEMPLATE_PARAMS["feature_list"]
85
- model_metrics_s3_path = TEMPLATE_PARAMS["model_metrics_s3_path"]
86
- models = {}
87
-
88
- # Script arguments for input/output directories
89
- parser = argparse.ArgumentParser()
90
- parser.add_argument("--model-dir", type=str, default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model"))
91
- parser.add_argument("--train", type=str, default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"))
92
- parser.add_argument(
93
- "--output-data-dir", type=str, default=os.environ.get("SM_OUTPUT_DATA_DIR", "/opt/ml/output/data")
94
- )
95
- args = parser.parse_args()
96
-
97
- # Read the training data into DataFrames
98
- training_files = [
99
- os.path.join(args.train, file)
100
- for file in os.listdir(args.train)
101
- if file.endswith(".csv")
102
- ]
103
- print(f"Training Files: {training_files}")
104
-
105
- # Combine files and read them all into a single pandas dataframe
106
- df = pd.concat([pd.read_csv(file, engine="python") for file in training_files])
107
-
108
- # Check if the dataframe is empty
109
- check_dataframe(df, "training_df")
110
-
111
- # Features/Target output
112
- print(f"Target: {target}")
113
- print(f"Features: {str(feature_list)}")
114
- print(f"Data Shape: {df.shape}")
115
-
116
- # Grab our Features and Target with traditional X, y handles
117
- y = df[target]
118
- X = df[feature_list]
119
-
120
- # Train 50 models with random 70/30 splits of the data
121
- for model_id in range(50):
122
- # Model Name
123
- model_name = f"m_{model_id:02}"
124
-
125
- # Bootstrap sample (50% with replacement)
126
- sample_size = int(0.5 * len(X))
127
- bootstrap_indices = np.random.choice(len(X), size=sample_size, replace=True)
128
- X_train, y_train = X.iloc[bootstrap_indices], y.iloc[bootstrap_indices]
129
- print(f"Training Model {model_name} with {len(X_train)} rows")
130
- model = xgb.XGBRegressor(reg_alpha=0.1, reg_lambda=1.0)
131
- model.fit(X_train, y_train)
132
-
133
- # Store the model
134
- models[model_name] = model
135
-
136
- # Run predictions for each model
137
- all_predictions = {model_name: model.predict(X) for model_name, model in models.items()}
138
-
139
- # Create a copy of the provided DataFrame and add the new columns
140
- result_df = df[[target]].copy()
141
-
142
- # Add the model predictions to the DataFrame
143
- for name, preds in all_predictions.items():
144
- result_df[name] = preds
145
-
146
- # Add the main prediction to the DataFrame (mean of all models)
147
- result_df["prediction"] = result_df[[name for name in result_df.columns if name.startswith("m_")]].mean(axis=1)
148
-
149
- # Now compute residuals on the rmse prediction
150
- result_df["residual"] = result_df[target] - result_df["prediction"]
151
- result_df["residual_abs"] = result_df["residual"].abs()
152
-
153
-
154
- # Save the results dataframe to S3
155
- wr.s3.to_csv(
156
- result_df,
157
- path=f"{model_metrics_s3_path}/validation_predictions.csv",
158
- index=False,
159
- )
160
-
161
- # Report Performance Metrics
162
- rmse = root_mean_squared_error(result_df[target], result_df["prediction"])
163
- mae = mean_absolute_error(result_df[target], result_df["prediction"])
164
- r2 = r2_score(result_df[target], result_df["prediction"])
165
- print(f"RMSE: {rmse:.3f}")
166
- print(f"MAE: {mae:.3f}")
167
- print(f"R2: {r2:.3f}")
168
- print(f"NumRows: {len(result_df)}")
169
-
170
- # Now save the models
171
- for name, model in models.items():
172
- model_path = os.path.join(args.model_dir, f"{name}.json")
173
- print(f"Saving model: {model_path}")
174
- model.save_model(model_path)
175
-
176
- # Also save the features (this will validate input during predictions)
177
- with open(os.path.join(args.model_dir, "feature_columns.json"), "w") as fp:
178
- json.dump(feature_list, fp)
179
-
180
-
181
- def model_fn(model_dir) -> dict:
182
- """Deserialized and return all the fitted models from the model directory.
183
-
184
- Args:
185
- model_dir (str): The directory where the models are stored.
186
-
187
- Returns:
188
- dict: A dictionary of the models.
189
- """
190
-
191
- # Load ALL the models from the model directory
192
- models = {}
193
- for file in os.listdir(model_dir):
194
- if file.startswith("m_") and file.endswith(".json"): # The Quantile models
195
- # Load the model
196
- model_path = os.path.join(model_dir, file)
197
- print(f"Loading model: {model_path}")
198
- model = xgb.XGBRegressor()
199
- model.load_model(model_path)
200
-
201
- # Store the model
202
- m_name = os.path.splitext(file)[0]
203
- models[m_name] = model
204
-
205
- # Return all the models
206
- return models
207
-
208
-
209
- def input_fn(input_data, content_type):
210
- """Parse input data and return a DataFrame."""
211
- if not input_data:
212
- raise ValueError("Empty input data is not supported!")
213
-
214
- # Decode bytes to string if necessary
215
- if isinstance(input_data, bytes):
216
- input_data = input_data.decode("utf-8")
217
-
218
- if "text/csv" in content_type:
219
- return pd.read_csv(StringIO(input_data))
220
- elif "application/json" in content_type:
221
- return pd.DataFrame(json.loads(input_data)) # Assumes JSON array of records
222
- else:
223
- raise ValueError(f"{content_type} not supported!")
224
-
225
-
226
- def output_fn(output_df, accept_type):
227
- """Supports both CSV and JSON output formats."""
228
- if "text/csv" in accept_type:
229
- csv_output = output_df.fillna("N/A").to_csv(index=False) # CSV with N/A for missing values
230
- return csv_output, "text/csv"
231
- elif "application/json" in accept_type:
232
- return output_df.to_json(orient="records"), "application/json" # JSON array of records (NaNs -> null)
233
- else:
234
- raise RuntimeError(f"{accept_type} accept type is not supported by this script.")
235
-
236
-
237
- def predict_fn(df, models) -> pd.DataFrame:
238
- """Make Predictions with our XGB Quantile Regression Model
239
-
240
- Args:
241
- df (pd.DataFrame): The input DataFrame
242
- models (dict): The dictionary of models to use for predictions
243
-
244
- Returns:
245
- pd.DataFrame: The DataFrame with the predictions added
246
- """
247
-
248
- # Grab our feature columns (from training)
249
- model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
250
- with open(os.path.join(model_dir, "feature_columns.json")) as fp:
251
- model_features = json.load(fp)
252
- print(f"Model Features: {model_features}")
253
-
254
- # We're going match features in a case-insensitive manner, accounting for all the permutations
255
- # - Model has a feature list that's any case ("Id", "taCos", "cOunT", "likes_tacos")
256
- # - Incoming data has columns that are mixed case ("ID", "Tacos", "Count", "Likes_Tacos")
257
- matched_df = match_features_case_insensitive(df, model_features)
258
-
259
- # Predict the features against all the models
260
- for name, model in models.items():
261
- df[name] = model.predict(matched_df[model_features])
262
-
263
- # Add quantiles for consistency with other UQ models
264
- df["q_025"] = df[[name for name in df.columns if name.startswith("m_")]].quantile(0.025, axis=1)
265
- df["q_975"] = df[[name for name in df.columns if name.startswith("m_")]].quantile(0.975, axis=1)
266
- df["q_25"] = df[[name for name in df.columns if name.startswith("m_")]].quantile(0.25, axis=1)
267
- df["q_75"] = df[[name for name in df.columns if name.startswith("m_")]].quantile(0.75, axis=1)
268
-
269
- # Compute the mean, min, max and stddev of the predictions
270
- df["prediction"] = df[[name for name in df.columns if name.startswith("m_")]].mean(axis=1)
271
- df["p_min"] = df[[name for name in df.columns if name.startswith("m_")]].min(axis=1)
272
- df["p_max"] = df[[name for name in df.columns if name.startswith("m_")]].max(axis=1)
273
- df["prediction_std"] = df[[name for name in df.columns if name.startswith("m_")]].std(axis=1)
274
-
275
- # Reorganize the columns so they are in alphabetical order
276
- df = df.reindex(sorted(df.columns), axis=1)
277
-
278
- # All done, return the DataFrame
279
- return df