dragon-ml-toolbox 1.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

@@ -0,0 +1,701 @@
1
+ import pandas as pd
2
+ import numpy as np
3
+ import seaborn as sns
4
+ import matplotlib.pyplot as plt
5
+ from matplotlib.colors import Colormap
6
+ from matplotlib import rcdefaults
7
+
8
+ import os
9
+ from typing import Literal, Union, Optional
10
+ import joblib
11
+
12
+ from imblearn.over_sampling import ADASYN, SMOTE, RandomOverSampler
13
+ from imblearn.under_sampling import RandomUnderSampler
14
+
15
+ from sklearn.ensemble import HistGradientBoostingClassifier, HistGradientBoostingRegressor
16
+ import xgboost as xgb
17
+ import lightgbm as lgb
18
+
19
+ from sklearn.model_selection import train_test_split
20
+ from sklearn.preprocessing import StandardScaler, MaxAbsScaler, MinMaxScaler
21
+ from sklearn.metrics import accuracy_score, classification_report, ConfusionMatrixDisplay, mean_absolute_error, mean_squared_error, r2_score, roc_curve, roc_auc_score
22
+ import shap
23
+
24
+ import warnings # Ignore warnings
25
+ warnings.filterwarnings('ignore', category=DeprecationWarning)
26
+ warnings.filterwarnings('ignore', category=FutureWarning)
27
+ warnings.filterwarnings('ignore', category=UserWarning)
28
+
29
+
30
+ ###### 1. Dataset Loader ######
31
+ #Load imputed datasets as a generator
32
+ def yield_imputed_dataframe(datasets_dir: str):
33
+ '''
34
+ Yields a tuple `(dataframe, dataframe_name)`
35
+ '''
36
+ dataset_filenames = [dataset for dataset in os.listdir(datasets_dir) if dataset.endswith(".csv")]
37
+ if not dataset_filenames:
38
+ raise IOError(f"No imputed datasets have been found at {datasets_dir}")
39
+
40
+ for dataset_filename in dataset_filenames:
41
+ full_path = os.path.join(datasets_dir, dataset_filename)
42
+ df = pd.read_csv(full_path)
43
+ #remove extension
44
+ filename = os.path.splitext(os.path.basename(dataset_filename))[0]
45
+ print(f"Working on dataset: {filename}")
46
+ yield (df, filename)
47
+
48
+ #Split a dataset into features and targets datasets
49
+ def dataset_yielder(df: pd.DataFrame, target_cols: list[str]):
50
+ '''
51
+ Yields one Tuple at a time: `(df_features, df_target, feature_names, target_name)`
52
+ '''
53
+ df_features = df.drop(columns=target_cols)
54
+ feature_names = df_features.columns.to_list()
55
+
56
+ for target_col in target_cols:
57
+ df_target = df[target_col]
58
+ yield (df_features, df_target, feature_names, target_col)
59
+
60
+ ###### 2. Initialize Models ######
61
+ def get_models(task: Literal["classification", "regression"], random_state: int=101, is_balanced: bool = True,
62
+ L1_regularization: float = 1.0, L2_regularization: float = 1.0, learning_rate: float=0.005) -> dict:
63
+ '''
64
+ Returns a dictionary `{Model_Name: Model}` with new instances of models.
65
+ Valid tasks: "classification" or "regression".
66
+
67
+ Classification Models:
68
+ - "XGBoost" - XGBClassifier
69
+ - "LightGBM" - LGBMClassifier
70
+ - "HistGB" - HistGradientBoostingClassifier
71
+ Regression Models:
72
+ - "XGBoost" - XGBRegressor
73
+ - "LightGBM" - LGBMRegressor
74
+ - "HistGB" - HistGradientBoostingRegressor
75
+
76
+ For classification only: Set `is_balanced=False` for imbalanced datasets.
77
+
78
+ Increase L1 and L2 if model is overfitting
79
+ '''
80
+
81
+ # Model initialization logic
82
+ if task not in ["classification", "regression"]:
83
+ raise ValueError(f"Invalid task: {task}. Must be 'classification' or 'regression'.")
84
+
85
+ models = {}
86
+
87
+ # Common parameters
88
+ xgb_params = {
89
+ 'n_estimators': 200,
90
+ 'max_depth': 5,
91
+ 'learning_rate': learning_rate,
92
+ 'subsample': 0.8,
93
+ 'colsample_bytree': 0.8,
94
+ 'random_state': random_state,
95
+ 'reg_alpha': L1_regularization,
96
+ 'reg_lambda': L2_regularization,
97
+ }
98
+
99
+ lgbm_params = {
100
+ 'n_estimators': 200,
101
+ 'learning_rate': learning_rate,
102
+ 'max_depth': 5,
103
+ 'subsample': 0.8,
104
+ 'colsample_bytree': 0.8,
105
+ 'random_state': random_state,
106
+ 'verbose': -1,
107
+ 'reg_alpha': L1_regularization,
108
+ 'reg_lambda': L2_regularization,
109
+ }
110
+
111
+ hist_params = {
112
+ 'max_iter': 200,
113
+ 'learning_rate': learning_rate,
114
+ 'max_depth': 5,
115
+ 'min_samples_leaf': 30,
116
+ 'random_state': random_state,
117
+ 'l2_regularization': L2_regularization,
118
+ }
119
+
120
+ # XGB Model
121
+ if task == "classification":
122
+ xgb_params.update({
123
+ 'scale_pos_weight': 1 if is_balanced else 8,
124
+ 'eval_metric': 'aucpr'
125
+ })
126
+ models["XGBoost"] = xgb.XGBClassifier(**xgb_params)
127
+ else:
128
+ xgb_params.update({'eval_metric': 'rmse'})
129
+ models["XGBoost"] = xgb.XGBRegressor(**xgb_params)
130
+
131
+ # LGBM Model
132
+ if task == "classification":
133
+ lgbm_params.update({
134
+ 'class_weight': None if is_balanced else 'balanced',
135
+ 'boosting_type': 'goss' if is_balanced else 'dart',
136
+ })
137
+ models["LightGBM"] = lgb.LGBMClassifier(**lgbm_params)
138
+ else:
139
+ lgbm_params['boosting_type'] = 'dart'
140
+ models["LightGBM"] = lgb.LGBMRegressor(**lgbm_params)
141
+
142
+ # HistGB Model
143
+ if task == "classification":
144
+ hist_params.update({
145
+ 'class_weight': None if is_balanced else 'balanced',
146
+ 'scoring': 'loss' if is_balanced else 'balanced_accuracy',
147
+ })
148
+ models["HistGB"] = HistGradientBoostingClassifier(**hist_params)
149
+ else:
150
+ hist_params['scoring'] = 'neg_mean_squared_error'
151
+ models["HistGB"] = HistGradientBoostingRegressor(**hist_params)
152
+
153
+ return models
154
+
155
+ ###### 3. Process Dataset ######
156
+ # function to split data into train and test
157
+ def _split_data(features, target, test_size, random_state):
158
+ X_train, X_test, y_train, y_test = train_test_split(features, target, test_size=test_size, random_state=random_state, stratify=target)
159
+ return X_train, X_test, y_train, y_test
160
+
161
+ # function to standardize the data
162
+ def _standardize_data(train_features, test_features, scaler_code):
163
+ if scaler_code == "standard":
164
+ scaler = StandardScaler()
165
+ elif scaler_code == "minmax":
166
+ scaler = MinMaxScaler()
167
+ elif scaler_code == "maxabs":
168
+ scaler = MaxAbsScaler()
169
+ else:
170
+ raise ValueError(f"Unrecognized scaler {scaler_code}")
171
+ train_scaled = scaler.fit_transform(train_features)
172
+ test_scaled = scaler.transform(test_features)
173
+ return train_scaled, test_scaled, scaler
174
+
175
+ # Over-sample minority class (Positive cases) and return several single target datasets (Classification)
176
+ def _resample(X_train_scaled: np.ndarray, y_train: pd.Series,
177
+ strategy: Literal[r"ADASYN", r'SMOTE', r'RANDOM', r'UNDERSAMPLE'], random_state):
178
+ '''
179
+ Oversample minority class or undersample majority class.
180
+
181
+ Returns a Tuple `(Features: nD-Array, Target: 1D-array)`
182
+ '''
183
+ if strategy == 'SMOTE':
184
+ resample_algorithm = SMOTE(random_state=random_state, k_neighbors=3)
185
+ elif strategy == 'RANDOM':
186
+ resample_algorithm = RandomOverSampler(random_state=random_state)
187
+ elif strategy == 'UNDERSAMPLE':
188
+ resample_algorithm = RandomUnderSampler(random_state=random_state)
189
+ elif strategy == 'ADASYN':
190
+ resample_algorithm = ADASYN(random_state=random_state, n_neighbors=3)
191
+ else:
192
+ raise ValueError(f"Invalid resampling strategy: {strategy}")
193
+
194
+ X_res, y_res = resample_algorithm.fit_resample(X_train_scaled, y_train)
195
+ return X_res, y_res
196
+
197
+ # DATASET PIPELINE
198
+ def dataset_pipeline(df_features: pd.DataFrame, df_target: pd.Series, task: Literal["classification", "regression"],
199
+ resample_strategy: Union[Literal[r"ADASYN", r'SMOTE', r'RANDOM', r'UNDERSAMPLE'], None], scaler: Literal["standard", "minmax", "maxabs"],
200
+ test_size: float=0.2, debug: bool=False, random_state: int=101):
201
+ '''
202
+ 1. Make Train/Test splits
203
+ 2. Standardize Train and Test Features
204
+ 3. Oversample imbalanced classes (classification)
205
+
206
+ Return a processed Tuple: (X_train, y_train, X_test, y_test, Scaler)
207
+
208
+ `(nD-array, 1D-array, nD-array, Series, Scaler)`
209
+ '''
210
+ #DEBUG
211
+ if debug:
212
+ print(f"Split Dataframes Shapes - Features DF: {df_features.shape}, Target DF: {df_target.shape}")
213
+ unique_values = df_target.unique() # Get unique values for the target column
214
+ print(f"\tUnique values for '{df_target.name}': {unique_values}")
215
+
216
+ #Train test split
217
+ X_train, X_test, y_train, y_test = _split_data(features=df_features, target=df_target, test_size=test_size, random_state=random_state)
218
+
219
+ #DEBUG
220
+ if debug:
221
+ print(f"Shapes after train test split - X_train: {X_train.shape}, y_train: {y_train.shape}, X_test: {X_test.shape}, y_test: {y_test.shape}")
222
+
223
+ # Standardize
224
+ X_train_scaled, X_test_scaled, scaler_object = _standardize_data(train_features=X_train, test_features=X_test, scaler_code=scaler)
225
+
226
+ #DEBUG
227
+ if debug:
228
+ print(f"Shapes after scaling features - X_train: {X_train_scaled.shape}, y_train: {y_train.shape}, X_test: {X_test_scaled.shape}, y_test: {y_test.shape}")
229
+
230
+ # Scale
231
+ if resample_strategy is None or task == "regression":
232
+ X_train_oversampled, y_train_oversampled = X_train_scaled, y_train
233
+ else:
234
+ X_train_oversampled, y_train_oversampled = _resample(X_train_scaled=X_train_scaled, y_train=y_train, strategy=resample_strategy, random_state=random_state)
235
+
236
+ #DEBUG
237
+ if debug:
238
+ print(f"Shapes after resampling - X_train: {X_train_oversampled.shape}, y_train: {y_train_oversampled.shape}, X_test: {X_test_scaled.shape}, y_test: {y_test.shape}")
239
+
240
+ return X_train_oversampled, y_train_oversampled, X_test_scaled, y_test, scaler_object
241
+
242
+ ###### 4. Train and Evaluation ######
243
+ # Trainer function
244
+ def _train_model(model, train_features, train_target):
245
+ model.fit(train_features, train_target)
246
+ return model
247
+
248
+ # handle local directories
249
+ def _local_directories(model_name: str, dataset_id: str, save_dir: str):
250
+ dataset_dir = os.path.join(save_dir, dataset_id)
251
+ if not os.path.isdir(dataset_dir):
252
+ os.makedirs(dataset_dir)
253
+
254
+ model_dir = os.path.join(dataset_dir, model_name)
255
+ if not os.path.isdir(model_dir):
256
+ os.makedirs(model_dir)
257
+
258
+ return model_dir
259
+
260
+ # save model
261
+ def _save_model(trained_model, model_name: str, target_name:str, feature_names: list[str], save_directory: str, scaler_object: Union[StandardScaler, MinMaxScaler, MaxAbsScaler]):
262
+ full_path = os.path.join(save_directory, f"{model_name}_{target_name}.joblib")
263
+ joblib.dump({'model': trained_model, 'scaler':scaler_object, 'feature_names': feature_names, 'target_name':target_name}, full_path)
264
+
265
+ # function to evaluate the model and save metrics (Classification)
266
+ def evaluate_model_classification(
267
+ model,
268
+ model_name: str,
269
+ save_dir: str,
270
+ x_test_scaled: np.ndarray,
271
+ single_y_test: np.ndarray,
272
+ target_id: str,
273
+ figsize: tuple = (10, 8),
274
+ title_fontsize: int = 24,
275
+ label_fontsize: int = 24,
276
+ cmap: Colormap = plt.cm.Blues # type: ignore
277
+ ) -> np.ndarray:
278
+ """
279
+ Evaluates a classification model, saves the classification report and confusion matrix plot.
280
+
281
+ Parameters:
282
+ model: Trained classifier with .predict() method
283
+ model_name: Identifier for the model
284
+ save_dir: Directory where results are saved
285
+ x_test_scaled: Feature matrix for test set
286
+ single_y_test: True binary labels
287
+ target_id: Suffix for naming output files
288
+ figsize: Size of the confusion matrix figure (width, height)
289
+ fontsize: Font size used for title, axis labels and ticks
290
+ cmap: Color map for the confusion matrix. Examples include:
291
+ - plt.cm.Blues (default)
292
+ - plt.cm.Greens
293
+ - plt.cm.Oranges
294
+ - plt.cm.Purples
295
+ - plt.cm.Reds
296
+ - plt.cm.cividis
297
+ - plt.cm.inferno
298
+
299
+ Returns:
300
+ y_pred: Predicted class labels
301
+ """
302
+ os.makedirs(save_dir, exist_ok=True)
303
+
304
+ y_pred = model.predict(x_test_scaled)
305
+ accuracy = accuracy_score(single_y_test, y_pred)
306
+
307
+ report = classification_report(
308
+ single_y_test,
309
+ y_pred,
310
+ target_names=["Negative", "Positive"],
311
+ output_dict=False
312
+ )
313
+
314
+ # Save text report
315
+ report_path = os.path.join(save_dir, f"Classification_Report_{target_id}.txt")
316
+ with open(report_path, "w") as f:
317
+ f.write(f"{model_name} - {target_id}\t\tAccuracy: {accuracy:.2f}\n")
318
+ f.write("Classification Report:\n")
319
+ f.write(report) # type: ignore
320
+
321
+ # Create confusion matrix
322
+ fig, ax = plt.subplots(figsize=figsize)
323
+ disp = ConfusionMatrixDisplay.from_predictions(
324
+ y_true=single_y_test,
325
+ y_pred=y_pred,
326
+ display_labels=["Negative", "Positive"],
327
+ cmap=cmap,
328
+ normalize="true",
329
+ ax=ax
330
+ )
331
+
332
+ ax.set_title(f"{model_name} - {target_id}", fontsize=title_fontsize)
333
+ ax.tick_params(axis='both', labelsize=label_fontsize)
334
+ ax.set_xlabel("Predicted label", fontsize=label_fontsize)
335
+ ax.set_ylabel("True label", fontsize=label_fontsize)
336
+
337
+ # Turn off gridlines
338
+ ax.grid(False)
339
+
340
+ # Manually update font size of cell texts
341
+ for text in ax.texts:
342
+ text.set_fontsize(title_fontsize+4)
343
+
344
+ fig.tight_layout()
345
+ fig_path = os.path.join(save_dir, f"Confusion_Matrix_{target_id}.svg")
346
+ fig.savefig(fig_path, format="svg", bbox_inches="tight")
347
+ plt.close(fig)
348
+
349
+ return y_pred
350
+
351
+ #Function to save ROC and ROC AUC (Classification)
352
+ def plot_roc_curve(
353
+ true_labels: np.ndarray,
354
+ probabilities_or_model: Union[np.ndarray, xgb.XGBClassifier, lgb.LGBMClassifier, object],
355
+ model_name: str,
356
+ target_name: str,
357
+ save_directory: str,
358
+ color: str = "darkorange",
359
+ figure_size: tuple = (10, 10),
360
+ linewidth: int = 2,
361
+ title_fontsize: int = 24,
362
+ label_fontsize: int = 24,
363
+ input_features: Optional[np.ndarray] = None,
364
+ ) -> plt.Figure: # type: ignore
365
+ """
366
+ Plots the ROC curve and computes AUC for binary classification. Positive class is assumed to be in the second column of the probabilities array.
367
+
368
+ Parameters:
369
+ true_labels: np.ndarray of shape (n_samples,), ground truth binary labels (0 or 1).
370
+ probabilities_or_model: either predicted probabilities (ndarray), or a trained model with attribute `.predict_proba()`.
371
+ target_name: str, used for figure title and filename.
372
+ save_directory: str, path to directory where figure is saved.
373
+ color: color of the ROC curve. Accepts any valid Matplotlib color specification. Examples:
374
+ - Named colors: "darkorange", "blue", "red", "green", "black"
375
+ - Hex codes: "#1f77b4", "#ff7f0e"
376
+ - RGB tuples: (0.2, 0.4, 0.6)
377
+ - Colormap value: plt.cm.viridis(0.6)
378
+ figure_size: Tuple for figure size (width, height).
379
+ linewidth: int, width of the plotted ROC line.
380
+ title_fontsize: int, font size of the title.
381
+ label_fontsize: int, font size for axes labels.
382
+ input_features: np.ndarray of shape (n_samples, n_features), required if a model is passed.
383
+
384
+ Returns:
385
+ fig: matplotlib Figure object
386
+ """
387
+
388
+ # Determine predicted probabilities
389
+ if isinstance(probabilities_or_model, np.ndarray):
390
+ # Input is already probabilities
391
+ if probabilities_or_model.ndim == 2:
392
+ y_score = probabilities_or_model[:, 1]
393
+ else:
394
+ y_score = probabilities_or_model
395
+
396
+ elif hasattr(probabilities_or_model, "predict_proba"):
397
+ if input_features is None:
398
+ raise ValueError("input_features must be provided when using a classifier.")
399
+
400
+ try:
401
+ classes = probabilities_or_model.classes_ # type: ignore
402
+ positive_class_index = list(classes).index(1)
403
+ except (AttributeError, ValueError):
404
+ positive_class_index = 1
405
+
406
+ y_score = probabilities_or_model.predict_proba(input_features)[:, positive_class_index] # type: ignore
407
+
408
+ else:
409
+ raise TypeError("Unsupported type for 'probabilities_or_model'. Must be a NumPy array or a model with support for '.predict_proba()'.")
410
+
411
+ # ROC and AUC
412
+ fpr, tpr, _ = roc_curve(true_labels, y_score)
413
+ auc_score = roc_auc_score(true_labels, y_score)
414
+
415
+ # Plot
416
+ fig, ax = plt.subplots(figsize=figure_size)
417
+ ax.plot(fpr, tpr, color=color, lw=linewidth, label=f"AUC = {auc_score:.2f}")
418
+ ax.plot([0, 1], [0, 1], color="gray", linestyle="--", lw=1)
419
+
420
+ ax.set_title(f"{model_name} - {target_name}", fontsize=title_fontsize)
421
+ ax.set_xlabel("False Positive Rate", fontsize=label_fontsize)
422
+ ax.set_ylabel("True Positive Rate", fontsize=label_fontsize)
423
+ ax.tick_params(axis='both', labelsize=label_fontsize)
424
+ ax.legend(loc="lower right", fontsize=label_fontsize)
425
+ ax.grid(True)
426
+
427
+ # Save figure
428
+ os.makedirs(save_directory, exist_ok=True)
429
+ save_path = os.path.join(save_directory, f"ROC_{target_name}.svg")
430
+ fig.savefig(save_path, bbox_inches="tight", format="svg")
431
+
432
+ return fig
433
+
434
+ # function to evaluate the model and save metrics (Regression)
435
+ def evaluate_model_regression(model, model_name: str,
436
+ save_dir: str,
437
+ x_test_scaled: np.ndarray, single_y_test: np.ndarray,
438
+ target_id: str,
439
+ figure_size: tuple = (12, 8),
440
+ alpha_transparency: float = 0.5,
441
+ title_fontsize: int = 24,
442
+ normal_fontsize: int = 24):
443
+ # Generate predictions
444
+ y_pred = model.predict(x_test_scaled)
445
+
446
+ # Calculate regression metrics
447
+ mae = mean_absolute_error(single_y_test, y_pred)
448
+ mse = mean_squared_error(single_y_test, y_pred)
449
+ rmse = np.sqrt(mse)
450
+ r2 = r2_score(single_y_test, y_pred)
451
+
452
+ # Create formatted report
453
+ report_path = os.path.join(save_dir, f"Regression_Report_{target_id}.txt")
454
+ with open(report_path, "w") as f:
455
+ f.write(f"{model_name} - {target_id} Regression Performance\n")
456
+ f.write(f"Mean Absolute Error (MAE): {mae:.4f}\n")
457
+ f.write(f"Mean Squared Error (MSE): {mse:.4f}\n")
458
+ f.write(f"Root Mean Squared Error (RMSE): {rmse:.4f}\n")
459
+ f.write(f"R² Score: {r2:.4f}\n")
460
+
461
+ # Generate and save residual plot
462
+ residuals = single_y_test - y_pred
463
+ plt.figure(figsize=figure_size)
464
+ plt.scatter(y_pred, residuals, alpha=alpha_transparency)
465
+ plt.axhline(0, color='red', linestyle='--')
466
+ plt.xlabel("Predicted Values", fontsize=normal_fontsize)
467
+ plt.ylabel("Residuals", fontsize=normal_fontsize)
468
+ plt.title(f"{model_name} - Residual Plot for {target_id}", fontsize=title_fontsize)
469
+ plt.grid(True)
470
+ plt.tight_layout()
471
+ plt.savefig(os.path.join(save_dir, f"Residual_Plot_{target_id}.svg"), bbox_inches='tight', format="svg")
472
+ plt.close()
473
+
474
+ # Create true vs predicted values plot
475
+ plt.figure(figsize=figure_size)
476
+ plt.scatter(single_y_test, y_pred, alpha=alpha_transparency)
477
+ plt.plot([single_y_test.min(), single_y_test.max()],
478
+ [single_y_test.min(), single_y_test.max()],
479
+ 'k--', lw=2)
480
+ plt.xlabel('True Values', fontsize=normal_fontsize)
481
+ plt.ylabel('Predictions', fontsize=normal_fontsize)
482
+ plt.title(f"{model_name} - True vs Predicted for {target_id}", fontsize=title_fontsize)
483
+ plt.grid(True)
484
+ plot_path = os.path.join(save_dir, f"Regression_Plot_{target_id}.svg")
485
+ plt.savefig(plot_path, bbox_inches='tight', format="svg")
486
+ plt.close()
487
+
488
+ return y_pred
489
+
490
+ # Get SHAP values
491
+ def get_shap_values(model, model_name: str,
492
+ save_dir: str,
493
+ features_to_explain: np.ndarray,
494
+ feature_names: list[str],
495
+ target_id: str,
496
+ task: Literal["classification", "regression"],
497
+ max_display_features: int=8,
498
+ figsize: tuple=(14, 20),
499
+ title_fontsize: int=38,
500
+ label_fontsize: int=38,
501
+ plot_type: Literal["bar", "dot"] = "dot"
502
+ ):
503
+ """
504
+ Universal SHAP explainer for regression and classification.
505
+ - Use `X_train` (or a subsample of it) to see how the model explains the data it was trained on.
506
+ - Use `X_test` (or a hold-out set) to see how the model explains unseen data.
507
+ - Use the entire dataset to get the global view.
508
+
509
+ Parameters:
510
+ - 'task': 'regression' or 'classification'
511
+ - 'features_to_explain': Should match the model's training data format, including scaling.
512
+ - 'save_dir': Directory to save visualizations
513
+ """
514
+ def _create_shap_plot(shap_values, features, feature_names,
515
+ full_save_path: str, plot_type: str,
516
+ title: str):
517
+ """Helper function to create and save SHAP plots"""
518
+ # Set style
519
+ preferred_styles = ['seaborn', 'seaborn-v0_8-darkgrid', 'seaborn-v0_8', 'default']
520
+ for style in preferred_styles:
521
+ if style in plt.style.available or style == 'default':
522
+ plt.style.use(style)
523
+ break
524
+
525
+ plt.figure(figsize=figsize)
526
+
527
+ #set rc parameters for better readability
528
+ plt.rc('font', size=label_fontsize)
529
+ plt.rc('axes', titlesize=title_fontsize)
530
+ plt.rc('axes', labelsize=label_fontsize)
531
+ plt.rc('xtick', labelsize=label_fontsize)
532
+ plt.rc('ytick', labelsize=label_fontsize)
533
+ plt.rc('legend', fontsize=label_fontsize)
534
+ plt.rc('figure', titlesize=title_fontsize)
535
+
536
+ # Create the SHAP plot
537
+ shap.summary_plot(
538
+ shap_values=shap_values,
539
+ features=features,
540
+ feature_names=feature_names,
541
+ plot_type=plot_type,
542
+ show=False,
543
+ plot_size=figsize,
544
+ max_display=max_display_features,
545
+ alpha=0.7,
546
+ color=plt.get_cmap('viridis')
547
+ )
548
+
549
+ # Add professional styling
550
+ ax = plt.gca()
551
+ ax.set_xlabel("SHAP Value Impact", fontsize=title_fontsize, weight='bold')
552
+ ax.set_ylabel("Features", fontsize=title_fontsize, weight='bold')
553
+ plt.title(title, fontsize=title_fontsize, pad=20, weight='bold')
554
+
555
+ # Manually fix tick fonts
556
+ for tick in ax.get_xticklabels():
557
+ tick.set_fontsize(label_fontsize)
558
+ tick.set_rotation(45)
559
+ for tick in ax.get_yticklabels():
560
+ tick.set_fontsize(label_fontsize)
561
+
562
+ # Handle colorbar for dot plots
563
+ if plot_type == "dot":
564
+ cb = plt.gcf().axes[-1]
565
+ # cb.set_ylabel("Feature Value", size=label_fontsize)
566
+ cb.set_ylabel("", size=1)
567
+ cb.tick_params(labelsize=label_fontsize - 2)
568
+
569
+ # Save and clean up
570
+ plt.savefig(
571
+ full_save_path,
572
+ bbox_inches='tight',
573
+ facecolor='white',
574
+ format="svg"
575
+ )
576
+ plt.close()
577
+ rcdefaults() # Reset rc parameters to default
578
+
579
+ # START
580
+ explainer = shap.TreeExplainer(model)
581
+ shap_values = explainer.shap_values(features_to_explain)
582
+
583
+ # Handle different model types
584
+ if task == 'classification':
585
+ # Determine if multiclass
586
+ try:
587
+ is_multiclass = len(model.classes_) > 2
588
+ class_names = model.classes_
589
+ except AttributeError:
590
+ is_multiclass = isinstance(shap_values, list) and len(shap_values) > 1
591
+ class_names = list(range(len(shap_values))) if is_multiclass else [0, 1]
592
+
593
+ if is_multiclass:
594
+ for class_idx, (class_shap, class_name) in enumerate(zip(shap_values, class_names)):
595
+ _create_shap_plot(
596
+ shap_values=class_shap,
597
+ features=features_to_explain,
598
+ feature_names=feature_names,
599
+ full_save_path=os.path.join(save_dir, f"SHAP_{target_id}_Class{class_name}.svg"),
600
+ plot_type=plot_type,
601
+ title=f"{model_name} - {target_id} (Class {class_name})"
602
+ )
603
+ else:
604
+ # Handle binary classification (single array case)
605
+ plot_vals = shap_values[1] if isinstance(shap_values, list) else shap_values
606
+ _create_shap_plot(
607
+ shap_values=plot_vals,
608
+ features=features_to_explain,
609
+ feature_names=feature_names,
610
+ full_save_path=os.path.join(save_dir, f"SHAP_{target_id}.svg"),
611
+ plot_type=plot_type,
612
+ title=f"{model_name} - {target_id}"
613
+ )
614
+
615
+ else: # Regression
616
+ _create_shap_plot(
617
+ shap_values=shap_values,
618
+ features=features_to_explain,
619
+ feature_names=feature_names,
620
+ full_save_path=os.path.join(save_dir, f"SHAP_{target_id}.svg"),
621
+ plot_type=plot_type,
622
+ title=f"{model_name} - {target_id}"
623
+ )
624
+
625
+ # TRAIN TEST PIPELINE
626
+ def train_test_pipeline(model, model_name: str, dataset_id: str, task: Literal["classification", "regression"],
627
+ train_features: np.ndarray, train_target: np.ndarray,
628
+ test_features: np.ndarray, test_target: np.ndarray,
629
+ feature_names: list[str], target_id: str, scaler_object: Union[StandardScaler, MinMaxScaler, MaxAbsScaler],
630
+ save_dir: str,
631
+ debug: bool=False, save_model: bool=False):
632
+ '''
633
+ 1. Train model.
634
+ 2. Evaluate model.
635
+ 3. SHAP values.
636
+
637
+ Returns: Tuple(Trained model, Test-set Predictions)
638
+ '''
639
+ print(f"\tModel: {model_name} for Target: {target_id}...")
640
+ trained_model = _train_model(model=model, train_features=train_features, train_target=train_target)
641
+ if debug:
642
+ print(f"Trained model object: {type(trained_model)}")
643
+ local_save_directory = _local_directories(model_name=model_name, dataset_id=dataset_id, save_dir=save_dir)
644
+
645
+ if save_model:
646
+ _save_model(trained_model=trained_model, model_name=model_name,
647
+ target_name=target_id, feature_names=feature_names,
648
+ save_directory=local_save_directory, scaler_object=scaler_object)
649
+
650
+ if task == "classification":
651
+ y_pred = evaluate_model_classification(model=trained_model, model_name=model_name, save_dir=local_save_directory,
652
+ x_test_scaled=test_features, single_y_test=test_target, target_id=target_id)
653
+ plot_roc_curve(true_labels=test_target,
654
+ probabilities_or_model=trained_model, model_name=model_name,
655
+ target_name=target_id, save_directory=local_save_directory,
656
+ input_features=test_features)
657
+ elif task == "regression":
658
+ y_pred = evaluate_model_regression(model=trained_model, model_name=model_name, save_dir=local_save_directory,
659
+ x_test_scaled=test_features, single_y_test=test_target, target_id=target_id)
660
+ else:
661
+ raise ValueError(f"Unrecognized task '{task}' for model training,")
662
+ if debug:
663
+ print(f"Predicted vector: {type(y_pred)} with shape: {y_pred.shape}")
664
+
665
+ get_shap_values(model=trained_model, model_name=model_name, save_dir=local_save_directory,
666
+ features_to_explain=train_features, feature_names=feature_names, target_id=target_id, task=task)
667
+ print("\t...done.")
668
+ return trained_model, y_pred
669
+
670
+ ###### 5. Execution ######
671
+ def run_pipeline(datasets_dir: str, save_dir: str, target_columns: list[str], task: Literal["classification", "regression"]="regression",
672
+ resample_strategy: Literal[r"ADASYN", r'SMOTE', r'RANDOM', r'UNDERSAMPLE', None]=None, scaler: Literal["standard", "minmax", "maxabs"]="minmax", save_model: bool=False,
673
+ test_size: float=0.2, debug:bool=False, L1_regularization: float=0.5, L2_regularization: float=0.5, learning_rate: float=0.005, random_state: int=101):
674
+ #Check paths
675
+ _check_paths(datasets_dir, save_dir)
676
+ #Yield imputed dataset
677
+ for dataframe, dataframe_name in yield_imputed_dataframe(datasets_dir):
678
+ #Yield features dataframe and target dataframe
679
+ for df_features, df_target, feature_names, target_name in dataset_yielder(df=dataframe, target_cols=target_columns):
680
+ #Dataset pipeline
681
+ X_train, y_train, X_test, y_test, scaler_object = dataset_pipeline(df_features=df_features, df_target=df_target, task=task,
682
+ resample_strategy=resample_strategy, scaler=scaler,
683
+ test_size=test_size, debug=debug, random_state=random_state)
684
+ #Get models
685
+ models_dict = get_models(task=task, is_balanced=False if resample_strategy is None else True,
686
+ L1_regularization=L1_regularization, L2_regularization=L2_regularization, learning_rate=learning_rate)
687
+ #Train models
688
+ for model_name, model in models_dict.items():
689
+ train_test_pipeline(model=model, model_name=model_name, dataset_id=dataframe_name, task=task,
690
+ train_features=X_train, train_target=y_train,
691
+ test_features=X_test, test_target=y_test,
692
+ feature_names=feature_names,target_id=target_name, scaler_object=scaler_object,
693
+ debug=debug, save_dir=save_dir, save_model=save_model)
694
+ print("\nTraining and evaluation complete.")
695
+
696
+
697
+ def _check_paths(datasets_dir: str, save_dir:str):
698
+ if not os.path.isdir(save_dir):
699
+ os.makedirs(save_dir)
700
+ if not os.path.isdir(datasets_dir):
701
+ raise IOError(f"Datasets directory '{datasets_dir}' not found.\nCheck path or run MICE script first.")