workbench 0.8.168__py3-none-any.whl → 0.8.192__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. workbench/algorithms/dataframe/proximity.py +143 -102
  2. workbench/algorithms/graph/light/proximity_graph.py +2 -1
  3. workbench/api/compound.py +1 -1
  4. workbench/api/endpoint.py +3 -2
  5. workbench/api/feature_set.py +4 -4
  6. workbench/api/model.py +16 -12
  7. workbench/api/monitor.py +1 -16
  8. workbench/core/artifacts/artifact.py +11 -3
  9. workbench/core/artifacts/data_capture_core.py +355 -0
  10. workbench/core/artifacts/endpoint_core.py +113 -27
  11. workbench/core/artifacts/feature_set_core.py +72 -13
  12. workbench/core/artifacts/model_core.py +50 -15
  13. workbench/core/artifacts/monitor_core.py +33 -249
  14. workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
  15. workbench/core/cloud_platform/aws/aws_meta.py +11 -4
  16. workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
  17. workbench/core/transforms/features_to_model/features_to_model.py +9 -4
  18. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +36 -6
  19. workbench/core/transforms/pandas_transforms/pandas_to_features.py +27 -0
  20. workbench/core/views/training_view.py +49 -53
  21. workbench/core/views/view.py +51 -1
  22. workbench/core/views/view_utils.py +4 -4
  23. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
  24. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
  25. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
  26. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +3 -5
  27. workbench/model_scripts/custom_models/proximity/proximity.py +143 -102
  28. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  29. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +10 -17
  30. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  31. workbench/model_scripts/custom_models/uq_models/meta_uq.template +156 -58
  32. workbench/model_scripts/custom_models/uq_models/ngboost.template +20 -14
  33. workbench/model_scripts/custom_models/uq_models/proximity.py +143 -102
  34. workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
  35. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +5 -13
  36. workbench/model_scripts/pytorch_model/pytorch.template +9 -18
  37. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  38. workbench/model_scripts/script_generation.py +7 -2
  39. workbench/model_scripts/uq_models/mapie.template +492 -0
  40. workbench/model_scripts/uq_models/requirements.txt +1 -0
  41. workbench/model_scripts/xgb_model/xgb_model.template +31 -40
  42. workbench/repl/workbench_shell.py +4 -4
  43. workbench/scripts/lambda_launcher.py +63 -0
  44. workbench/scripts/{ml_pipeline_launcher.py → ml_pipeline_batch.py} +49 -51
  45. workbench/scripts/ml_pipeline_sqs.py +186 -0
  46. workbench/utils/chem_utils/__init__.py +0 -0
  47. workbench/utils/chem_utils/fingerprints.py +134 -0
  48. workbench/utils/chem_utils/misc.py +194 -0
  49. workbench/utils/chem_utils/mol_descriptors.py +483 -0
  50. workbench/utils/chem_utils/mol_standardize.py +450 -0
  51. workbench/utils/chem_utils/mol_tagging.py +348 -0
  52. workbench/utils/chem_utils/projections.py +209 -0
  53. workbench/utils/chem_utils/salts.py +256 -0
  54. workbench/utils/chem_utils/sdf.py +292 -0
  55. workbench/utils/chem_utils/toxicity.py +250 -0
  56. workbench/utils/chem_utils/vis.py +253 -0
  57. workbench/utils/config_manager.py +2 -6
  58. workbench/utils/endpoint_utils.py +5 -7
  59. workbench/utils/license_manager.py +2 -6
  60. workbench/utils/model_utils.py +76 -30
  61. workbench/utils/monitor_utils.py +44 -62
  62. workbench/utils/pandas_utils.py +3 -3
  63. workbench/utils/shap_utils.py +10 -2
  64. workbench/utils/workbench_sqs.py +1 -1
  65. workbench/utils/xgboost_model_utils.py +283 -145
  66. workbench/web_interface/components/plugins/dashboard_status.py +3 -1
  67. workbench/web_interface/components/plugins/generated_compounds.py +1 -1
  68. workbench/web_interface/components/plugins/scatter_plot.py +3 -3
  69. {workbench-0.8.168.dist-info → workbench-0.8.192.dist-info}/METADATA +2 -1
  70. {workbench-0.8.168.dist-info → workbench-0.8.192.dist-info}/RECORD +74 -70
  71. {workbench-0.8.168.dist-info → workbench-0.8.192.dist-info}/entry_points.txt +3 -1
  72. workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
  73. workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
  74. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  75. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  76. workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
  77. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  78. workbench/model_scripts/pytorch_model/generated_model_script.py +0 -576
  79. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  80. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  81. workbench/model_scripts/scikit_learn/generated_model_script.py +0 -307
  82. workbench/model_scripts/xgb_model/generated_model_script.py +0 -477
  83. workbench/utils/chem_utils.py +0 -1556
  84. workbench/utils/fast_inference.py +0 -167
  85. workbench/utils/resource_utils.py +0 -39
  86. {workbench-0.8.168.dist-info → workbench-0.8.192.dist-info}/WHEEL +0 -0
  87. {workbench-0.8.168.dist-info → workbench-0.8.192.dist-info}/licenses/LICENSE +0 -0
  88. {workbench-0.8.168.dist-info → workbench-0.8.192.dist-info}/top_level.txt +0 -0
