workbench 0.8.202__py3-none-any.whl → 0.8.220__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of workbench might be problematic. Click here for more details.
- workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
- workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
- workbench/algorithms/dataframe/fingerprint_proximity.py +421 -85
- workbench/algorithms/dataframe/projection_2d.py +44 -21
- workbench/algorithms/dataframe/proximity.py +78 -150
- workbench/algorithms/graph/light/proximity_graph.py +5 -5
- workbench/algorithms/models/cleanlab_model.py +382 -0
- workbench/algorithms/models/noise_model.py +388 -0
- workbench/algorithms/sql/outliers.py +3 -3
- workbench/api/__init__.py +3 -0
- workbench/api/df_store.py +17 -108
- workbench/api/endpoint.py +13 -11
- workbench/api/feature_set.py +111 -8
- workbench/api/meta_model.py +289 -0
- workbench/api/model.py +45 -12
- workbench/api/parameter_store.py +3 -52
- workbench/cached/cached_model.py +4 -4
- workbench/core/artifacts/artifact.py +5 -5
- workbench/core/artifacts/df_store_core.py +114 -0
- workbench/core/artifacts/endpoint_core.py +228 -237
- workbench/core/artifacts/feature_set_core.py +185 -230
- workbench/core/artifacts/model_core.py +34 -26
- workbench/core/artifacts/parameter_store_core.py +98 -0
- workbench/core/pipelines/pipeline_executor.py +1 -1
- workbench/core/transforms/features_to_model/features_to_model.py +22 -10
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +41 -10
- workbench/core/transforms/pandas_transforms/pandas_to_features.py +11 -2
- workbench/model_script_utils/model_script_utils.py +339 -0
- workbench/model_script_utils/pytorch_utils.py +405 -0
- workbench/model_script_utils/uq_harness.py +278 -0
- workbench/model_scripts/chemprop/chemprop.template +428 -631
- workbench/model_scripts/chemprop/generated_model_script.py +432 -635
- workbench/model_scripts/chemprop/model_script_utils.py +339 -0
- workbench/model_scripts/chemprop/requirements.txt +2 -10
- workbench/model_scripts/custom_models/chem_info/fingerprints.py +87 -46
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +6 -6
- workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
- workbench/model_scripts/meta_model/generated_model_script.py +209 -0
- workbench/model_scripts/meta_model/meta_model.template +209 -0
- workbench/model_scripts/pytorch_model/generated_model_script.py +374 -613
- workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
- workbench/model_scripts/pytorch_model/pytorch.template +370 -609
- workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
- workbench/model_scripts/pytorch_model/requirements.txt +1 -1
- workbench/model_scripts/pytorch_model/uq_harness.py +278 -0
- workbench/model_scripts/script_generation.py +6 -5
- workbench/model_scripts/uq_models/generated_model_script.py +65 -422
- workbench/model_scripts/xgb_model/generated_model_script.py +372 -395
- workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
- workbench/model_scripts/xgb_model/uq_harness.py +278 -0
- workbench/model_scripts/xgb_model/xgb_model.template +366 -396
- workbench/repl/workbench_shell.py +0 -5
- workbench/resources/open_source_api.key +1 -1
- workbench/scripts/endpoint_test.py +2 -2
- workbench/scripts/meta_model_sim.py +35 -0
- workbench/scripts/training_test.py +85 -0
- workbench/utils/chem_utils/fingerprints.py +87 -46
- workbench/utils/chem_utils/projections.py +16 -6
- workbench/utils/chemprop_utils.py +36 -655
- workbench/utils/meta_model_simulator.py +499 -0
- workbench/utils/metrics_utils.py +256 -0
- workbench/utils/model_utils.py +192 -54
- workbench/utils/pytorch_utils.py +33 -472
- workbench/utils/shap_utils.py +1 -55
- workbench/utils/xgboost_local_crossfold.py +267 -0
- workbench/utils/xgboost_model_utils.py +49 -356
- workbench/web_interface/components/model_plot.py +7 -1
- workbench/web_interface/components/plugins/model_details.py +30 -68
- workbench/web_interface/components/plugins/scatter_plot.py +4 -8
- {workbench-0.8.202.dist-info → workbench-0.8.220.dist-info}/METADATA +6 -5
- {workbench-0.8.202.dist-info → workbench-0.8.220.dist-info}/RECORD +76 -60
- {workbench-0.8.202.dist-info → workbench-0.8.220.dist-info}/entry_points.txt +2 -0
- workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
- workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -296
- workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
- workbench/model_scripts/custom_models/proximity/proximity.py +0 -410
- workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -377
- workbench/model_scripts/custom_models/uq_models/proximity.py +0 -410
- workbench/model_scripts/uq_models/mapie.template +0 -605
- workbench/model_scripts/uq_models/requirements.txt +0 -1
- {workbench-0.8.202.dist-info → workbench-0.8.220.dist-info}/WHEEL +0 -0
- {workbench-0.8.202.dist-info → workbench-0.8.220.dist-info}/licenses/LICENSE +0 -0
- {workbench-0.8.202.dist-info → workbench-0.8.220.dist-info}/top_level.txt +0 -0
|
@@ -1,417 +1,265 @@
|
|
|
1
|
-
#
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
#
|
|
7
|
-
#
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
from pytorch_tabular.models import CategoryEmbeddingModelConfig
|
|
12
|
-
|
|
13
|
-
# Model Performance Scores
|
|
14
|
-
from sklearn.metrics import (
|
|
15
|
-
mean_absolute_error,
|
|
16
|
-
median_absolute_error,
|
|
17
|
-
r2_score,
|
|
18
|
-
root_mean_squared_error,
|
|
19
|
-
precision_recall_fscore_support,
|
|
20
|
-
confusion_matrix,
|
|
21
|
-
)
|
|
22
|
-
from scipy.stats import spearmanr
|
|
1
|
+
# PyTorch Tabular Model Template for Workbench
|
|
2
|
+
#
|
|
3
|
+
# This template handles both classification and regression models with:
|
|
4
|
+
# - K-fold cross-validation ensemble training (or single train/val split)
|
|
5
|
+
# - Out-of-fold predictions for validation metrics
|
|
6
|
+
# - Categorical feature embedding via TabularMLP
|
|
7
|
+
# - Compressed feature decompression
|
|
8
|
+
#
|
|
9
|
+
# NOTE: Imports are structured to minimize serverless endpoint startup time.
|
|
10
|
+
# Heavy imports (sklearn, awswrangler) are deferred to training time.
|
|
23
11
|
|
|
24
|
-
# Classification Encoder
|
|
25
|
-
from sklearn.preprocessing import LabelEncoder
|
|
26
|
-
|
|
27
|
-
# Scikit Learn Imports
|
|
28
|
-
from sklearn.model_selection import train_test_split, KFold, StratifiedKFold
|
|
29
|
-
|
|
30
|
-
from io import StringIO
|
|
31
12
|
import json
|
|
32
|
-
import
|
|
13
|
+
import os
|
|
14
|
+
|
|
33
15
|
import joblib
|
|
16
|
+
import numpy as np
|
|
34
17
|
import pandas as pd
|
|
18
|
+
import torch
|
|
19
|
+
|
|
20
|
+
from model_script_utils import (
|
|
21
|
+
convert_categorical_types,
|
|
22
|
+
decompress_features,
|
|
23
|
+
expand_proba_column,
|
|
24
|
+
input_fn,
|
|
25
|
+
match_features_case_insensitive,
|
|
26
|
+
output_fn,
|
|
27
|
+
)
|
|
28
|
+
from pytorch_utils import (
|
|
29
|
+
FeatureScaler,
|
|
30
|
+
load_model,
|
|
31
|
+
predict,
|
|
32
|
+
prepare_data,
|
|
33
|
+
)
|
|
34
|
+
from uq_harness import (
|
|
35
|
+
compute_confidence,
|
|
36
|
+
load_uq_models,
|
|
37
|
+
predict_intervals,
|
|
38
|
+
)
|
|
35
39
|
|
|
36
|
-
#
|
|
40
|
+
# =============================================================================
|
|
41
|
+
# Default Hyperparameters
|
|
42
|
+
# =============================================================================
|
|
43
|
+
DEFAULT_HYPERPARAMETERS = {
|
|
44
|
+
# Training parameters
|
|
45
|
+
"n_folds": 5,
|
|
46
|
+
"max_epochs": 200,
|
|
47
|
+
"early_stopping_patience": 30,
|
|
48
|
+
"batch_size": 128,
|
|
49
|
+
# Model architecture (larger capacity - ensemble provides regularization)
|
|
50
|
+
"layers": "512-256-128",
|
|
51
|
+
"learning_rate": 1e-3,
|
|
52
|
+
"dropout": 0.05,
|
|
53
|
+
"use_batch_norm": True,
|
|
54
|
+
# Loss function for regression (L1Loss=MAE, MSELoss=MSE, HuberLoss, SmoothL1Loss)
|
|
55
|
+
"loss": "L1Loss",
|
|
56
|
+
# Random seed
|
|
57
|
+
"seed": 42,
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
# Template parameters (filled in by Workbench)
|
|
37
61
|
TEMPLATE_PARAMS = {
|
|
38
62
|
"model_type": "uq_regressor",
|
|
39
63
|
"target": "udm_asy_res_efflux_ratio",
|
|
40
|
-
"features": ['
|
|
64
|
+
"features": ['chi2v', 'fr_sulfone', 'chi1v', 'bcut2d_logplow', 'fr_piperzine', 'kappa3', 'smr_vsa1', 'slogp_vsa5', 'fr_ketone_topliss', 'fr_sulfonamd', 'fr_imine', 'fr_benzene', 'fr_ester', 'chi2n', 'labuteasa', 'peoe_vsa2', 'smr_vsa6', 'bcut2d_chglo', 'fr_sh', 'peoe_vsa1', 'fr_allylic_oxid', 'chi4n', 'fr_ar_oh', 'fr_nh0', 'fr_term_acetylene', 'slogp_vsa7', 'slogp_vsa4', 'estate_vsa1', 'vsa_estate4', 'numbridgeheadatoms', 'numheterocycles', 'fr_ketone', 'fr_morpholine', 'fr_guanido', 'estate_vsa2', 'numheteroatoms', 'fr_nitro_arom_nonortho', 'fr_piperdine', 'nocount', 'numspiroatoms', 'fr_aniline', 'fr_thiophene', 'slogp_vsa10', 'fr_amide', 'slogp_vsa2', 'fr_epoxide', 'vsa_estate7', 'fr_ar_coo', 'fr_imidazole', 'fr_nitrile', 'fr_oxazole', 'numsaturatedrings', 'fr_pyridine', 'fr_hoccn', 'fr_ndealkylation1', 'numaliphaticheterocycles', 'fr_phenol', 'maxpartialcharge', 'vsa_estate5', 'peoe_vsa13', 'minpartialcharge', 'qed', 'fr_al_oh', 'slogp_vsa11', 'chi0n', 'fr_bicyclic', 'peoe_vsa12', 'fpdensitymorgan1', 'fr_oxime', 'molwt', 'fr_dihydropyridine', 'smr_vsa5', 'peoe_vsa5', 'fr_nitro', 'hallkieralpha', 'heavyatommolwt', 'fr_alkyl_halide', 'peoe_vsa8', 'fr_nhpyrrole', 'fr_isocyan', 'bcut2d_chghi', 'fr_lactam', 'peoe_vsa11', 'smr_vsa9', 'tpsa', 'chi4v', 'slogp_vsa1', 'phi', 'bcut2d_logphi', 'avgipc', 'estate_vsa11', 'fr_coo', 'bcut2d_mwhi', 'numunspecifiedatomstereocenters', 'vsa_estate10', 'estate_vsa8', 'numvalenceelectrons', 'fr_nh2', 'fr_lactone', 'vsa_estate1', 'estate_vsa4', 'numatomstereocenters', 'vsa_estate8', 'fr_para_hydroxylation', 'peoe_vsa3', 'fr_thiazole', 'peoe_vsa10', 'fr_ndealkylation2', 'slogp_vsa12', 'peoe_vsa9', 'maxestateindex', 'fr_quatn', 'smr_vsa7', 'minestateindex', 'numaromaticheterocycles', 'numrotatablebonds', 'fr_ar_nh', 'fr_ether', 'exactmolwt', 'fr_phenol_noorthohbond', 'slogp_vsa3', 'fr_ar_n', 'sps', 'fr_c_o_nocoo', 'bertzct', 'peoe_vsa7', 'slogp_vsa8', 'numradicalelectrons', 'molmr', 'fr_tetrazole', 'numsaturatedcarbocycles', 'bcut2d_mrhi', 'kappa1', 'numamidebonds', 'fpdensitymorgan2', 'smr_vsa8', 'chi1n', 'estate_vsa6', 'fr_barbitur', 'fr_diazo', 'kappa2', 'chi0', 'bcut2d_mrlow', 'balabanj', 'peoe_vsa4', 'numhacceptors', 'fr_sulfide', 'chi3n', 'smr_vsa2', 'fr_al_oh_notert', 'fr_benzodiazepine', 'fr_phos_ester', 'fr_aldehyde', 'fr_coo2', 'estate_vsa5', 'fr_prisulfonamd', 'numaromaticcarbocycles', 'fr_unbrch_alkane', 'fr_urea', 'fr_nitroso', 'smr_vsa10', 'fr_c_s', 'smr_vsa3', 'fr_methoxy', 'maxabspartialcharge', 'slogp_vsa9', 'heavyatomcount', 'fr_azide', 'chi3v', 'smr_vsa4', 'mollogp', 'chi0v', 'fr_aryl_methyl', 'fr_nh1', 'fpdensitymorgan3', 'fr_furan', 'fr_hdrzine', 'fr_arn', 'numaromaticrings', 'vsa_estate3', 'fr_azo', 'fr_halogen', 'estate_vsa9', 'fr_hdrzone', 'numhdonors', 'fr_alkyl_carbamate', 'fr_isothiocyan', 'minabspartialcharge', 'fr_al_coo', 'ringcount', 'chi1', 'estate_vsa7', 'fr_nitro_arom', 'vsa_estate9', 'minabsestateindex', 'maxabsestateindex', 'vsa_estate6', 'estate_vsa10', 'estate_vsa3', 'fr_n_o', 'fr_amidine', 'fr_thiocyan', 'fr_phos_acid', 'fr_c_o', 'fr_imide', 'numaliphaticrings', 'peoe_vsa6', 'vsa_estate2', 'nhohcount', 'numsaturatedheterocycles', 'slogp_vsa6', 'peoe_vsa14', 'fractioncsp3', 'bcut2d_mwlow', 'numaliphaticcarbocycles', 'fr_priamide', 'nacid', 'nbase', 'naromatom', 'narombond', 'sz', 'sm', 'sv', 'sse', 'spe', 'sare', 'sp', 'si', 'mz', 'mm', 'mv', 'mse', 'mpe', 'mare', 'mp', 'mi', 'xch_3d', 'xch_4d', 'xch_5d', 'xch_6d', 'xch_7d', 'xch_3dv', 'xch_4dv', 'xch_5dv', 'xch_6dv', 'xch_7dv', 'xc_3d', 'xc_4d', 'xc_5d', 'xc_6d', 'xc_3dv', 'xc_4dv', 'xc_5dv', 'xc_6dv', 'xpc_4d', 'xpc_5d', 'xpc_6d', 'xpc_4dv', 'xpc_5dv', 'xpc_6dv', 'xp_0d', 'xp_1d', 'xp_2d', 'xp_3d', 'xp_4d', 'xp_5d', 'xp_6d', 'xp_7d', 'axp_0d', 'axp_1d', 'axp_2d', 'axp_3d', 'axp_4d', 'axp_5d', 'axp_6d', 'axp_7d', 'xp_0dv', 'xp_1dv', 'xp_2dv', 'xp_3dv', 'xp_4dv', 'xp_5dv', 'xp_6dv', 'xp_7dv', 'axp_0dv', 'axp_1dv', 'axp_2dv', 'axp_3dv', 'axp_4dv', 'axp_5dv', 'axp_6dv', 'axp_7dv', 'c1sp1', 'c2sp1', 'c1sp2', 'c2sp2', 'c3sp2', 'c1sp3', 'c2sp3', 'c3sp3', 'c4sp3', 'hybratio', 'fcsp3', 'num_stereocenters', 'num_unspecified_stereocenters', 'num_defined_stereocenters', 'num_r_centers', 'num_s_centers', 'num_stereobonds', 'num_e_bonds', 'num_z_bonds', 'stereo_complexity', 'frac_defined_stereo'],
|
|
41
65
|
"id_column": "udm_mol_bat_id",
|
|
42
|
-
"compressed_features": [],
|
|
43
|
-
"model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/caco2-er-reg-
|
|
44
|
-
"hyperparameters": {'n_folds':
|
|
66
|
+
"compressed_features": ['fingerprint'],
|
|
67
|
+
"model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/caco2-er-reg-temporal/training",
|
|
68
|
+
"hyperparameters": {'n_folds': 1},
|
|
45
69
|
}
|
|
46
70
|
|
|
47
71
|
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
Args:
|
|
53
|
-
df (pd.DataFrame): DataFrame to check
|
|
54
|
-
df_name (str): Name of the DataFrame
|
|
55
|
-
"""
|
|
56
|
-
if df.empty:
|
|
57
|
-
msg = f"*** The training data {df_name} has 0 rows! ***STOPPING***"
|
|
58
|
-
print(msg)
|
|
59
|
-
raise ValueError(msg)
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
def expand_proba_column(df: pd.DataFrame, class_labels: list[str]) -> pd.DataFrame:
|
|
63
|
-
"""
|
|
64
|
-
Expands a column in a DataFrame containing a list of probabilities into separate columns.
|
|
65
|
-
|
|
66
|
-
Args:
|
|
67
|
-
df (pd.DataFrame): DataFrame containing a "pred_proba" column
|
|
68
|
-
class_labels (list[str]): List of class labels
|
|
69
|
-
|
|
70
|
-
Returns:
|
|
71
|
-
pd.DataFrame: DataFrame with the "pred_proba" expanded into separate columns
|
|
72
|
-
"""
|
|
73
|
-
proba_column = "pred_proba"
|
|
74
|
-
if proba_column not in df.columns:
|
|
75
|
-
raise ValueError('DataFrame does not contain a "pred_proba" column')
|
|
76
|
-
|
|
77
|
-
# Construct new column names with '_proba' suffix
|
|
78
|
-
proba_splits = [f"{label}_proba" for label in class_labels]
|
|
79
|
-
|
|
80
|
-
# Expand the proba_column into separate columns for each probability
|
|
81
|
-
proba_df = pd.DataFrame(df[proba_column].tolist(), columns=proba_splits)
|
|
82
|
-
|
|
83
|
-
# Drop any proba columns and reset the index in prep for the concat
|
|
84
|
-
df = df.drop(columns=[proba_column] + proba_splits, errors="ignore")
|
|
85
|
-
df = df.reset_index(drop=True)
|
|
86
|
-
|
|
87
|
-
# Concatenate the new columns with the original DataFrame
|
|
88
|
-
df = pd.concat([df, proba_df], axis=1)
|
|
89
|
-
return df
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
def match_features_case_insensitive(df: pd.DataFrame, model_features: list[str]) -> pd.DataFrame:
|
|
93
|
-
"""
|
|
94
|
-
Matches and renames DataFrame columns to match model feature names (case-insensitive).
|
|
95
|
-
Prioritizes exact matches, then case-insensitive matches.
|
|
96
|
-
|
|
97
|
-
Raises ValueError if any model features cannot be matched.
|
|
98
|
-
"""
|
|
99
|
-
df_columns_lower = {col.lower(): col for col in df.columns}
|
|
100
|
-
rename_dict = {}
|
|
101
|
-
missing = []
|
|
102
|
-
for feature in model_features:
|
|
103
|
-
if feature in df.columns:
|
|
104
|
-
continue # Exact match
|
|
105
|
-
elif feature.lower() in df_columns_lower:
|
|
106
|
-
rename_dict[df_columns_lower[feature.lower()]] = feature
|
|
107
|
-
else:
|
|
108
|
-
missing.append(feature)
|
|
109
|
-
|
|
110
|
-
if missing:
|
|
111
|
-
raise ValueError(f"Features not found: {missing}")
|
|
112
|
-
|
|
113
|
-
# Rename the DataFrame columns to match the model features
|
|
114
|
-
return df.rename(columns=rename_dict)
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
def convert_categorical_types(
|
|
118
|
-
df: pd.DataFrame, features: list[str], category_mappings: dict[str, list[str]] | None = None
|
|
119
|
-
) -> tuple[pd.DataFrame, dict[str, list[str]]]:
|
|
120
|
-
"""
|
|
121
|
-
Converts appropriate columns to categorical type with consistent mappings.
|
|
122
|
-
|
|
123
|
-
Args:
|
|
124
|
-
df (pd.DataFrame): The DataFrame to process.
|
|
125
|
-
features (list): List of feature names to consider for conversion.
|
|
126
|
-
category_mappings (dict, optional): Existing category mappings. If None or empty,
|
|
127
|
-
we're in training mode. If populated, we're in
|
|
128
|
-
inference mode.
|
|
129
|
-
|
|
130
|
-
Returns:
|
|
131
|
-
tuple: (processed DataFrame, category mappings dictionary)
|
|
132
|
-
"""
|
|
133
|
-
if category_mappings is None:
|
|
134
|
-
category_mappings = {}
|
|
135
|
-
|
|
136
|
-
# Training mode
|
|
137
|
-
if not category_mappings:
|
|
138
|
-
for col in df.select_dtypes(include=["object", "string"]):
|
|
139
|
-
if col in features and df[col].nunique() < 20:
|
|
140
|
-
print(f"Training mode: Converting {col} to category")
|
|
141
|
-
df[col] = df[col].astype("category")
|
|
142
|
-
category_mappings[col] = df[col].cat.categories.tolist()
|
|
143
|
-
|
|
144
|
-
# Inference mode
|
|
145
|
-
else:
|
|
146
|
-
for col, categories in category_mappings.items():
|
|
147
|
-
if col in df.columns:
|
|
148
|
-
print(f"Inference mode: Applying categorical mapping for {col}")
|
|
149
|
-
df[col] = pd.Categorical(df[col], categories=categories)
|
|
150
|
-
|
|
151
|
-
return df, category_mappings
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
def decompress_features(
|
|
155
|
-
df: pd.DataFrame, features: list[str], compressed_features: list[str]
|
|
156
|
-
) -> tuple[pd.DataFrame, list[str]]:
|
|
157
|
-
"""Prepare features for the model
|
|
158
|
-
|
|
159
|
-
Args:
|
|
160
|
-
df (pd.DataFrame): The features DataFrame
|
|
161
|
-
features (list[str]): Full list of feature names
|
|
162
|
-
compressed_features (list[str]): List of feature names to decompress (bitstrings)
|
|
163
|
-
|
|
164
|
-
Returns:
|
|
165
|
-
pd.DataFrame: DataFrame with the decompressed features
|
|
166
|
-
list[str]: Updated list of feature names after decompression
|
|
167
|
-
|
|
168
|
-
Raises:
|
|
169
|
-
ValueError: If any missing values are found in the specified features
|
|
170
|
-
"""
|
|
171
|
-
# Check for any missing values in the required features
|
|
172
|
-
missing_counts = df[features].isna().sum()
|
|
173
|
-
if missing_counts.any():
|
|
174
|
-
missing_features = missing_counts[missing_counts > 0]
|
|
175
|
-
print(
|
|
176
|
-
f"WARNING: Found missing values in features: {missing_features.to_dict()}. "
|
|
177
|
-
"WARNING: You might want to remove/replace all NaN values before processing."
|
|
178
|
-
)
|
|
179
|
-
|
|
180
|
-
# Make a copy to avoid mutating the original list
|
|
181
|
-
decompressed_features = features.copy()
|
|
182
|
-
|
|
183
|
-
for feature in compressed_features:
|
|
184
|
-
if (feature not in df.columns) or (feature not in decompressed_features):
|
|
185
|
-
print(f"Feature '{feature}' not in the features list, skipping decompression.")
|
|
186
|
-
continue
|
|
187
|
-
|
|
188
|
-
# Remove the feature from the list of features to avoid duplication
|
|
189
|
-
decompressed_features.remove(feature)
|
|
190
|
-
|
|
191
|
-
# Handle all compressed features as bitstrings
|
|
192
|
-
bit_matrix = np.array([list(bitstring) for bitstring in df[feature]], dtype=np.uint8)
|
|
193
|
-
prefix = feature[:3]
|
|
194
|
-
|
|
195
|
-
# Create all new columns at once - avoids fragmentation
|
|
196
|
-
new_col_names = [f"{prefix}_{i}" for i in range(bit_matrix.shape[1])]
|
|
197
|
-
new_df = pd.DataFrame(bit_matrix, columns=new_col_names, index=df.index)
|
|
198
|
-
|
|
199
|
-
# Add to features list
|
|
200
|
-
decompressed_features.extend(new_col_names)
|
|
201
|
-
|
|
202
|
-
# Drop original column and concatenate new ones
|
|
203
|
-
df = df.drop(columns=[feature])
|
|
204
|
-
df = pd.concat([df, new_df], axis=1)
|
|
205
|
-
|
|
206
|
-
return df, decompressed_features
|
|
207
|
-
|
|
208
|
-
|
|
72
|
+
# =============================================================================
|
|
73
|
+
# Model Loading (for SageMaker inference)
|
|
74
|
+
# =============================================================================
|
|
209
75
|
def model_fn(model_dir: str) -> dict:
|
|
210
|
-
"""Load
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
Dictionary with ensemble models and metadata
|
|
217
|
-
"""
|
|
218
|
-
import torch
|
|
219
|
-
from functools import partial
|
|
220
|
-
|
|
221
|
-
# Load ensemble metadata if present
|
|
222
|
-
ensemble_metadata_path = os.path.join(model_dir, "ensemble_metadata.joblib")
|
|
223
|
-
if os.path.exists(ensemble_metadata_path):
|
|
224
|
-
ensemble_metadata = joblib.load(ensemble_metadata_path)
|
|
225
|
-
n_ensemble = ensemble_metadata["n_ensemble"]
|
|
76
|
+
"""Load PyTorch TabularMLP ensemble from the specified directory."""
|
|
77
|
+
# Load ensemble metadata
|
|
78
|
+
metadata_path = os.path.join(model_dir, "ensemble_metadata.joblib")
|
|
79
|
+
if os.path.exists(metadata_path):
|
|
80
|
+
metadata = joblib.load(metadata_path)
|
|
81
|
+
n_ensemble = metadata["n_ensemble"]
|
|
226
82
|
else:
|
|
227
83
|
n_ensemble = 1
|
|
228
84
|
|
|
229
|
-
# Determine
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
# Patch torch.load globally to use map_location (needed for joblib-loaded callbacks)
|
|
233
|
-
# This handles the case where pytorch-tabular loads callbacks.sav via joblib,
|
|
234
|
-
# which internally calls torch.load without map_location
|
|
235
|
-
original_torch_load = torch.load
|
|
236
|
-
torch.load = partial(original_torch_load, map_location=map_location)
|
|
85
|
+
# Determine device
|
|
86
|
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
237
87
|
|
|
238
|
-
#
|
|
239
|
-
original_cwd = os.getcwd()
|
|
88
|
+
# Load ensemble models
|
|
240
89
|
ensemble_models = []
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
if not input_data:
|
|
265
|
-
raise ValueError("Empty input data is not supported!")
|
|
266
|
-
|
|
267
|
-
# Decode bytes to string if necessary
|
|
268
|
-
if isinstance(input_data, bytes):
|
|
269
|
-
input_data = input_data.decode("utf-8")
|
|
270
|
-
|
|
271
|
-
if "text/csv" in content_type:
|
|
272
|
-
return pd.read_csv(StringIO(input_data))
|
|
273
|
-
elif "application/json" in content_type:
|
|
274
|
-
return pd.DataFrame(json.loads(input_data)) # Assumes JSON array of records
|
|
275
|
-
else:
|
|
276
|
-
raise ValueError(f"{content_type} not supported!")
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
def output_fn(output_df: pd.DataFrame, accept_type: str) -> tuple[str, str]:
|
|
280
|
-
"""Supports both CSV and JSON output formats."""
|
|
281
|
-
if "text/csv" in accept_type:
|
|
282
|
-
csv_output = output_df.fillna("N/A").to_csv(index=False)
|
|
283
|
-
return csv_output, "text/csv"
|
|
284
|
-
elif "application/json" in accept_type:
|
|
285
|
-
return output_df.to_json(orient="records"), "application/json"
|
|
286
|
-
else:
|
|
287
|
-
raise RuntimeError(f"{accept_type} accept type is not supported by this script.")
|
|
90
|
+
for i in range(n_ensemble):
|
|
91
|
+
model_path = os.path.join(model_dir, f"model_{i}")
|
|
92
|
+
model = load_model(model_path, device=device)
|
|
93
|
+
ensemble_models.append(model)
|
|
94
|
+
|
|
95
|
+
print(f"Loaded {len(ensemble_models)} model(s)")
|
|
96
|
+
|
|
97
|
+
# Load feature scaler
|
|
98
|
+
scaler = FeatureScaler.load(os.path.join(model_dir, "scaler.joblib"))
|
|
99
|
+
|
|
100
|
+
# Load UQ models (regression only)
|
|
101
|
+
uq_models, uq_metadata = None, None
|
|
102
|
+
uq_path = os.path.join(model_dir, "uq_metadata.json")
|
|
103
|
+
if os.path.exists(uq_path):
|
|
104
|
+
uq_models, uq_metadata = load_uq_models(model_dir)
|
|
105
|
+
|
|
106
|
+
return {
|
|
107
|
+
"ensemble_models": ensemble_models,
|
|
108
|
+
"n_ensemble": n_ensemble,
|
|
109
|
+
"scaler": scaler,
|
|
110
|
+
"uq_models": uq_models,
|
|
111
|
+
"uq_metadata": uq_metadata,
|
|
112
|
+
}
|
|
288
113
|
|
|
289
114
|
|
|
115
|
+
# =============================================================================
|
|
116
|
+
# Inference (for SageMaker inference)
|
|
117
|
+
# =============================================================================
|
|
290
118
|
def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
|
|
291
|
-
"""Make
|
|
292
|
-
|
|
293
|
-
Args:
|
|
294
|
-
df (pd.DataFrame): The input DataFrame
|
|
295
|
-
model_dict: Dictionary containing ensemble models and metadata
|
|
296
|
-
|
|
297
|
-
Returns:
|
|
298
|
-
pd.DataFrame: The DataFrame with predictions (and prediction_std for ensembles)
|
|
299
|
-
"""
|
|
119
|
+
"""Make predictions with PyTorch TabularMLP ensemble."""
|
|
300
120
|
model_type = TEMPLATE_PARAMS["model_type"]
|
|
301
121
|
compressed_features = TEMPLATE_PARAMS["compressed_features"]
|
|
122
|
+
model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
|
|
302
123
|
|
|
303
|
-
#
|
|
124
|
+
# Load artifacts
|
|
304
125
|
ensemble_models = model_dict["ensemble_models"]
|
|
305
|
-
|
|
126
|
+
scaler = model_dict["scaler"]
|
|
127
|
+
uq_models = model_dict.get("uq_models")
|
|
128
|
+
uq_metadata = model_dict.get("uq_metadata")
|
|
306
129
|
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
with open(os.path.join(model_dir, "
|
|
310
|
-
|
|
311
|
-
|
|
130
|
+
with open(os.path.join(model_dir, "feature_columns.json")) as f:
|
|
131
|
+
features = json.load(f)
|
|
132
|
+
with open(os.path.join(model_dir, "category_mappings.json")) as f:
|
|
133
|
+
category_mappings = json.load(f)
|
|
134
|
+
with open(os.path.join(model_dir, "feature_metadata.json")) as f:
|
|
135
|
+
feature_metadata = json.load(f)
|
|
312
136
|
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
category_mappings = json.load(fp)
|
|
137
|
+
continuous_cols = feature_metadata["continuous_cols"]
|
|
138
|
+
categorical_cols = feature_metadata["categorical_cols"]
|
|
316
139
|
|
|
317
|
-
# Load our Label Encoder if we have one
|
|
318
140
|
label_encoder = None
|
|
319
|
-
|
|
320
|
-
if os.path.exists(
|
|
321
|
-
label_encoder = joblib.load(
|
|
141
|
+
encoder_path = os.path.join(model_dir, "label_encoder.joblib")
|
|
142
|
+
if os.path.exists(encoder_path):
|
|
143
|
+
label_encoder = joblib.load(encoder_path)
|
|
322
144
|
|
|
323
|
-
|
|
324
|
-
matched_df = match_features_case_insensitive(df, features)
|
|
145
|
+
print(f"Model Features: {features}")
|
|
325
146
|
|
|
326
|
-
#
|
|
147
|
+
# Prepare features
|
|
148
|
+
matched_df = match_features_case_insensitive(df, features)
|
|
327
149
|
matched_df, _ = convert_categorical_types(matched_df, features, category_mappings)
|
|
328
150
|
|
|
329
|
-
# If we have compressed features, decompress them
|
|
330
151
|
if compressed_features:
|
|
331
152
|
print("Decompressing features for prediction...")
|
|
332
153
|
matched_df, features = decompress_features(matched_df, features, compressed_features)
|
|
333
154
|
|
|
334
|
-
# Track
|
|
155
|
+
# Track missing features
|
|
335
156
|
missing_mask = matched_df[features].isna().any(axis=1)
|
|
336
157
|
if missing_mask.any():
|
|
337
|
-
print(f"Warning: {missing_mask.sum()} rows have missing features
|
|
158
|
+
print(f"Warning: {missing_mask.sum()} rows have missing features")
|
|
338
159
|
|
|
339
|
-
# Initialize
|
|
160
|
+
# Initialize output columns
|
|
340
161
|
df["prediction"] = np.nan
|
|
341
162
|
if model_type in ["regressor", "uq_regressor"]:
|
|
342
163
|
df["prediction_std"] = np.nan
|
|
343
164
|
|
|
344
|
-
|
|
345
|
-
complete_df = matched_df[~missing_mask]
|
|
165
|
+
complete_df = matched_df[~missing_mask].copy()
|
|
346
166
|
if len(complete_df) == 0:
|
|
347
167
|
print("Warning: No complete rows to predict on")
|
|
348
168
|
return df
|
|
349
169
|
|
|
350
|
-
#
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
# Collect predictions from all ensemble members
|
|
355
|
-
all_ensemble_preds = []
|
|
356
|
-
all_ensemble_probs = []
|
|
357
|
-
|
|
358
|
-
for ens_idx, ens_model in enumerate(ensemble_models):
|
|
359
|
-
result = ens_model.predict(complete_df[features])
|
|
360
|
-
|
|
361
|
-
if prediction_column in result.columns:
|
|
362
|
-
ens_preds = result[prediction_column].values
|
|
363
|
-
else:
|
|
364
|
-
raise ValueError(f"Cannot find prediction column in: {result.columns.tolist()}")
|
|
365
|
-
|
|
366
|
-
all_ensemble_preds.append(ens_preds)
|
|
170
|
+
# Prepare data for inference (with standardization)
|
|
171
|
+
x_cont, x_cat, _, _, _ = prepare_data(
|
|
172
|
+
complete_df, continuous_cols, categorical_cols, category_mappings=category_mappings, scaler=scaler
|
|
173
|
+
)
|
|
367
174
|
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
175
|
+
# Collect ensemble predictions
|
|
176
|
+
all_preds = []
|
|
177
|
+
for model in ensemble_models:
|
|
178
|
+
preds = predict(model, x_cont, x_cat)
|
|
179
|
+
all_preds.append(preds)
|
|
373
180
|
|
|
374
|
-
#
|
|
375
|
-
ensemble_preds = np.stack(
|
|
181
|
+
# Aggregate predictions
|
|
182
|
+
ensemble_preds = np.stack(all_preds, axis=0)
|
|
376
183
|
preds = np.mean(ensemble_preds, axis=0)
|
|
377
|
-
preds_std = np.std(ensemble_preds, axis=0)
|
|
184
|
+
preds_std = np.std(ensemble_preds, axis=0)
|
|
378
185
|
|
|
379
|
-
print(f"Inference
|
|
186
|
+
print(f"Inference complete: {len(preds)} predictions, {len(ensemble_models)} ensemble members")
|
|
380
187
|
|
|
381
|
-
# Handle classification vs regression
|
|
382
188
|
if label_encoder is not None:
|
|
383
|
-
#
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
all_proba.loc[~missing_mask] = [p.tolist() for p in avg_probs]
|
|
393
|
-
df["pred_proba"] = all_proba
|
|
394
|
-
|
|
395
|
-
# Expand the pred_proba column into separate columns for each class
|
|
396
|
-
df = expand_proba_column(df, label_encoder.classes_)
|
|
397
|
-
else:
|
|
398
|
-
# No probabilities, use averaged predictions
|
|
399
|
-
predictions = label_encoder.inverse_transform(preds.astype(int))
|
|
189
|
+
# Classification: average probabilities, then argmax
|
|
190
|
+
avg_probs = preds # Already softmax output
|
|
191
|
+
class_preds = np.argmax(avg_probs, axis=1)
|
|
192
|
+
predictions = label_encoder.inverse_transform(class_preds)
|
|
193
|
+
|
|
194
|
+
all_proba = pd.Series([None] * len(df), index=df.index, dtype=object)
|
|
195
|
+
all_proba.loc[~missing_mask] = [p.tolist() for p in avg_probs]
|
|
196
|
+
df["pred_proba"] = all_proba
|
|
197
|
+
df = expand_proba_column(df, label_encoder.classes_)
|
|
400
198
|
else:
|
|
401
|
-
# Regression
|
|
402
|
-
predictions = preds
|
|
403
|
-
df.loc[~missing_mask, "prediction_std"] = preds_std
|
|
199
|
+
# Regression
|
|
200
|
+
predictions = preds.flatten()
|
|
201
|
+
df.loc[~missing_mask, "prediction_std"] = preds_std.flatten()
|
|
202
|
+
|
|
203
|
+
# Add UQ intervals if available
|
|
204
|
+
if uq_models and uq_metadata:
|
|
205
|
+
X_complete = complete_df[features]
|
|
206
|
+
df_complete = df.loc[~missing_mask].copy()
|
|
207
|
+
df_complete["prediction"] = predictions # Set prediction before compute_confidence
|
|
208
|
+
df_complete = predict_intervals(df_complete, X_complete, uq_models, uq_metadata)
|
|
209
|
+
df_complete = compute_confidence(df_complete, uq_metadata["median_interval_width"], "q_10", "q_90")
|
|
210
|
+
# Copy UQ columns back to main dataframe
|
|
211
|
+
for col in df_complete.columns:
|
|
212
|
+
if col.startswith("q_") or col == "confidence":
|
|
213
|
+
df.loc[~missing_mask, col] = df_complete[col].values
|
|
404
214
|
|
|
405
|
-
# Set predictions only for complete rows
|
|
406
215
|
df.loc[~missing_mask, "prediction"] = predictions
|
|
407
|
-
|
|
408
216
|
return df
|
|
409
217
|
|
|
410
218
|
|
|
219
|
+
# =============================================================================
|
|
220
|
+
# Training
|
|
221
|
+
# =============================================================================
|
|
411
222
|
if __name__ == "__main__":
|
|
412
|
-
|
|
223
|
+
# -------------------------------------------------------------------------
|
|
224
|
+
# Training-only imports (deferred to reduce serverless startup time)
|
|
225
|
+
# -------------------------------------------------------------------------
|
|
226
|
+
import argparse
|
|
227
|
+
|
|
228
|
+
import awswrangler as wr
|
|
229
|
+
from sklearn.model_selection import KFold, StratifiedKFold, train_test_split
|
|
230
|
+
from sklearn.preprocessing import LabelEncoder
|
|
231
|
+
|
|
232
|
+
# Enable Tensor Core optimization for GPUs that support it
|
|
233
|
+
torch.set_float32_matmul_precision("medium")
|
|
234
|
+
|
|
235
|
+
from model_script_utils import (
|
|
236
|
+
check_dataframe,
|
|
237
|
+
compute_classification_metrics,
|
|
238
|
+
compute_regression_metrics,
|
|
239
|
+
print_classification_metrics,
|
|
240
|
+
print_confusion_matrix,
|
|
241
|
+
print_regression_metrics,
|
|
242
|
+
)
|
|
243
|
+
from pytorch_utils import (
|
|
244
|
+
create_model,
|
|
245
|
+
save_model,
|
|
246
|
+
train_model,
|
|
247
|
+
)
|
|
248
|
+
from uq_harness import (
|
|
249
|
+
save_uq_models,
|
|
250
|
+
train_uq_models,
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
# -------------------------------------------------------------------------
|
|
254
|
+
# Setup: Parse arguments and load data
|
|
255
|
+
# -------------------------------------------------------------------------
|
|
256
|
+
parser = argparse.ArgumentParser()
|
|
257
|
+
parser.add_argument("--model-dir", type=str, default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model"))
|
|
258
|
+
parser.add_argument("--train", type=str, default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"))
|
|
259
|
+
parser.add_argument("--output-data-dir", type=str, default=os.environ.get("SM_OUTPUT_DATA_DIR", "/opt/ml/output/data"))
|
|
260
|
+
args = parser.parse_args()
|
|
413
261
|
|
|
414
|
-
#
|
|
262
|
+
# Extract template parameters
|
|
415
263
|
target = TEMPLATE_PARAMS["target"]
|
|
416
264
|
features = TEMPLATE_PARAMS["features"]
|
|
417
265
|
orig_features = features.copy()
|
|
@@ -419,341 +267,254 @@ if __name__ == "__main__":
|
|
|
419
267
|
compressed_features = TEMPLATE_PARAMS["compressed_features"]
|
|
420
268
|
model_type = TEMPLATE_PARAMS["model_type"]
|
|
421
269
|
model_metrics_s3_path = TEMPLATE_PARAMS["model_metrics_s3_path"]
|
|
422
|
-
hyperparameters = TEMPLATE_PARAMS["hyperparameters"]
|
|
270
|
+
hyperparameters = {**DEFAULT_HYPERPARAMETERS, **(TEMPLATE_PARAMS["hyperparameters"] or {})}
|
|
423
271
|
|
|
424
|
-
#
|
|
425
|
-
|
|
426
|
-
parser.add_argument("--model-dir", type=str, default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model"))
|
|
427
|
-
parser.add_argument("--train", type=str, default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"))
|
|
428
|
-
parser.add_argument(
|
|
429
|
-
"--output-data-dir", type=str, default=os.environ.get("SM_OUTPUT_DATA_DIR", "/opt/ml/output/data")
|
|
430
|
-
)
|
|
431
|
-
args = parser.parse_args()
|
|
432
|
-
|
|
433
|
-
# Read the training data into DataFrames
|
|
434
|
-
training_files = [os.path.join(args.train, file) for file in os.listdir(args.train) if file.endswith(".csv")]
|
|
272
|
+
# Load training data
|
|
273
|
+
training_files = [os.path.join(args.train, f) for f in os.listdir(args.train) if f.endswith(".csv")]
|
|
435
274
|
print(f"Training Files: {training_files}")
|
|
436
|
-
|
|
437
|
-
# Combine files and read them all into a single pandas dataframe
|
|
438
|
-
all_df = pd.concat([pd.read_csv(file, engine="python") for file in training_files])
|
|
439
|
-
|
|
440
|
-
# Print out some info about the dataframe
|
|
441
|
-
print(f"All Data Shape: {all_df.shape}")
|
|
442
|
-
print(f"Feature dtypes:\n{all_df[features].dtypes.value_counts()}")
|
|
443
|
-
print(f"Int64 columns: {all_df[features].select_dtypes(include=['int64']).columns.tolist()}")
|
|
444
|
-
|
|
445
|
-
# Check if the dataframe is empty
|
|
275
|
+
all_df = pd.concat([pd.read_csv(f, engine="python") for f in training_files])
|
|
446
276
|
check_dataframe(all_df, "training_df")
|
|
447
277
|
|
|
448
|
-
# Drop
|
|
449
|
-
|
|
278
|
+
# Drop rows with missing features
|
|
279
|
+
initial_count = len(all_df)
|
|
450
280
|
all_df = all_df.dropna(subset=features)
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
print(f"Dropped {dropped_rows} rows due to missing feature values.")
|
|
281
|
+
if len(all_df) < initial_count:
|
|
282
|
+
print(f"Dropped {initial_count - len(all_df)} rows with missing features")
|
|
454
283
|
|
|
455
|
-
# Features/Target output
|
|
456
284
|
print(f"Target: {target}")
|
|
457
|
-
print(f"Features: {
|
|
285
|
+
print(f"Features: {features}")
|
|
286
|
+
print(f"Hyperparameters: {hyperparameters}")
|
|
458
287
|
|
|
459
|
-
#
|
|
288
|
+
# -------------------------------------------------------------------------
|
|
289
|
+
# Preprocessing
|
|
290
|
+
# -------------------------------------------------------------------------
|
|
460
291
|
all_df, category_mappings = convert_categorical_types(all_df, features)
|
|
461
292
|
|
|
462
|
-
# Print out some info about the dataframe
|
|
463
|
-
print(f"All Data Shape: {all_df.shape}")
|
|
464
|
-
print(f"Feature dtypes:\n{all_df[features].dtypes.value_counts()}")
|
|
465
|
-
print(f"Int64 columns: {all_df[features].select_dtypes(include=['int64']).columns.tolist()}")
|
|
466
|
-
|
|
467
|
-
# If we have compressed features, decompress them
|
|
468
293
|
if compressed_features:
|
|
469
|
-
print(f"Decompressing features {compressed_features}
|
|
294
|
+
print(f"Decompressing features: {compressed_features}")
|
|
470
295
|
all_df, features = decompress_features(all_df, features, compressed_features)
|
|
471
296
|
|
|
472
|
-
# Determine categorical
|
|
473
|
-
categorical_cols = [
|
|
474
|
-
continuous_cols = [
|
|
475
|
-
print(f"Categorical columns: {categorical_cols}")
|
|
476
|
-
print(f"Continuous columns: {continuous_cols}")
|
|
477
|
-
|
|
478
|
-
# Cast continuous columns to float
|
|
297
|
+
# Determine categorical vs continuous columns
|
|
298
|
+
categorical_cols = [c for c in features if all_df[c].dtype.name == "category"]
|
|
299
|
+
continuous_cols = [c for c in features if c not in categorical_cols]
|
|
479
300
|
all_df[continuous_cols] = all_df[continuous_cols].astype("float64")
|
|
301
|
+
print(f"Categorical: {categorical_cols}")
|
|
302
|
+
print(f"Continuous: {len(continuous_cols)} columns")
|
|
480
303
|
|
|
481
|
-
#
|
|
304
|
+
# -------------------------------------------------------------------------
|
|
305
|
+
# Classification setup
|
|
306
|
+
# -------------------------------------------------------------------------
|
|
307
|
+
label_encoder = None
|
|
308
|
+
n_outputs = 1
|
|
482
309
|
if model_type == "classifier":
|
|
483
|
-
task = "classification"
|
|
484
|
-
# Encode the target column on full dataset for consistent encoding
|
|
485
310
|
label_encoder = LabelEncoder()
|
|
486
311
|
all_df[target] = label_encoder.fit_transform(all_df[target])
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
task = "regression"
|
|
490
|
-
label_encoder = None
|
|
491
|
-
num_classes = None
|
|
312
|
+
n_outputs = len(label_encoder.classes_)
|
|
313
|
+
print(f"Class labels: {label_encoder.classes_.tolist()}")
|
|
492
314
|
|
|
493
|
-
#
|
|
494
|
-
|
|
495
|
-
|
|
315
|
+
# -------------------------------------------------------------------------
|
|
316
|
+
# Cross-validation setup
|
|
317
|
+
# -------------------------------------------------------------------------
|
|
318
|
+
n_folds = hyperparameters["n_folds"]
|
|
319
|
+
task = "classification" if model_type == "classifier" else "regression"
|
|
320
|
+
hidden_layers = [int(x) for x in hyperparameters["layers"].split("-")]
|
|
496
321
|
|
|
497
|
-
#
|
|
498
|
-
|
|
499
|
-
# =========================================================================
|
|
500
|
-
print(f"Training {'single model' if n_folds == 1 else f'{n_folds}-fold cross-validation ensemble'}...")
|
|
322
|
+
# Get categorical cardinalities
|
|
323
|
+
categorical_cardinalities = [len(category_mappings.get(col, {})) for col in categorical_cols]
|
|
501
324
|
|
|
502
|
-
# Create fold splits
|
|
503
325
|
if n_folds == 1:
|
|
504
|
-
# Single fold: use train/val split from "training" column or random split
|
|
505
326
|
if "training" in all_df.columns:
|
|
506
|
-
print("
|
|
327
|
+
print("Using 'training' column for train/val split")
|
|
507
328
|
train_idx = np.where(all_df["training"])[0]
|
|
508
329
|
val_idx = np.where(~all_df["training"])[0]
|
|
509
330
|
else:
|
|
510
|
-
print("WARNING: No training column found,
|
|
511
|
-
|
|
512
|
-
train_idx, val_idx = train_test_split(indices, test_size=0.2, random_state=42)
|
|
331
|
+
print("WARNING: No 'training' column found, using random 80/20 split")
|
|
332
|
+
train_idx, val_idx = train_test_split(np.arange(len(all_df)), test_size=0.2, random_state=42)
|
|
513
333
|
folds = [(train_idx, val_idx)]
|
|
514
334
|
else:
|
|
515
|
-
# K-Fold CV
|
|
516
335
|
if model_type == "classifier":
|
|
517
336
|
kfold = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42)
|
|
518
|
-
|
|
337
|
+
folds = list(kfold.split(all_df, all_df[target]))
|
|
519
338
|
else:
|
|
520
339
|
kfold = KFold(n_splits=n_folds, shuffle=True, random_state=42)
|
|
521
|
-
|
|
522
|
-
folds = list(kfold.split(all_df, split_target))
|
|
523
|
-
|
|
524
|
-
# Initialize storage for out-of-fold predictions
|
|
525
|
-
oof_predictions = np.full(len(all_df), np.nan, dtype=np.float64)
|
|
526
|
-
if model_type == "classifier" and num_classes and num_classes > 1:
|
|
527
|
-
oof_proba = np.full((len(all_df), num_classes), np.nan, dtype=np.float64)
|
|
528
|
-
else:
|
|
529
|
-
oof_proba = None
|
|
340
|
+
folds = list(kfold.split(all_df))
|
|
530
341
|
|
|
531
|
-
|
|
342
|
+
print(f"Training {'single model' if n_folds == 1 else f'{n_folds}-fold ensemble'}...")
|
|
532
343
|
|
|
533
|
-
#
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
categorical_cols=categorical_cols,
|
|
538
|
-
)
|
|
344
|
+
# Fit scaler on all training data (used across all folds)
|
|
345
|
+
scaler = FeatureScaler()
|
|
346
|
+
scaler.fit(all_df, continuous_cols)
|
|
347
|
+
print(f"Fitted scaler on {len(continuous_cols)} continuous features")
|
|
539
348
|
|
|
540
|
-
#
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
"activation": "LeakyReLU",
|
|
544
|
-
"learning_rate": 1e-3,
|
|
545
|
-
"dropout": 0.1,
|
|
546
|
-
"use_batch_norm": True,
|
|
547
|
-
"initialization": "kaiming",
|
|
548
|
-
}
|
|
549
|
-
# Override defaults with model_config if present
|
|
550
|
-
model_overrides = {k: v for k, v in hyperparameters.get("model_config", {}).items() if k in model_defaults}
|
|
551
|
-
for key, value in model_overrides.items():
|
|
552
|
-
print(f"MODEL CONFIG Override: {key}: {model_defaults[key]} → {value}")
|
|
553
|
-
model_params = {**model_defaults, **model_overrides}
|
|
349
|
+
# Determine device
|
|
350
|
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
351
|
+
print(f"Using device: {device}")
|
|
554
352
|
|
|
555
|
-
|
|
556
|
-
|
|
353
|
+
# -------------------------------------------------------------------------
|
|
354
|
+
# Training loop
|
|
355
|
+
# -------------------------------------------------------------------------
|
|
356
|
+
oof_predictions = np.full((len(all_df), n_outputs), np.nan, dtype=np.float64)
|
|
557
357
|
|
|
358
|
+
ensemble_models = []
|
|
558
359
|
for fold_idx, (train_idx, val_idx) in enumerate(folds):
|
|
559
360
|
print(f"\n{'='*50}")
|
|
560
|
-
print(f"
|
|
361
|
+
print(f"Fold {fold_idx + 1}/{len(folds)} - Train: {len(train_idx)}, Val: {len(val_idx)}")
|
|
561
362
|
print(f"{'='*50}")
|
|
562
363
|
|
|
563
|
-
# Split data for this fold
|
|
564
364
|
df_train = all_df.iloc[train_idx].reset_index(drop=True)
|
|
565
365
|
df_val = all_df.iloc[val_idx].reset_index(drop=True)
|
|
566
366
|
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
batch_size += 1 # Adjust to avoid last batch of size 1
|
|
574
|
-
trainer_defaults = {
|
|
575
|
-
"auto_lr_find": False,
|
|
576
|
-
"batch_size": batch_size,
|
|
577
|
-
"max_epochs": 200,
|
|
578
|
-
"min_epochs": 10,
|
|
579
|
-
"early_stopping": "valid_loss",
|
|
580
|
-
"early_stopping_patience": 20,
|
|
581
|
-
"checkpoints": "valid_loss",
|
|
582
|
-
"accelerator": "auto",
|
|
583
|
-
"progress_bar": "none",
|
|
584
|
-
"gradient_clip_val": 1.0,
|
|
585
|
-
"seed": 42 + fold_idx,
|
|
586
|
-
}
|
|
587
|
-
|
|
588
|
-
# Override defaults with training_config if present
|
|
589
|
-
training_overrides = {k: v for k, v in hyperparameters.get("training_config", {}).items() if k in trainer_defaults}
|
|
590
|
-
if fold_idx == 0: # Only print overrides once
|
|
591
|
-
for key, value in training_overrides.items():
|
|
592
|
-
print(f"TRAINING CONFIG Override: {key}: {trainer_defaults[key]} → {value}")
|
|
593
|
-
trainer_params = {**trainer_defaults, **training_overrides}
|
|
594
|
-
trainer_config = TrainerConfig(**trainer_params)
|
|
595
|
-
|
|
596
|
-
# Create and train the TabularModel for this fold
|
|
597
|
-
tabular_model = TabularModel(
|
|
598
|
-
data_config=data_config,
|
|
599
|
-
model_config=model_config,
|
|
600
|
-
optimizer_config=optimizer_config,
|
|
601
|
-
trainer_config=trainer_config,
|
|
367
|
+
# Prepare data (using pre-fitted scaler)
|
|
368
|
+
train_x_cont, train_x_cat, train_y, _, _ = prepare_data(
|
|
369
|
+
df_train, continuous_cols, categorical_cols, target, category_mappings, scaler=scaler
|
|
370
|
+
)
|
|
371
|
+
val_x_cont, val_x_cat, val_y, _, _ = prepare_data(
|
|
372
|
+
df_val, continuous_cols, categorical_cols, target, category_mappings, scaler=scaler
|
|
602
373
|
)
|
|
603
|
-
tabular_model.fit(train=df_train, validation=df_val)
|
|
604
|
-
ensemble_models.append(tabular_model)
|
|
605
374
|
|
|
606
|
-
#
|
|
607
|
-
|
|
608
|
-
|
|
375
|
+
# Create model
|
|
376
|
+
torch.manual_seed(hyperparameters["seed"] + fold_idx)
|
|
377
|
+
model = create_model(
|
|
378
|
+
n_continuous=len(continuous_cols),
|
|
379
|
+
categorical_cardinalities=categorical_cardinalities,
|
|
380
|
+
hidden_layers=hidden_layers,
|
|
381
|
+
n_outputs=n_outputs,
|
|
382
|
+
task=task,
|
|
383
|
+
dropout=hyperparameters["dropout"],
|
|
384
|
+
use_batch_norm=hyperparameters["use_batch_norm"],
|
|
385
|
+
)
|
|
609
386
|
|
|
610
|
-
#
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
387
|
+
# Train
|
|
388
|
+
model, history = train_model(
|
|
389
|
+
model,
|
|
390
|
+
train_x_cont, train_x_cat, train_y,
|
|
391
|
+
val_x_cont, val_x_cat, val_y,
|
|
392
|
+
task=task,
|
|
393
|
+
max_epochs=hyperparameters["max_epochs"],
|
|
394
|
+
patience=hyperparameters["early_stopping_patience"],
|
|
395
|
+
batch_size=hyperparameters["batch_size"],
|
|
396
|
+
learning_rate=hyperparameters["learning_rate"],
|
|
397
|
+
loss=hyperparameters.get("loss", "L1Loss"),
|
|
398
|
+
device=device,
|
|
399
|
+
)
|
|
400
|
+
ensemble_models.append(model)
|
|
618
401
|
|
|
619
|
-
|
|
402
|
+
# Out-of-fold predictions
|
|
403
|
+
fold_preds = predict(model, val_x_cont, val_x_cat)
|
|
404
|
+
oof_predictions[val_idx] = fold_preds
|
|
620
405
|
|
|
621
406
|
print(f"\nTraining complete! Trained {len(ensemble_models)} model(s).")
|
|
622
407
|
|
|
623
|
-
#
|
|
624
|
-
#
|
|
408
|
+
# -------------------------------------------------------------------------
|
|
409
|
+
# Prepare validation results
|
|
410
|
+
# -------------------------------------------------------------------------
|
|
625
411
|
if n_folds == 1:
|
|
626
|
-
val_mask = ~np.isnan(oof_predictions)
|
|
627
|
-
preds = oof_predictions[val_mask]
|
|
412
|
+
val_mask = ~np.isnan(oof_predictions[:, 0])
|
|
628
413
|
df_val = all_df[val_mask].copy()
|
|
629
|
-
|
|
630
|
-
oof_proba = oof_proba[val_mask]
|
|
414
|
+
predictions = oof_predictions[val_mask]
|
|
631
415
|
else:
|
|
632
|
-
preds = oof_predictions
|
|
633
416
|
df_val = all_df.copy()
|
|
417
|
+
predictions = oof_predictions
|
|
418
|
+
|
|
419
|
+
# Decode labels for classification
|
|
420
|
+
if model_type == "classifier":
|
|
421
|
+
class_preds = np.argmax(predictions, axis=1)
|
|
422
|
+
df_val[target] = label_encoder.inverse_transform(df_val[target].astype(int))
|
|
423
|
+
df_val["prediction"] = label_encoder.inverse_transform(class_preds)
|
|
424
|
+
df_val["pred_proba"] = [p.tolist() for p in predictions]
|
|
425
|
+
df_val = expand_proba_column(df_val, label_encoder.classes_)
|
|
426
|
+
else:
|
|
427
|
+
df_val["prediction"] = predictions.flatten()
|
|
634
428
|
|
|
635
|
-
#
|
|
636
|
-
#
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
all_ensemble_preds_for_std = []
|
|
641
|
-
for ens_model in ensemble_models:
|
|
642
|
-
result = ens_model.predict(df_val[features], include_input_features=False)
|
|
643
|
-
ens_preds = result[f"{target}_prediction"].values.flatten()
|
|
644
|
-
all_ensemble_preds_for_std.append(ens_preds)
|
|
645
|
-
|
|
646
|
-
ensemble_preds_stacked = np.stack(all_ensemble_preds_for_std, axis=0)
|
|
647
|
-
preds_std = np.std(ensemble_preds_stacked, axis=0)
|
|
648
|
-
print(f"Ensemble prediction_std - mean: {np.mean(preds_std):.4f}, max: {np.max(preds_std):.4f}")
|
|
429
|
+
# -------------------------------------------------------------------------
|
|
430
|
+
# Compute and print metrics
|
|
431
|
+
# -------------------------------------------------------------------------
|
|
432
|
+
y_true = df_val[target].values
|
|
433
|
+
y_pred = df_val["prediction"].values
|
|
649
434
|
|
|
650
435
|
if model_type == "classifier":
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
df_val["pred_proba"] = [p.tolist() for p in oof_proba]
|
|
655
|
-
df_val = expand_proba_column(df_val, label_encoder.classes_)
|
|
656
|
-
|
|
657
|
-
# Decode the target and prediction labels
|
|
658
|
-
y_validate = label_encoder.inverse_transform(df_val[target])
|
|
659
|
-
preds_decoded = label_encoder.inverse_transform(preds.astype(int))
|
|
436
|
+
score_df = compute_classification_metrics(y_true, y_pred, label_encoder.classes_, target)
|
|
437
|
+
print_classification_metrics(score_df, target, label_encoder.classes_)
|
|
438
|
+
print_confusion_matrix(y_true, y_pred, label_encoder.classes_)
|
|
660
439
|
else:
|
|
661
|
-
|
|
662
|
-
|
|
440
|
+
metrics = compute_regression_metrics(y_true, y_pred)
|
|
441
|
+
print_regression_metrics(metrics)
|
|
442
|
+
|
|
443
|
+
# Compute ensemble prediction_std
|
|
444
|
+
if n_folds > 1:
|
|
445
|
+
# Re-run inference with all models to get std
|
|
446
|
+
x_cont, x_cat, _, _, _ = prepare_data(
|
|
447
|
+
df_val, continuous_cols, categorical_cols, category_mappings=category_mappings, scaler=scaler
|
|
448
|
+
)
|
|
449
|
+
all_preds = [predict(m, x_cont, x_cat).flatten() for m in ensemble_models]
|
|
450
|
+
df_val["prediction_std"] = np.std(np.stack(all_preds), axis=0)
|
|
451
|
+
print(f"Ensemble std - mean: {df_val['prediction_std'].mean():.4f}, max: {df_val['prediction_std'].max():.4f}")
|
|
452
|
+
else:
|
|
453
|
+
df_val["prediction_std"] = 0.0
|
|
663
454
|
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
455
|
+
# Train UQ models for uncertainty quantification
|
|
456
|
+
print("\n" + "=" * 50)
|
|
457
|
+
print("Training UQ Models")
|
|
458
|
+
print("=" * 50)
|
|
459
|
+
uq_models, uq_metadata = train_uq_models(
|
|
460
|
+
all_df[features], all_df[target], df_val[features], y_true
|
|
461
|
+
)
|
|
462
|
+
df_val = predict_intervals(df_val, df_val[features], uq_models, uq_metadata)
|
|
463
|
+
df_val = compute_confidence(df_val, uq_metadata["median_interval_width"])
|
|
667
464
|
|
|
668
|
-
#
|
|
465
|
+
# -------------------------------------------------------------------------
|
|
466
|
+
# Save validation predictions to S3
|
|
467
|
+
# -------------------------------------------------------------------------
|
|
669
468
|
output_columns = []
|
|
670
469
|
if id_column in df_val.columns:
|
|
671
470
|
output_columns.append(id_column)
|
|
672
471
|
output_columns += [target, "prediction"]
|
|
673
472
|
|
|
674
|
-
|
|
675
|
-
if model_type in ["regressor", "uq_regressor"]:
|
|
676
|
-
if preds_std is not None:
|
|
677
|
-
df_val["prediction_std"] = preds_std
|
|
678
|
-
else:
|
|
679
|
-
df_val["prediction_std"] = 0.0
|
|
473
|
+
if model_type != "classifier":
|
|
680
474
|
output_columns.append("prediction_std")
|
|
681
|
-
|
|
475
|
+
output_columns += [c for c in df_val.columns if c.startswith("q_") or c == "confidence"]
|
|
476
|
+
|
|
477
|
+
output_columns += [c for c in df_val.columns if c.endswith("_proba")]
|
|
478
|
+
|
|
479
|
+
wr.s3.to_csv(df_val[output_columns], f"{model_metrics_s3_path}/validation_predictions.csv", index=False)
|
|
480
|
+
|
|
481
|
+
# -------------------------------------------------------------------------
|
|
482
|
+
# Save model artifacts
|
|
483
|
+
# -------------------------------------------------------------------------
|
|
484
|
+
model_config = {
|
|
485
|
+
"n_continuous": len(continuous_cols),
|
|
486
|
+
"categorical_cardinalities": categorical_cardinalities,
|
|
487
|
+
"hidden_layers": hidden_layers,
|
|
488
|
+
"n_outputs": n_outputs,
|
|
489
|
+
"task": task,
|
|
490
|
+
"dropout": hyperparameters["dropout"],
|
|
491
|
+
"use_batch_norm": hyperparameters["use_batch_norm"],
|
|
492
|
+
}
|
|
682
493
|
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
path=f"{model_metrics_s3_path}/validation_predictions.csv",
|
|
687
|
-
index=False,
|
|
688
|
-
)
|
|
494
|
+
for idx, m in enumerate(ensemble_models):
|
|
495
|
+
save_model(m, os.path.join(args.model_dir, f"model_{idx}"), model_config)
|
|
496
|
+
print(f"Saved {len(ensemble_models)} model(s)")
|
|
689
497
|
|
|
690
|
-
|
|
691
|
-
if model_type == "classifier":
|
|
692
|
-
# Get the label names and their integer mapping
|
|
693
|
-
label_names = label_encoder.classes_
|
|
694
|
-
|
|
695
|
-
# Calculate various model performance metrics
|
|
696
|
-
scores = precision_recall_fscore_support(y_validate, preds_decoded, average=None, labels=label_names)
|
|
697
|
-
|
|
698
|
-
# Put the scores into a dataframe
|
|
699
|
-
score_df = pd.DataFrame(
|
|
700
|
-
{
|
|
701
|
-
target: label_names,
|
|
702
|
-
"precision": scores[0],
|
|
703
|
-
"recall": scores[1],
|
|
704
|
-
"f1": scores[2],
|
|
705
|
-
"support": scores[3],
|
|
706
|
-
}
|
|
707
|
-
)
|
|
498
|
+
joblib.dump({"n_ensemble": len(ensemble_models), "n_folds": n_folds}, os.path.join(args.model_dir, "ensemble_metadata.joblib"))
|
|
708
499
|
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
for t in label_names:
|
|
712
|
-
for m in metrics:
|
|
713
|
-
value = score_df.loc[score_df[target] == t, m].iloc[0]
|
|
714
|
-
print(f"Metrics:{t}:{m} {value}")
|
|
500
|
+
with open(os.path.join(args.model_dir, "feature_columns.json"), "w") as f:
|
|
501
|
+
json.dump(orig_features, f)
|
|
715
502
|
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
for i, row_name in enumerate(label_names):
|
|
719
|
-
for j, col_name in enumerate(label_names):
|
|
720
|
-
value = conf_mtx[i, j]
|
|
721
|
-
print(f"ConfusionMatrix:{row_name}:{col_name} {value}")
|
|
503
|
+
with open(os.path.join(args.model_dir, "category_mappings.json"), "w") as f:
|
|
504
|
+
json.dump(category_mappings, f)
|
|
722
505
|
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
support = len(df_val)
|
|
731
|
-
print(f"rmse: {rmse:.3f}")
|
|
732
|
-
print(f"mae: {mae:.3f}")
|
|
733
|
-
print(f"medae: {medae:.3f}")
|
|
734
|
-
print(f"r2: {r2:.3f}")
|
|
735
|
-
print(f"spearmanr: {spearman_corr:.3f}")
|
|
736
|
-
print(f"support: {support}")
|
|
737
|
-
|
|
738
|
-
# Save ensemble models
|
|
739
|
-
for model_idx, ens_model in enumerate(ensemble_models):
|
|
740
|
-
model_path = os.path.join(args.model_dir, f"tabular_model_{model_idx}")
|
|
741
|
-
ens_model.save_model(model_path)
|
|
742
|
-
print(f"Saved model {model_idx + 1} to {model_path}")
|
|
743
|
-
|
|
744
|
-
# Save ensemble metadata
|
|
745
|
-
n_ensemble = len(ensemble_models)
|
|
746
|
-
ensemble_metadata = {"n_ensemble": n_ensemble, "n_folds": n_folds}
|
|
747
|
-
joblib.dump(ensemble_metadata, os.path.join(args.model_dir, "ensemble_metadata.joblib"))
|
|
748
|
-
print(f"Saved ensemble metadata (n_ensemble={n_ensemble}, n_folds={n_folds})")
|
|
506
|
+
with open(os.path.join(args.model_dir, "feature_metadata.json"), "w") as f:
|
|
507
|
+
json.dump({"continuous_cols": continuous_cols, "categorical_cols": categorical_cols}, f)
|
|
508
|
+
|
|
509
|
+
with open(os.path.join(args.model_dir, "hyperparameters.json"), "w") as f:
|
|
510
|
+
json.dump(hyperparameters, f, indent=2)
|
|
511
|
+
|
|
512
|
+
scaler.save(os.path.join(args.model_dir, "scaler.joblib"))
|
|
749
513
|
|
|
750
514
|
if label_encoder:
|
|
751
515
|
joblib.dump(label_encoder, os.path.join(args.model_dir, "label_encoder.joblib"))
|
|
752
516
|
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
json.dump(orig_features, fp)
|
|
517
|
+
if model_type != "classifier":
|
|
518
|
+
save_uq_models(uq_models, uq_metadata, args.model_dir)
|
|
756
519
|
|
|
757
|
-
|
|
758
|
-
with open(os.path.join(args.model_dir, "category_mappings.json"), "w") as fp:
|
|
759
|
-
json.dump(category_mappings, fp)
|
|
520
|
+
print(f"\nModel training complete! Artifacts saved to {args.model_dir}")
|