@@ -1,576 +0,0 @@
1
- # Imports for PyTorch Tabular Model
2
- import os
3
- import awswrangler as wr
4
- import numpy as np
5
-
6
- # PyTorch compatibility: pytorch-tabular saves complex objects, not just tensors
7
- # Use legacy loading behavior for compatibility (recommended by PyTorch docs for this scenario)
8
- os.environ["TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"] = "1"
9
- from pytorch_tabular import TabularModel
10
- from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig
11
- from pytorch_tabular.models import CategoryEmbeddingModelConfig
12
-
13
- # Model Performance Scores
14
- from sklearn.metrics import (
15
- mean_absolute_error,
16
- r2_score,
17
- root_mean_squared_error,
18
- precision_recall_fscore_support,
19
- confusion_matrix,
20
- )
21
-
22
- # Classification Encoder
23
- from sklearn.preprocessing import LabelEncoder
24
-
25
- # Scikit Learn Imports
26
- from sklearn.model_selection import train_test_split
27
-
28
- from io import StringIO
29
- import json
30
- import argparse
31
- import joblib
32
- import os
33
- import pandas as pd
34
- from typing import List, Tuple
35
-
36
- # Template Parameters
37
- TEMPLATE_PARAMS = {
38
- "model_type": "classifier",
39
- "target_column": "solubility_class",
40
- "features": ['molwt', 'mollogp', 'molmr', 'heavyatomcount', 'numhacceptors', 'numhdonors', 'numheteroatoms', 'numrotatablebonds', 'numvalenceelectrons', 'numaromaticrings', 'numsaturatedrings', 'numaliphaticrings', 'ringcount', 'tpsa', 'labuteasa', 'balabanj', 'bertzct'],
41
- "compressed_features": [],
42
- "model_metrics_s3_path": "s3://sandbox-sageworks-artifacts/models/aqsol-pytorch-class/training",
43
- "train_all_data": False,
44
- "hyperparameters": {'training_config': {'max_epochs': 150}, 'model_config': {'layers': '256-128-64'}}
45
- }
46
-
47
-
48
- # Function to check if dataframe is empty
49
- def check_dataframe(df: pd.DataFrame, df_name: str) -> None:
50
- """
51
- Check if the provided dataframe is empty and raise an exception if it is.
52
-
53
- Args:
54
- df (pd.DataFrame): DataFrame to check
55
- df_name (str): Name of the DataFrame
56
- """
57
- if df.empty:
58
- msg = f"*** The training data {df_name} has 0 rows! ***STOPPING***"
59
- print(msg)
60
- raise ValueError(msg)
61
-
62
-
63
- def expand_proba_column(df: pd.DataFrame, class_labels: List[str]) -> pd.DataFrame:
64
- """
65
- Expands a column in a DataFrame containing a list of probabilities into separate columns.
66
-
67
- Args:
68
- df (pd.DataFrame): DataFrame containing a "pred_proba" column
69
- class_labels (List[str]): List of class labels
70
-
71
- Returns:
72
- pd.DataFrame: DataFrame with the "pred_proba" expanded into separate columns
73
- """
74
-
75
- # Sanity check
76
- proba_column = "pred_proba"
77
- if proba_column not in df.columns:
78
- raise ValueError('DataFrame does not contain a "pred_proba" column')
79
-
80
- # Construct new column names with '_proba' suffix
81
- proba_splits = [f"{label}_proba" for label in class_labels]
82
-
83
- # Expand the proba_column into separate columns for each probability
84
- proba_df = pd.DataFrame(df[proba_column].tolist(), columns=proba_splits)
85
-
86
- # Drop any proba columns and reset the index in prep for the concat
87
- df = df.drop(columns=[proba_column] + proba_splits, errors="ignore")
88
- df = df.reset_index(drop=True)
89
-
90
- # Concatenate the new columns with the original DataFrame
91
- df = pd.concat([df, proba_df], axis=1)
92
- print(df)
93
- return df
94
-
95
-
96
- def match_features_case_insensitive(df: pd.DataFrame, model_features: list) -> pd.DataFrame:
97
- """
98
- Matches and renames DataFrame columns to match model feature names (case-insensitive).
99
- Prioritizes exact matches, then case-insensitive matches.
100
-
101
- Raises ValueError if any model features cannot be matched.
102
- """
103
- df_columns_lower = {col.lower(): col for col in df.columns}
104
- rename_dict = {}
105
- missing = []
106
-
107
- for feature in model_features:
108
- if feature in df.columns:
109
- continue # Exact match
110
- elif feature.lower() in df_columns_lower:
111
- rename_dict[df_columns_lower[feature.lower()]] = feature
112
- else:
113
- missing.append(feature)
114
-
115
- if missing:
116
- raise ValueError(f"Features not found: {missing}")
117
-
118
- return df.rename(columns=rename_dict)
119
-
120
-
121
- def convert_categorical_types(df: pd.DataFrame, features: list, category_mappings={}) -> tuple:
122
- """
123
- Converts appropriate columns to categorical type with consistent mappings.
124
-
125
- Args:
126
- df (pd.DataFrame): The DataFrame to process.
127
- features (list): List of feature names to consider for conversion.
128
- category_mappings (dict, optional): Existing category mappings. If empty dict, we're in
129
- training mode. If populated, we're in inference mode.
130
-
131
- Returns:
132
- tuple: (processed DataFrame, category mappings dictionary)
133
- """
134
- # Training mode
135
- if category_mappings == {}:
136
- for col in df.select_dtypes(include=["object", "string"]):
137
- if col in features and df[col].nunique() < 20:
138
- print(f"Training mode: Converting {col} to category")
139
- df[col] = df[col].astype("category")
140
- category_mappings[col] = df[col].cat.categories.tolist() # Store category mappings
141
-
142
- # Inference mode
143
- else:
144
- for col, categories in category_mappings.items():
145
- if col in df.columns:
146
- print(f"Inference mode: Applying categorical mapping for {col}")
147
- df[col] = pd.Categorical(df[col], categories=categories) # Apply consistent categorical mapping
148
-
149
- return df, category_mappings
150
-
151
-
152
- def decompress_features(
153
- df: pd.DataFrame, features: List[str], compressed_features: List[str]
154
- ) -> Tuple[pd.DataFrame, List[str]]:
155
- """Prepare features for the model
156
-
157
- Args:
158
- df (pd.DataFrame): The features DataFrame
159
- features (List[str]): Full list of feature names
160
- compressed_features (List[str]): List of feature names to decompress (bitstrings)
161
-
162
- Returns:
163
- pd.DataFrame: DataFrame with the decompressed features
164
- List[str]: Updated list of feature names after decompression
165
-
166
- Raises:
167
- ValueError: If any missing values are found in the specified features
168
- """
169
-
170
- # Check for any missing values in the required features
171
- missing_counts = df[features].isna().sum()
172
- if missing_counts.any():
173
- missing_features = missing_counts[missing_counts > 0]
174
- print(
175
- f"WARNING: Found missing values in features: {missing_features.to_dict()}. "
176
- "WARNING: You might want to remove/replace all NaN values before processing."
177
- )
178
-
179
- # Decompress the specified compressed features
180
- decompressed_features = features
181
- for feature in compressed_features:
182
- if (feature not in df.columns) or (feature not in features):
183
- print(f"Feature '{feature}' not in the features list, skipping decompression.")
184
- continue
185
-
186
- # Remove the feature from the list of features to avoid duplication
187
- decompressed_features.remove(feature)
188
-
189
- # Handle all compressed features as bitstrings
190
- bit_matrix = np.array([list(bitstring) for bitstring in df[feature]], dtype=np.uint8)
191
- prefix = feature[:3]
192
-
193
- # Create all new columns at once - avoids fragmentation
194
- new_col_names = [f"{prefix}_{i}" for i in range(bit_matrix.shape[1])]
195
- new_df = pd.DataFrame(bit_matrix, columns=new_col_names, index=df.index)
196
-
197
- # Add to features list
198
- decompressed_features.extend(new_col_names)
199
-
200
- # Drop original column and concatenate new ones
201
- df = df.drop(columns=[feature])
202
- df = pd.concat([df, new_df], axis=1)
203
-
204
- return df, decompressed_features
205
-
206
-
207
- def model_fn(model_dir):
208
-
209
- # Save current working directory
210
- original_cwd = os.getcwd()
211
- try:
212
- # Change to /tmp because Pytorch Tabular needs write access (creates a .pt_tmp directory)
213
- os.chdir('/tmp')
214
-
215
- # Load the model
216
- model_path = os.path.join(model_dir, "tabular_model")
217
- model = TabularModel.load_model(model_path)
218
-
219
- # Restore the original working directory
220
- finally:
221
- os.chdir(original_cwd)
222
-
223
- return model
224
-
225
-
226
- def input_fn(input_data, content_type):
227
- """Parse input data and return a DataFrame."""
228
- if not input_data:
229
- raise ValueError("Empty input data is not supported!")
230
-
231
- # Decode bytes to string if necessary
232
- if isinstance(input_data, bytes):
233
- input_data = input_data.decode("utf-8")
234
-
235
- if "text/csv" in content_type:
236
- return pd.read_csv(StringIO(input_data))
237
- elif "application/json" in content_type:
238
- return pd.DataFrame(json.loads(input_data)) # Assumes JSON array of records
239
- else:
240
- raise ValueError(f"{content_type} not supported!")
241
-
242
-
243
- def output_fn(output_df, accept_type):
244
- """Supports both CSV and JSON output formats."""
245
- if "text/csv" in accept_type:
246
- csv_output = output_df.fillna("N/A").to_csv(index=False) # CSV with N/A for missing values
247
- return csv_output, "text/csv"
248
- elif "application/json" in accept_type:
249
- return output_df.to_json(orient="records"), "application/json" # JSON array of records (NaNs -> null)
250
- else:
251
- raise RuntimeError(f"{accept_type} accept type is not supported by this script.")
252
-
253
-
254
- def predict_fn(df, model) -> pd.DataFrame:
255
- """Make Predictions with our PyTorch Tabular Model
256
-
257
- Args:
258
- df (pd.DataFrame): The input DataFrame
259
- model: The TabularModel use for predictions
260
-
261
- Returns:
262
- pd.DataFrame: The DataFrame with the predictions added
263
- """
264
- compressed_features = TEMPLATE_PARAMS["compressed_features"]
265
-
266
- # Grab our feature columns (from training)
267
- model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
268
- with open(os.path.join(model_dir, "feature_columns.json")) as fp:
269
- features = json.load(fp)
270
- print(f"Model Features: {features}")
271
-
272
- # Load the category mappings (from training)
273
- with open(os.path.join(model_dir, "category_mappings.json")) as fp:
274
- category_mappings = json.load(fp)
275
-
276
- # Load our Label Encoder if we have one
277
- label_encoder = None
278
- if os.path.exists(os.path.join(model_dir, "label_encoder.joblib")):
279
- label_encoder = joblib.load(os.path.join(model_dir, "label_encoder.joblib"))
280
-
281
- # We're going match features in a case-insensitive manner, accounting for all the permutations
282
- # - Model has a feature list that's any case ("Id", "taCos", "cOunT", "likes_tacos")
283
- # - Incoming data has columns that are mixed case ("ID", "Tacos", "Count", "Likes_Tacos")
284
- matched_df = match_features_case_insensitive(df, features)
285
-
286
- # Detect categorical types in the incoming DataFrame
287
- matched_df, _ = convert_categorical_types(matched_df, features, category_mappings)
288
-
289
- # If we have compressed features, decompress them
290
- if compressed_features:
291
- print("Decompressing features for prediction...")
292
- matched_df, features = decompress_features(matched_df, features, compressed_features)
293
-
294
- # Make predictions using the TabularModel
295
- result = model.predict(matched_df[features])
296
-
297
- # pytorch-tabular returns predictions using f"{target}_prediction" column
298
- # and classification probabilities in columns ending with "_probability"
299
- target = TEMPLATE_PARAMS["target_column"]
300
- prediction_column = f"{target}_prediction"
301
- if prediction_column in result.columns:
302
- predictions = result[prediction_column].values
303
- else:
304
- raise ValueError(f"Cannot find prediction column in: {result.columns.tolist()}")
305
-
306
- # If we have a label encoder, decode the predictions
307
- if label_encoder:
308
- predictions = label_encoder.inverse_transform(predictions.astype(int))
309
-
310
- # Set the predictions on the DataFrame
311
- df["prediction"] = predictions
312
-
313
- # For classification, get probabilities
314
- if label_encoder is not None:
315
- prob_cols = [col for col in result.columns if col.endswith("_probability")]
316
- if prob_cols:
317
- probs = result[prob_cols].values
318
- df["pred_proba"] = [p.tolist() for p in probs]
319
-
320
- # Expand the pred_proba column into separate columns for each class
321
- df = expand_proba_column(df, label_encoder.classes_)
322
-
323
- # All done, return the DataFrame with new columns for the predictions
324
- return df
325
-
326
-
327
- if __name__ == "__main__":
328
- """The main function is for training the PyTorch Tabular model"""
329
-
330
- # Harness Template Parameters
331
- target = TEMPLATE_PARAMS["target_column"]
332
- features = TEMPLATE_PARAMS["features"]
333
- orig_features = features.copy()
334
- compressed_features = TEMPLATE_PARAMS["compressed_features"]
335
- model_type = TEMPLATE_PARAMS["model_type"]
336
- model_metrics_s3_path = TEMPLATE_PARAMS["model_metrics_s3_path"]
337
- train_all_data = TEMPLATE_PARAMS["train_all_data"]
338
- hyperparameters = TEMPLATE_PARAMS["hyperparameters"]
339
- validation_split = 0.2
340
-
341
- # Script arguments for input/output directories
342
- parser = argparse.ArgumentParser()
343
- parser.add_argument("--model-dir", type=str, default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model"))
344
- parser.add_argument("--train", type=str, default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"))
345
- parser.add_argument(
346
- "--output-data-dir", type=str, default=os.environ.get("SM_OUTPUT_DATA_DIR", "/opt/ml/output/data")
347
- )
348
- args = parser.parse_args()
349
-
350
- # Read the training data into DataFrames
351
- training_files = [
352
- os.path.join(args.train, file)
353
- for file in os.listdir(args.train)
354
- if file.endswith(".csv")
355
- ]
356
- print(f"Training Files: {training_files}")
357
-
358
- # Combine files and read them all into a single pandas dataframe
359
- all_df = pd.concat([pd.read_csv(file, engine="python") for file in training_files])
360
-
361
- # Check if the dataframe is empty
362
- check_dataframe(all_df, "training_df")
363
-
364
- # Features/Target output
365
- print(f"Target: {target}")
366
- print(f"Features: {str(features)}")
367
-
368
- # Convert any features that might be categorical to 'category' type
369
- all_df, category_mappings = convert_categorical_types(all_df, features)
370
-
371
- # If we have compressed features, decompress them
372
- if compressed_features:
373
- print(f"Decompressing features {compressed_features}...")
374
- all_df, features = decompress_features(all_df, features, compressed_features)
375
-
376
- # Do we want to train on all the data?
377
- if train_all_data:
378
- print("Training on ALL of the data")
379
- df_train = all_df.copy()
380
- df_val = all_df.copy()
381
-
382
- # Does the dataframe have a training column?
383
- elif "training" in all_df.columns:
384
- print("Found training column, splitting data based on training column")
385
- df_train = all_df[all_df["training"]]
386
- df_val = all_df[~all_df["training"]]
387
- else:
388
- # Just do a random training Split
389
- print("WARNING: No training column found, splitting data with random state=42")
390
- df_train, df_val = train_test_split(all_df, test_size=validation_split, random_state=42)
391
- print(f"FIT/TRAIN: {df_train.shape}")
392
- print(f"VALIDATION: {df_val.shape}")
393
-
394
- # Determine categorical and continuous columns
395
- categorical_cols = [col for col in features if df_train[col].dtype.name == "category"]
396
- continuous_cols = [col for col in features if col not in categorical_cols]
397
-
398
- print(f"Categorical columns: {categorical_cols}")
399
- print(f"Continuous columns: {continuous_cols}")
400
-
401
- # Set up PyTorch Tabular configuration
402
- data_config = DataConfig(
403
- target=[target],
404
- continuous_cols=continuous_cols,
405
- categorical_cols=categorical_cols,
406
- )
407
-
408
- # Choose the 'task' based on model type also set up the label encoder if needed
409
- if model_type == "classifier":
410
- task = "classification"
411
- # Encode the target column
412
- label_encoder = LabelEncoder()
413
- df_train[target] = label_encoder.fit_transform(df_train[target])
414
- df_val[target] = label_encoder.transform(df_val[target])
415
- else:
416
- task = "regression"
417
- label_encoder = None
418
-
419
- # Use any hyperparameters to set up both the trainer and model configurations
420
- print(f"Hyperparameters: {hyperparameters}")
421
-
422
- # Set up PyTorch Tabular configuration with defaults
423
- trainer_defaults = {
424
- "auto_lr_find": True,
425
- "batch_size": min(1024, max(32, len(df_train) // 4)),
426
- "max_epochs": 100,
427
- "early_stopping": "valid_loss",
428
- "early_stopping_patience": 15,
429
- "checkpoints": "valid_loss",
430
- "accelerator": "auto",
431
- "progress_bar": "none",
432
- "gradient_clip_val": 1.0,
433
- }
434
-
435
- # Override defaults with training_config if present
436
- training_overrides = {k: v for k, v in hyperparameters.get('training_config', {}).items()
437
- if k in trainer_defaults}
438
- # Print overwrites
439
- for key, value in training_overrides.items():
440
- print(f"TRAINING CONFIG Override: {key}: {trainer_defaults[key]} → {value}")
441
- trainer_params = {**trainer_defaults, **training_overrides}
442
- trainer_config = TrainerConfig(**trainer_params)
443
-
444
- # Model config defaults
445
- model_defaults = {
446
- "layers": "1024-512-512",
447
- "activation": "ReLU",
448
- "learning_rate": 1e-3,
449
- "dropout": 0.1,
450
- "use_batch_norm": True,
451
- "initialization": "kaiming",
452
- }
453
- # Override defaults with model_config if present
454
- model_overrides = {k: v for k, v in hyperparameters.get('model_config', {}).items()
455
- if k in model_defaults}
456
- # Print overwrites
457
- for key, value in model_overrides.items():
458
- print(f"MODEL CONFIG Override: {key}: {model_defaults[key]} → {value}")
459
- model_params = {**model_defaults, **model_overrides}
460
-
461
- # Use CategoryEmbedding model configuration for general-purpose tabular modeling.
462
- # Works effectively for both regression and classification as the foundational
463
- # architecture in PyTorch Tabular
464
- model_config = CategoryEmbeddingModelConfig(
465
- task=task,
466
- **model_params
467
- )
468
- optimizer_config = OptimizerConfig()
469
-
470
- #####################################
471
- # Create and train the TabularModel #
472
- #####################################
473
- tabular_model = TabularModel(
474
- data_config=data_config,
475
- model_config=model_config,
476
- optimizer_config=optimizer_config,
477
- trainer_config=trainer_config,
478
- )
479
- tabular_model.fit(train=df_train, validation=df_val)
480
-
481
- # Make Predictions on the Validation Set
482
- print("Making Predictions on Validation Set...")
483
- result = tabular_model.predict(df_val, include_input_features=False)
484
-
485
- # pytorch-tabular returns predictions using f"{target}_prediction" column
486
- # and classification probabilities in columns ending with "_probability"
487
- if model_type == "classifier":
488
- preds = result[f"{target}_prediction"].values
489
- else:
490
- # Regression: use the target column name
491
- preds = result[f"{target}_prediction"].values
492
-
493
- if model_type == "classifier":
494
- # Get probabilities for classification
495
- print("Processing Probabilities...")
496
- prob_cols = [col for col in result.columns if col.endswith("_probability")]
497
- if prob_cols:
498
- probs = result[prob_cols].values
499
- df_val["pred_proba"] = [p.tolist() for p in probs]
500
-
501
- # Expand the pred_proba column into separate columns for each class
502
- print(df_val.columns)
503
- df_val = expand_proba_column(df_val, label_encoder.classes_)
504
- print(df_val.columns)
505
-
506
- # Decode the target and prediction labels
507
- y_validate = label_encoder.inverse_transform(df_val[target])
508
- preds = label_encoder.inverse_transform(preds.astype(int))
509
- else:
510
- y_validate = df_val[target].values
511
-
512
- # Save predictions to S3 (just the target, prediction, and '_probability' columns)
513
- df_val["prediction"] = preds
514
- output_columns = [target, "prediction"]
515
- output_columns += [col for col in df_val.columns if col.endswith("_probability")]
516
- wr.s3.to_csv(
517
- df_val[output_columns],
518
- path=f"{model_metrics_s3_path}/validation_predictions.csv",
519
- index=False,
520
- )
521
-
522
- # Report Performance Metrics
523
- if model_type == "classifier":
524
- # Get the label names and their integer mapping
525
- label_names = label_encoder.classes_
526
-
527
- # Calculate various model performance metrics
528
- scores = precision_recall_fscore_support(y_validate, preds, average=None, labels=label_names)
529
-
530
- # Put the scores into a dataframe
531
- score_df = pd.DataFrame(
532
- {
533
- target: label_names,
534
- "precision": scores[0],
535
- "recall": scores[1],
536
- "fscore": scores[2],
537
- "support": scores[3],
538
- }
539
- )
540
-
541
- # We need to get creative with the Classification Metrics
542
- metrics = ["precision", "recall", "fscore", "support"]
543
- for t in label_names:
544
- for m in metrics:
545
- value = score_df.loc[score_df[target] == t, m].iloc[0]
546
- print(f"Metrics:{t}:{m} {value}")
547
-
548
- # Compute and output the confusion matrix
549
- conf_mtx = confusion_matrix(y_validate, preds, labels=label_names)
550
- for i, row_name in enumerate(label_names):
551
- for j, col_name in enumerate(label_names):
552
- value = conf_mtx[i, j]
553
- print(f"ConfusionMatrix:{row_name}:{col_name} {value}")
554
-
555
- else:
556
- # Calculate various model performance metrics (regression)
557
- rmse = root_mean_squared_error(y_validate, preds)
558
- mae = mean_absolute_error(y_validate, preds)
559
- r2 = r2_score(y_validate, preds)
560
- print(f"RMSE: {rmse:.3f}")
561
- print(f"MAE: {mae:.3f}")
562
- print(f"R2: {r2:.3f}")
563
- print(f"NumRows: {len(df_val)}")
564
-
565
- # Save the model to the standard place/name
566
- tabular_model.save_model(os.path.join(args.model_dir, "tabular_model"))
567
- if label_encoder:
568
- joblib.dump(label_encoder, os.path.join(args.model_dir, "label_encoder.joblib"))
569
-
570
- # Save the features (this will validate input during predictions)
571
- with open(os.path.join(args.model_dir, "feature_columns.json"), "w") as fp:
572
- json.dump(orig_features, fp) # We save the original features, not the decompressed ones
573
-
574
- # Save the category mappings
575
- with open(os.path.join(args.model_dir, "category_mappings.json"), "w") as fp:
576
- json.dump(category_mappings, fp)