py2ls 0.2.4.14__py3-none-any.whl → 0.2.4.16__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
py2ls/ml2ls copy.py ADDED
@@ -0,0 +1,2906 @@
1
+ from sklearn.ensemble import (
2
+ RandomForestClassifier,
3
+ GradientBoostingClassifier,
4
+ AdaBoostClassifier,
5
+ BaggingClassifier,
6
+ )
7
+ from sklearn.svm import SVC, SVR
8
+ from sklearn.calibration import CalibratedClassifierCV
9
+ from sklearn.model_selection import GridSearchCV, StratifiedKFold
10
+ from sklearn.linear_model import (
11
+ LassoCV,
12
+ LogisticRegression,
13
+ LinearRegression,
14
+ Lasso,
15
+ Ridge,
16
+ RidgeClassifierCV,
17
+ ElasticNet,
18
+ )
19
+ from sklearn.feature_selection import RFE
20
+ from sklearn.naive_bayes import GaussianNB
21
+ from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
22
+ import xgboost as xgb # Make sure you have xgboost installed
23
+
24
+ from sklearn.model_selection import train_test_split, cross_val_score
25
+ from sklearn.metrics import (
26
+ accuracy_score,
27
+ precision_score,
28
+ recall_score,
29
+ f1_score,
30
+ roc_auc_score,
31
+ confusion_matrix,
32
+ matthews_corrcoef,
33
+ roc_curve,
34
+ auc,
35
+ balanced_accuracy_score,
36
+ precision_recall_curve,
37
+ average_precision_score,
38
+ )
39
+ from imblearn.over_sampling import SMOTE
40
+ from sklearn.pipeline import Pipeline
41
+ from collections import defaultdict
42
+ from sklearn.preprocessing import StandardScaler, OneHotEncoder
43
+ from typing import Dict, Any, Optional, List, Union
44
+ import numpy as np
45
+ import pandas as pd
46
+ from . import ips
47
+ from . import plot
48
+ import matplotlib.pyplot as plt
49
+ import seaborn as sns
50
+
51
+ plt.style.use(str(ips.get_cwd()) + "/data/styles/stylelib/paper.mplstyle")
52
+ import logging
53
+ import warnings
54
+
55
+ logging.basicConfig(
56
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
57
+ )
58
+ logger = logging.getLogger()
59
+
60
+ # Ignore specific warnings (UserWarning in this case)
61
+ warnings.filterwarnings("ignore", category=UserWarning)
62
+ from sklearn.tree import DecisionTreeClassifier
63
+ from sklearn.neighbors import KNeighborsClassifier
64
+
65
+
66
+ def features_knn(
67
+ x_train: pd.DataFrame, y_train: pd.Series, knn_params: dict
68
+ ) -> pd.DataFrame:
69
+ """
70
+ A distance-based classifier that assigns labels based on the majority label of nearest neighbors.
71
+ when to use:
72
+ Effective for small to medium datasets with a low number of features.
73
+ It does not directly provide feature importances but can be assessed through feature permutation or similar methods.
74
+ Recommended Use: Effective for datasets with low feature dimensionality and well-separated clusters.
75
+
76
+ Fits KNeighborsClassifier and approximates feature influence using permutation importance.
77
+ """
78
+ knn = KNeighborsClassifier(**knn_params)
79
+ knn.fit(x_train, y_train)
80
+ importances = permutation_importance(
81
+ knn, x_train, y_train, n_repeats=30, random_state=1, scoring="accuracy"
82
+ )
83
+ return pd.DataFrame(
84
+ {"feature": x_train.columns, "importance": importances.importances_mean}
85
+ ).sort_values(by="importance", ascending=False)
86
+
87
+
88
+ #! 1. Linear and Regularized Regression Methods
89
+ # 1.1 Lasso
90
+ def features_lasso(
91
+ x_train: pd.DataFrame, y_train: pd.Series, lasso_params: dict
92
+ ) -> np.ndarray:
93
+ """
94
+ Lasso (Least Absolute Shrinkage and Selection Operator):
95
+ A regularized linear regression method that uses L1 penalty to shrink coefficients, effectively
96
+ performing feature selection by zeroing out less important ones.
97
+ """
98
+ lasso = LassoCV(**lasso_params)
99
+ lasso.fit(x_train, y_train)
100
+ # Get non-zero coefficients and their corresponding features
101
+ coefficients = lasso.coef_
102
+ importance_df = pd.DataFrame(
103
+ {"feature": x_train.columns, "importance": np.abs(coefficients)}
104
+ )
105
+ return importance_df[importance_df["importance"] > 0].sort_values(
106
+ by="importance", ascending=False
107
+ )
108
+
109
+
110
+ # 1.2 Ridge regression
111
+ def features_ridge(
112
+ x_train: pd.DataFrame, y_train: pd.Series, ridge_params: dict
113
+ ) -> np.ndarray:
114
+ """
115
+ Ridge Regression: A linear regression technique that applies L2 regularization, reducing coefficient
116
+ magnitudes to avoid overfitting, especially with multicollinearity among features.
117
+ """
118
+ from sklearn.linear_model import RidgeCV
119
+
120
+ ridge = RidgeCV(**ridge_params)
121
+ ridge.fit(x_train, y_train)
122
+
123
+ # Get the coefficients
124
+ coefficients = ridge.coef_
125
+
126
+ # Create a DataFrame to hold feature importance
127
+ importance_df = pd.DataFrame(
128
+ {"feature": x_train.columns, "importance": np.abs(coefficients)}
129
+ )
130
+ return importance_df[importance_df["importance"] > 0].sort_values(
131
+ by="importance", ascending=False
132
+ )
133
+
134
+
135
+ # 1.3 Elastic Net(Enet)
136
+ def features_enet(
137
+ x_train: pd.DataFrame, y_train: pd.Series, enet_params: dict
138
+ ) -> np.ndarray:
139
+ """
140
+ Elastic Net (Enet): Combines L1 and L2 penalties (lasso and ridge) in a linear model, beneficial
141
+ when features are highly correlated or for datasets with more features than samples.
142
+ """
143
+ from sklearn.linear_model import ElasticNetCV
144
+
145
+ enet = ElasticNetCV(**enet_params)
146
+ enet.fit(x_train, y_train)
147
+ # Get the coefficients
148
+ coefficients = enet.coef_
149
+ # Create a DataFrame to hold feature importance
150
+ importance_df = pd.DataFrame(
151
+ {"feature": x_train.columns, "importance": np.abs(coefficients)}
152
+ )
153
+ return importance_df[importance_df["importance"] > 0].sort_values(
154
+ by="importance", ascending=False
155
+ )
156
+
157
+
158
+ # 1.4 Partial Least Squares Regression for Generalized Linear Models (plsRglm): Combines regression and
159
+ # feature reduction, useful for high-dimensional data with correlated features, such as genomics.
160
+
161
+ #! 2.Generalized Linear Models and Extensions
162
+ # 2.1
163
+
164
+
165
+ #!3.Tree-Based and Ensemble Methods
166
+ # 3.1 Random Forest(RF)
167
+ def features_rf(
168
+ x_train: pd.DataFrame, y_train: pd.Series, rf_params: dict
169
+ ) -> np.ndarray:
170
+ """
171
+ An ensemble of decision trees that combines predictions from multiple trees for classification or
172
+ regression, effective with high-dimensional, complex datasets.
173
+ when to use:
174
+ Handles high-dimensional data well.
175
+ Robust to overfitting due to averaging of multiple trees.
176
+ Provides feature importance, which can help in understanding the influence of different genes.
177
+ Fit Random Forest and return sorted feature importances.
178
+ Recommended Use: Great for classification problems, especially when you have many features (genes).
179
+ """
180
+ rf = RandomForestClassifier(**rf_params)
181
+ rf.fit(x_train, y_train)
182
+ return pd.DataFrame(
183
+ {"feature": x_train.columns, "importance": rf.featuress_}
184
+ ).sort_values(by="importance", ascending=False)
185
+
186
+
187
+ # 3.2 Gradient Boosting Trees
188
+ def features_gradient_boosting(
189
+ x_train: pd.DataFrame, y_train: pd.Series, gb_params: dict
190
+ ) -> pd.DataFrame:
191
+ """
192
+ An ensemble of decision trees that combines predictions from multiple trees for classification or regression, effective with
193
+ high-dimensional, complex datasets.
194
+ Gradient Boosting
195
+ Strengths:
196
+ High predictive accuracy and works well for both classification and regression.
197
+ Can handle a mixture of numerical and categorical features.
198
+ Recommended Use:
199
+ Effective for complex relationships and when you need a powerful predictive model.
200
+ Fit Gradient Boosting classifier and return sorted feature importances.
201
+ Recommended Use: Effective for complex datasets with many features (genes).
202
+ """
203
+ gb = GradientBoostingClassifier(**gb_params)
204
+ gb.fit(x_train, y_train)
205
+ return pd.DataFrame(
206
+ {"feature": x_train.columns, "importance": gb.feature_importances_}
207
+ ).sort_values(by="importance", ascending=False)
208
+
209
+
210
+ # 3.3 XGBoost
211
+ def features_xgb(
212
+ x_train: pd.DataFrame, y_train: pd.Series, xgb_params: dict
213
+ ) -> pd.DataFrame:
214
+ """
215
+ XGBoost: An advanced gradient boosting technique, faster and more efficient than GBM, with excellent predictive performance on structured data.
216
+ """
217
+ import xgboost as xgb
218
+
219
+ xgb_model = xgb.XGBClassifier(**xgb_params)
220
+ xgb_model.fit(x_train, y_train)
221
+ return pd.DataFrame(
222
+ {"feature": x_train.columns, "importance": xgb_model.feature_importances_}
223
+ ).sort_values(by="importance", ascending=False)
224
+
225
+
226
+ # 3.4.decision tree
227
+ def features_decision_tree(
228
+ x_train: pd.DataFrame, y_train: pd.Series, dt_params: dict
229
+ ) -> pd.DataFrame:
230
+ """
231
+ A single decision tree classifier effective for identifying key decision boundaries in data.
232
+ when to use:
233
+ Good for capturing non-linear patterns.
234
+ Provides feature importance scores for each feature, though it may overfit on small datasets.
235
+ Efficient for low to medium-sized datasets, where interpretability of decisions is key.
236
+ Recommended Use: Useful for interpretable feature importance analysis in smaller or balanced datasets.
237
+
238
+ Fits DecisionTreeClassifier and returns sorted feature importances.
239
+ """
240
+ dt = DecisionTreeClassifier(**dt_params)
241
+ dt.fit(x_train, y_train)
242
+ return pd.DataFrame(
243
+ {"feature": x_train.columns, "importance": dt.feature_importances_}
244
+ ).sort_values(by="importance", ascending=False)
245
+
246
+
247
+ # 3.5 bagging
248
+ def features_bagging(
249
+ x_train: pd.DataFrame, y_train: pd.Series, bagging_params: dict
250
+ ) -> pd.DataFrame:
251
+ """
252
+ A bagging ensemble of models, often used with weak learners like decision trees, to reduce variance.
253
+ when to use:
254
+ Helps reduce overfitting, especially on high-variance models.
255
+ Effective when the dataset has numerous features and may benefit from ensemble stability.
256
+ Recommended Use: Beneficial for high-dimensional or noisy datasets needing ensemble stability.
257
+
258
+ Fits BaggingClassifier and returns averaged feature importances from underlying estimators if available.
259
+ """
260
+ bagging = BaggingClassifier(**bagging_params)
261
+ bagging.fit(x_train, y_train)
262
+
263
+ # Calculate feature importance by averaging importances across estimators, if feature_importances_ is available.
264
+ if hasattr(bagging.estimators_[0], "feature_importances_"):
265
+ importances = np.mean(
266
+ [estimator.feature_importances_ for estimator in bagging.estimators_],
267
+ axis=0,
268
+ )
269
+ return pd.DataFrame(
270
+ {"feature": x_train.columns, "importance": importances}
271
+ ).sort_values(by="importance", ascending=False)
272
+ else:
273
+ # If the base estimator does not support feature importances, fallback to permutation importance.
274
+ importances = permutation_importance(
275
+ bagging, x_train, y_train, n_repeats=30, random_state=1, scoring="accuracy"
276
+ )
277
+ return pd.DataFrame(
278
+ {"feature": x_train.columns, "importance": importances.importances_mean}
279
+ ).sort_values(by="importance", ascending=False)
280
+
281
+
282
+ #! 4.Support Vector Machines
283
+ def features_svm(
284
+ x_train: pd.DataFrame, y_train: pd.Series, rfe_params: dict
285
+ ) -> np.ndarray:
286
+ """
287
+ Suitable for classification tasks where the number of features is much larger than the number of samples.
288
+ 1. Effective in high-dimensional spaces and with clear margin of separation.
289
+ 2. Works well for both linear and non-linear classification (using kernel functions).
290
+ Select features using RFE with SVM.When combined with SVM, RFE selects features that are most critical for the decision boundary,
291
+ helping reduce the dataset to a more manageable size without losing much predictive power.
292
+ SVM (Support Vector Machines),supports various kernels (linear, rbf, poly, and sigmoid), is good at handling high-dimensional
293
+ data and finding an optimal decision boundary between classes, especially when using the right kernel.
294
+ kernel: ["linear", "rbf", "poly", "sigmoid"]
295
+ 'linear': simplest kernel that attempts to separate data by drawing a straight line (or hyperplane) between classes. It is effective
296
+ when the data is linearly separable, meaning the classes can be well divided by a straight boundary.
297
+ Advantages:
298
+ - Computationally efficient for large datasets.
299
+ - Works well when the number of features is high, which is common in genomic data where you may have thousands of genes
300
+ as features.
301
+ 'rbf': a nonlinear kernel that maps the input data into a higher-dimensional space to find a decision boundary. It works well for
302
+ data that is not linearly separable in its original space.
303
+ Advantages:
304
+ - Handles nonlinear relationships between features and classes
305
+ - Often better than a linear kernel when there is no clear linear decision boundary in the data.
306
+ 'poly': Polynomial Kernel: computes similarity between data points based on polynomial functions of the input features. It can model
307
+ interactions between features to a certain degree, depending on the polynomial degree chosen.
308
+ Advantages:
309
+ - Allows modeling of feature interactions.
310
+ - Can fit more complex relationships compared to linear models.
311
+ 'sigmoid': similar to the activation function in neural networks, and it works well when the data follows an S-shaped decision boundary.
312
+ Advantages:
313
+ - Can approximate the behavior of neural networks.
314
+ - Use case: It’s not as widely used as the RBF or linear kernel but can be explored when there is some evidence of non-linear
315
+ S-shaped relationships.
316
+ """
317
+ # SVM (Support Vector Machines)
318
+ svc = SVC(kernel=rfe_params["kernel"]) # ["linear", "rbf", "poly", "sigmoid"]
319
+ # RFE(Recursive Feature Elimination)
320
+ selector = RFE(svc, n_features_to_select=rfe_params["n_features_to_select"])
321
+ selector.fit(x_train, y_train)
322
+ return x_train.columns[selector.support_]
323
+
324
+
325
+ #! 5.Bayesian and Probabilistic Methods
326
+ def features_naive_bayes(x_train: pd.DataFrame, y_train: pd.Series) -> list:
327
+ """
328
+ Naive Bayes: A probabilistic classifier based on Bayes' theorem, assuming independence between features, simple and fast, especially
329
+ effective for text classification and other high-dimensional data.
330
+ """
331
+ from sklearn.naive_bayes import GaussianNB
332
+
333
+ nb = GaussianNB()
334
+ nb.fit(x_train, y_train)
335
+ probabilities = nb.predict_proba(x_train)
336
+ # Limit the number of features safely, choosing the lesser of half the features or all columns
337
+ n_features = min(x_train.shape[1] // 2, len(x_train.columns))
338
+
339
+ # Sort probabilities, then map to valid column indices
340
+ sorted_indices = np.argsort(probabilities.max(axis=1))[:n_features]
341
+
342
+ # Ensure indices are within the column bounds of x_train
343
+ valid_indices = sorted_indices[sorted_indices < len(x_train.columns)]
344
+
345
+ return x_train.columns[valid_indices]
346
+
347
+
348
+ #! 6.Linear Discriminant Analysis (LDA)
349
+ def features_lda(x_train: pd.DataFrame, y_train: pd.Series) -> list:
350
+ """
351
+ Linear Discriminant Analysis (LDA): Projects data onto a lower-dimensional space to maximize class separability, often used as a dimensionality
352
+ reduction technique before classification on high-dimensional data.
353
+ """
354
+ from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
355
+
356
+ lda = LinearDiscriminantAnalysis()
357
+ lda.fit(x_train, y_train)
358
+ coef = lda.coef_.flatten()
359
+ # Create a DataFrame to hold feature importance
360
+ importance_df = pd.DataFrame(
361
+ {"feature": x_train.columns, "importance": np.abs(coef)}
362
+ )
363
+
364
+ return importance_df[importance_df["importance"] > 0].sort_values(
365
+ by="importance", ascending=False
366
+ )
367
+
368
+
369
+ def features_adaboost(
370
+ x_train: pd.DataFrame, y_train: pd.Series, adaboost_params: dict
371
+ ) -> pd.DataFrame:
372
+ """
373
+ AdaBoost
374
+ Strengths:
375
+ Combines multiple weak learners to create a strong classifier.
376
+ Focuses on examples that are hard to classify, improving overall performance.
377
+ Recommended Use:
378
+ Can be effective for boosting weak models in a genomics context.
379
+ Fit AdaBoost classifier and return sorted feature importances.
380
+ Recommended Use: Great for classification problems with a large number of features (genes).
381
+ """
382
+ ada = AdaBoostClassifier(**adaboost_params)
383
+ ada.fit(x_train, y_train)
384
+ return pd.DataFrame(
385
+ {"feature": x_train.columns, "importance": ada.feature_importances_}
386
+ ).sort_values(by="importance", ascending=False)
387
+
388
+
389
+ import torch
390
+ import torch.nn as nn
391
+ import torch.optim as optim
392
+ from torch.utils.data import DataLoader, TensorDataset
393
+ from skorch import NeuralNetClassifier # sklearn compatible
394
+
395
+
396
+ class DNNClassifier(nn.Module):
397
+ def __init__(self, input_dim, hidden_dim=128, output_dim=2, dropout_rate=0.5):
398
+ super(DNNClassifier, self).__init__()
399
+
400
+ self.hidden_layer1 = nn.Sequential(
401
+ nn.Linear(input_dim, hidden_dim),
402
+ nn.ReLU(),
403
+ nn.Dropout(dropout_rate),
404
+ nn.Linear(hidden_dim, hidden_dim),
405
+ nn.ReLU(),
406
+ )
407
+
408
+ self.hidden_layer2 = nn.Sequential(
409
+ nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Dropout(dropout_rate)
410
+ )
411
+
412
+ # Adding a residual connection between hidden layers
413
+ self.residual = nn.Linear(input_dim, hidden_dim)
414
+
415
+ self.output_layer = nn.Sequential(
416
+ nn.Linear(hidden_dim, output_dim), nn.Softmax(dim=1)
417
+ )
418
+
419
+ def forward(self, x):
420
+ residual = self.residual(x)
421
+ x = self.hidden_layer1(x)
422
+ x = x + residual # Residual connection
423
+ x = self.hidden_layer2(x)
424
+ x = self.output_layer(x)
425
+ return x
426
+
427
+
428
+ def validate_classifier(
429
+ clf,
430
+ x_train: pd.DataFrame,
431
+ y_train: pd.Series,
432
+ x_test: pd.DataFrame,
433
+ y_test: pd.Series,
434
+ metrics: list = ["accuracy", "precision", "recall", "f1", "roc_auc"],
435
+ cv_folds: int = 5,
436
+ ) -> dict:
437
+ """
438
+ Perform cross-validation for a given classifier and return average scores for specified metrics on training data.
439
+ Then fit the best model on the full training data and evaluate it on the test set.
440
+
441
+ Parameters:
442
+ - clf: The classifier to be validated.
443
+ - x_train: Training features.
444
+ - y_train: Training labels.
445
+ - x_test: Test features.
446
+ - y_test: Test labels.
447
+ - metrics: List of metrics to evaluate (e.g., ['accuracy', 'roc_auc']).
448
+ - cv_folds: Number of cross-validation folds.
449
+
450
+ Returns:
451
+ - results: Dictionary containing average cv_train_scores and cv_test_scores.
452
+ """
453
+ cv_train_scores = {metric: [] for metric in metrics}
454
+ skf = StratifiedKFold(n_splits=cv_folds)
455
+ # Perform cross-validation
456
+ for metric in metrics:
457
+ try:
458
+ if metric == "roc_auc" and len(set(y_train)) == 2:
459
+ scores = cross_val_score(
460
+ clf, x_train, y_train, cv=skf, scoring="roc_auc"
461
+ )
462
+ cv_train_scores[metric] = (
463
+ np.nanmean(scores) if not np.isnan(scores).all() else float("nan")
464
+ )
465
+ else:
466
+ score = cross_val_score(clf, x_train, y_train, cv=skf, scoring=metric)
467
+ cv_train_scores[metric] = score.mean()
468
+ except Exception as e:
469
+ cv_train_scores[metric] = float("nan")
470
+ clf.fit(x_train, y_train)
471
+
472
+ # Evaluate on the test set
473
+ cv_test_scores = {}
474
+ for metric in metrics:
475
+ if metric == "roc_auc" and len(set(y_test)) == 2:
476
+ try:
477
+ y_prob = clf.predict_proba(x_test)[:, 1]
478
+ cv_test_scores[metric] = roc_auc_score(y_test, y_prob)
479
+ except AttributeError:
480
+ cv_test_scores[metric] = float("nan")
481
+ else:
482
+ score_func = globals().get(
483
+ f"{metric}_score"
484
+ ) # Fetching the appropriate scoring function
485
+ if score_func:
486
+ try:
487
+ y_pred = clf.predict(x_test)
488
+ cv_test_scores[metric] = score_func(y_test, y_pred)
489
+ except Exception as e:
490
+ cv_test_scores[metric] = float("nan")
491
+
492
+ # Combine results
493
+ results = {"cv_train_scores": cv_train_scores, "cv_test_scores": cv_test_scores}
494
+ return results
495
+
496
+
497
+ def get_models(
498
+ random_state=1,
499
+ cls=[
500
+ "lasso",
501
+ "ridge",
502
+ "Elastic Net(Enet)",
503
+ "gradient Boosting",
504
+ "Random forest (rf)",
505
+ "XGBoost (xgb)",
506
+ "Support Vector Machine(svm)",
507
+ "naive bayes",
508
+ "Linear Discriminant Analysis (lda)",
509
+ "adaboost",
510
+ "DecisionTree",
511
+ "KNeighbors",
512
+ "Bagging",
513
+ ],
514
+ ):
515
+ from sklearn.ensemble import (
516
+ RandomForestClassifier,
517
+ GradientBoostingClassifier,
518
+ AdaBoostClassifier,
519
+ BaggingClassifier,
520
+ )
521
+ from sklearn.svm import SVC
522
+ from sklearn.linear_model import (
523
+ LogisticRegression,
524
+ Lasso,
525
+ RidgeClassifierCV,
526
+ ElasticNet,
527
+ )
528
+ from sklearn.naive_bayes import GaussianNB
529
+ from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
530
+ import xgboost as xgb
531
+ from sklearn.tree import DecisionTreeClassifier
532
+ from sklearn.neighbors import KNeighborsClassifier
533
+
534
+ res_cls = {}
535
+ model_all = {
536
+ "Lasso": LogisticRegression(
537
+ penalty="l1", solver="saga", random_state=random_state
538
+ ),
539
+ "Ridge": RidgeClassifierCV(),
540
+ "Elastic Net (Enet)": ElasticNet(random_state=random_state),
541
+ "Gradient Boosting": GradientBoostingClassifier(random_state=random_state),
542
+ "Random Forest (RF)": RandomForestClassifier(random_state=random_state),
543
+ "XGBoost (XGB)": xgb.XGBClassifier(random_state=random_state),
544
+ "Support Vector Machine (SVM)": SVC(kernel="rbf", probability=True),
545
+ "Naive Bayes": GaussianNB(),
546
+ "Linear Discriminant Analysis (LDA)": LinearDiscriminantAnalysis(),
547
+ "AdaBoost": AdaBoostClassifier(random_state=random_state, algorithm="SAMME"),
548
+ "DecisionTree": DecisionTreeClassifier(),
549
+ "KNeighbors": KNeighborsClassifier(n_neighbors=5),
550
+ "Bagging": BaggingClassifier(),
551
+ }
552
+ print("Using default models:")
553
+ for cls_name in cls:
554
+ cls_name = ips.strcmp(cls_name, list(model_all.keys()))[0]
555
+ res_cls[cls_name] = model_all[cls_name]
556
+ print(f"- {cls_name}")
557
+ return res_cls
558
+
559
+
560
+ def get_features(
561
+ X: Union[pd.DataFrame, np.ndarray], # n_samples X n_features
562
+ y: Union[pd.Series, np.ndarray, list], # n_samples X n_features
563
+ test_size: float = 0.2,
564
+ random_state: int = 1,
565
+ n_features: int = 10,
566
+ fill_missing=True,
567
+ rf_params: Optional[Dict] = None,
568
+ rfe_params: Optional[Dict] = None,
569
+ lasso_params: Optional[Dict] = None,
570
+ ridge_params: Optional[Dict] = None,
571
+ enet_params: Optional[Dict] = None,
572
+ gb_params: Optional[Dict] = None,
573
+ adaboost_params: Optional[Dict] = None,
574
+ xgb_params: Optional[Dict] = None,
575
+ dt_params: Optional[Dict] = None,
576
+ bagging_params: Optional[Dict] = None,
577
+ knn_params: Optional[Dict] = None,
578
+ cls: list = [
579
+ "lasso",
580
+ "ridge",
581
+ "Elastic Net(Enet)",
582
+ "gradient Boosting",
583
+ "Random forest (rf)",
584
+ "XGBoost (xgb)",
585
+ "Support Vector Machine(svm)",
586
+ "naive bayes",
587
+ "Linear Discriminant Analysis (lda)",
588
+ "adaboost",
589
+ "DecisionTree",
590
+ "KNeighbors",
591
+ "Bagging",
592
+ ],
593
+ metrics: Optional[List[str]] = None,
594
+ cv_folds: int = 5,
595
+ strict: bool = False,
596
+ n_shared: int = 2, # 只要有两个方法有重合,就纳入common genes
597
+ use_selected_features: bool = True,
598
+ plot_: bool = True,
599
+ dir_save: str = "./",
600
+ ) -> dict:
601
+ """
602
+ Master function to perform feature selection and validate models.
603
+ """
604
+ from sklearn.compose import ColumnTransformer
605
+ from sklearn.preprocessing import StandardScaler, OneHotEncoder
606
+
607
+ # Ensure X and y are DataFrames/Series for consistency
608
+ if isinstance(X, np.ndarray):
609
+ X = pd.DataFrame(X)
610
+ if isinstance(y, (np.ndarray, list)):
611
+ y = pd.Series(y)
612
+
613
+ # fill na
614
+ if fill_missing:
615
+ ips.df_fillna(data=X, method="knn", inplace=True, axis=0)
616
+ if isinstance(y, str) and y in X.columns:
617
+ y_col_name = y
618
+ y = X[y]
619
+ y = ips.df_encoder(pd.DataFrame(y), method="dummy")
620
+ X = X.drop(y_col_name, axis=1)
621
+ else:
622
+ y = ips.df_encoder(pd.DataFrame(y), method="dummy").values.ravel()
623
+ y = y.loc[X.index] # Align y with X after dropping rows with missing values in X
624
+ y = y.ravel() if isinstance(y, np.ndarray) else y.values.ravel()
625
+
626
+ if X.shape[0] != len(y):
627
+ raise ValueError("X and y must have the same number of samples (rows).")
628
+
629
+ # #! # Check for non-numeric columns in X and apply one-hot encoding if needed
630
+ # Check if any column in X is non-numeric
631
+ if any(not np.issubdtype(dtype, np.number) for dtype in X.dtypes):
632
+ X = pd.get_dummies(X, drop_first=True)
633
+ print(X.shape)
634
+
635
+ # #!alternative: # Identify categorical and numerical columns
636
+ # categorical_cols = X.select_dtypes(include=["object", "category"]).columns
637
+ # numerical_cols = X.select_dtypes(include=["number"]).columns
638
+
639
+ # # Define preprocessing pipeline
640
+ # preprocessor = ColumnTransformer(
641
+ # transformers=[
642
+ # ("num", StandardScaler(), numerical_cols),
643
+ # ("cat", OneHotEncoder(drop="first", handle_unknown="ignore"), categorical_cols),
644
+ # ]
645
+ # )
646
+ # # Preprocess the data
647
+ # X = preprocessor.fit_transform(X)
648
+
649
+ # Split data into training and test sets
650
+ x_train, x_test, y_train, y_test = train_test_split(
651
+ X, y, test_size=test_size, random_state=random_state
652
+ )
653
+ # Standardize features
654
+ scaler = StandardScaler()
655
+ x_train_scaled = scaler.fit_transform(x_train)
656
+ x_test_scaled = scaler.transform(x_test)
657
+
658
+ # Convert back to DataFrame for consistency
659
+ x_train = pd.DataFrame(x_train_scaled, columns=x_train.columns)
660
+ x_test = pd.DataFrame(x_test_scaled, columns=x_test.columns)
661
+
662
+ rf_defaults = {"n_estimators": 100, "random_state": random_state}
663
+ rfe_defaults = {"kernel": "linear", "n_features_to_select": n_features}
664
+ lasso_defaults = {"alphas": np.logspace(-4, 4, 100), "cv": 10}
665
+ ridge_defaults = {"alphas": np.logspace(-4, 4, 100), "cv": 10}
666
+ enet_defaults = {"alphas": np.logspace(-4, 4, 100), "cv": 10}
667
+ xgb_defaults = {
668
+ "n_estimators": 100,
669
+ "use_label_encoder": False,
670
+ "eval_metric": "logloss",
671
+ "random_state": random_state,
672
+ }
673
+ gb_defaults = {"n_estimators": 100, "random_state": random_state}
674
+ adaboost_defaults = {"n_estimators": 50, "random_state": random_state}
675
+ dt_defaults = {"max_depth": None, "random_state": random_state}
676
+ bagging_defaults = {"n_estimators": 50, "random_state": random_state}
677
+ knn_defaults = {"n_neighbors": 5}
678
+ rf_params, rfe_params = rf_params or rf_defaults, rfe_params or rfe_defaults
679
+ lasso_params, ridge_params = (
680
+ lasso_params or lasso_defaults,
681
+ ridge_params or ridge_defaults,
682
+ )
683
+ enet_params, xgb_params = enet_params or enet_defaults, xgb_params or xgb_defaults
684
+ gb_params, adaboost_params = (
685
+ gb_params or gb_defaults,
686
+ adaboost_params or adaboost_defaults,
687
+ )
688
+ dt_params = dt_params or dt_defaults
689
+ bagging_params = bagging_params or bagging_defaults
690
+ knn_params = knn_params or knn_defaults
691
+
692
+ cls_ = [
693
+ "lasso",
694
+ "ridge",
695
+ "Elastic Net(Enet)",
696
+ "Gradient Boosting",
697
+ "Random Forest (rf)",
698
+ "XGBoost (xgb)",
699
+ "Support Vector Machine(svm)",
700
+ "Naive Bayes",
701
+ "Linear Discriminant Analysis (lda)",
702
+ "adaboost",
703
+ ]
704
+ cls = [ips.strcmp(i, cls_)[0] for i in cls]
705
+
706
+ # Lasso Feature Selection
707
+ lasso_importances = (
708
+ features_lasso(x_train, y_train, lasso_params)
709
+ if "lasso" in cls
710
+ else pd.DataFrame()
711
+ )
712
+ lasso_selected_features = (
713
+ lasso_importances.head(n_features)["feature"].values if "lasso" in cls else []
714
+ )
715
+ # Ridge
716
+ ridge_importances = (
717
+ features_ridge(x_train, y_train, ridge_params)
718
+ if "ridge" in cls
719
+ else pd.DataFrame()
720
+ )
721
+ selected_ridge_features = (
722
+ ridge_importances.head(n_features)["feature"].values if "ridge" in cls else []
723
+ )
724
+ # Elastic Net
725
+ enet_importances = (
726
+ features_enet(x_train, y_train, enet_params)
727
+ if "Enet" in cls
728
+ else pd.DataFrame()
729
+ )
730
+ selected_enet_features = (
731
+ enet_importances.head(n_features)["feature"].values if "Enet" in cls else []
732
+ )
733
+ # Random Forest Feature Importance
734
+ rf_importances = (
735
+ features_rf(x_train, y_train, rf_params)
736
+ if "Random Forest" in cls
737
+ else pd.DataFrame()
738
+ )
739
+ top_rf_features = (
740
+ rf_importances.head(n_features)["feature"].values
741
+ if "Random Forest" in cls
742
+ else []
743
+ )
744
+ # Gradient Boosting Feature Importance
745
+ gb_importances = (
746
+ features_gradient_boosting(x_train, y_train, gb_params)
747
+ if "Gradient Boosting" in cls
748
+ else pd.DataFrame()
749
+ )
750
+ top_gb_features = (
751
+ gb_importances.head(n_features)["feature"].values
752
+ if "Gradient Boosting" in cls
753
+ else []
754
+ )
755
+ # xgb
756
+ xgb_importances = (
757
+ features_xgb(x_train, y_train, xgb_params) if "xgb" in cls else pd.DataFrame()
758
+ )
759
+ top_xgb_features = (
760
+ xgb_importances.head(n_features)["feature"].values if "xgb" in cls else []
761
+ )
762
+
763
+ # SVM with RFE
764
+ selected_svm_features = (
765
+ features_svm(x_train, y_train, rfe_params) if "svm" in cls else []
766
+ )
767
+ # Naive Bayes
768
+ selected_naive_bayes_features = (
769
+ features_naive_bayes(x_train, y_train) if "Naive Bayes" in cls else []
770
+ )
771
+ # lda: linear discriminant analysis
772
+ lda_importances = features_lda(x_train, y_train) if "lda" in cls else pd.DataFrame()
773
+ selected_lda_features = (
774
+ lda_importances.head(n_features)["feature"].values if "lda" in cls else []
775
+ )
776
+ # AdaBoost Feature Importance
777
+ adaboost_importances = (
778
+ features_adaboost(x_train, y_train, adaboost_params)
779
+ if "AdaBoost" in cls
780
+ else pd.DataFrame()
781
+ )
782
+ top_adaboost_features = (
783
+ adaboost_importances.head(n_features)["feature"].values
784
+ if "AdaBoost" in cls
785
+ else []
786
+ )
787
+ # Decision Tree Feature Importance
788
+ dt_importances = (
789
+ features_decision_tree(x_train, y_train, dt_params)
790
+ if "Decision Tree" in cls
791
+ else pd.DataFrame()
792
+ )
793
+ top_dt_features = (
794
+ dt_importances.head(n_features)["feature"].values
795
+ if "Decision Tree" in cls
796
+ else []
797
+ )
798
+ # Bagging Feature Importance
799
+ bagging_importances = (
800
+ features_bagging(x_train, y_train, bagging_params)
801
+ if "Bagging" in cls
802
+ else pd.DataFrame()
803
+ )
804
+ top_bagging_features = (
805
+ bagging_importances.head(n_features)["feature"].values
806
+ if "Bagging" in cls
807
+ else []
808
+ )
809
+ # KNN Feature Importance via Permutation
810
+ knn_importances = (
811
+ features_knn(x_train, y_train, knn_params) if "KNN" in cls else pd.DataFrame()
812
+ )
813
+ top_knn_features = (
814
+ knn_importances.head(n_features)["feature"].values if "KNN" in cls else []
815
+ )
816
+
817
+ #! Find common features
818
+ common_features = ips.shared(
819
+ lasso_selected_features,
820
+ selected_ridge_features,
821
+ selected_enet_features,
822
+ top_rf_features,
823
+ top_gb_features,
824
+ top_xgb_features,
825
+ selected_svm_features,
826
+ selected_naive_bayes_features,
827
+ selected_lda_features,
828
+ top_adaboost_features,
829
+ top_dt_features,
830
+ top_bagging_features,
831
+ top_knn_features,
832
+ strict=strict,
833
+ n_shared=n_shared,
834
+ verbose=False,
835
+ )
836
+
837
+ # Use selected features or all features for model validation
838
+ x_train_selected = (
839
+ x_train[list(common_features)] if use_selected_features else x_train
840
+ )
841
+ x_test_selected = x_test[list(common_features)] if use_selected_features else x_test
842
+
843
+ if metrics is None:
844
+ metrics = ["accuracy", "precision", "recall", "f1", "roc_auc"]
845
+
846
+ # Prepare results DataFrame for selected features
847
+ features_df = pd.DataFrame(
848
+ {
849
+ "type": ["Lasso"] * len(lasso_selected_features)
850
+ + ["Ridge"] * len(selected_ridge_features)
851
+ + ["Random Forest"] * len(top_rf_features)
852
+ + ["Gradient Boosting"] * len(top_gb_features)
853
+ + ["Enet"] * len(selected_enet_features)
854
+ + ["xgb"] * len(top_xgb_features)
855
+ + ["SVM"] * len(selected_svm_features)
856
+ + ["Naive Bayes"] * len(selected_naive_bayes_features)
857
+ + ["Linear Discriminant Analysis"] * len(selected_lda_features)
858
+ + ["AdaBoost"] * len(top_adaboost_features)
859
+ + ["Decision Tree"] * len(top_dt_features)
860
+ + ["Bagging"] * len(top_bagging_features)
861
+ + ["KNN"] * len(top_knn_features),
862
+ "feature": np.concatenate(
863
+ [
864
+ lasso_selected_features,
865
+ selected_ridge_features,
866
+ top_rf_features,
867
+ top_gb_features,
868
+ selected_enet_features,
869
+ top_xgb_features,
870
+ selected_svm_features,
871
+ selected_naive_bayes_features,
872
+ selected_lda_features,
873
+ top_adaboost_features,
874
+ top_dt_features,
875
+ top_bagging_features,
876
+ top_knn_features,
877
+ ]
878
+ ),
879
+ }
880
+ )
881
+
882
+ #! Validate trained each classifier
883
+ models = get_models(random_state=random_state, cls=cls)
884
+ cv_train_results, cv_test_results = [], []
885
+ for name, clf in models.items():
886
+ if not x_train_selected.empty:
887
+ cv_scores = validate_classifier(
888
+ clf,
889
+ x_train_selected,
890
+ y_train,
891
+ x_test_selected,
892
+ y_test,
893
+ metrics=metrics,
894
+ cv_folds=cv_folds,
895
+ )
896
+
897
+ cv_train_score_df = pd.DataFrame(cv_scores["cv_train_scores"], index=[name])
898
+ cv_test_score_df = pd.DataFrame(cv_scores["cv_test_scores"], index=[name])
899
+ cv_train_results.append(cv_train_score_df)
900
+ cv_test_results.append(cv_test_score_df)
901
+ if all([cv_train_results, cv_test_results]):
902
+ cv_train_results_df = (
903
+ pd.concat(cv_train_results)
904
+ .reset_index()
905
+ .rename(columns={"index": "Classifier"})
906
+ )
907
+ cv_test_results_df = (
908
+ pd.concat(cv_test_results)
909
+ .reset_index()
910
+ .rename(columns={"index": "Classifier"})
911
+ )
912
+ #! Store results in the main results dictionary
913
+ results = {
914
+ "selected_features": features_df,
915
+ "cv_train_scores": cv_train_results_df,
916
+ "cv_test_scores": rank_models(cv_test_results_df, plot_=plot_),
917
+ "common_features": list(common_features),
918
+ }
919
+ if all([plot_, dir_save]):
920
+ from datetime import datetime
921
+
922
+ now_ = datetime.now().strftime("%y%m%d_%H%M%S")
923
+ ips.figsave(dir_save + f"features{now_}.pdf")
924
+ else:
925
+ results = {
926
+ "selected_features": pd.DataFrame(),
927
+ "cv_train_scores": pd.DataFrame(),
928
+ "cv_test_scores": pd.DataFrame(),
929
+ "common_features": [],
930
+ }
931
+ print(f"Warning: 没有找到共同的genes, when n_shared={n_shared}")
932
+ return results
933
+
934
+
935
+ #! # usage:
936
+ # # Get features and common features
937
+ # results = get_features(X, y)
938
+ # common_features = results["common_features"]
939
+ def validate_features(
940
+ x_train: pd.DataFrame,
941
+ y_train: pd.Series,
942
+ x_true: pd.DataFrame,
943
+ y_true: pd.Series,
944
+ common_features: set = None,
945
+ models: Optional[Dict[str, Any]] = None,
946
+ metrics: Optional[list] = None,
947
+ random_state: int = 1,
948
+ smote: bool = False,
949
+ n_jobs: int = -1,
950
+ plot_: bool = True,
951
+ class_weight: str = "balanced",
952
+ ) -> dict:
953
+ """
954
+ Validate models using selected features on the validation dataset.
955
+
956
+ Parameters:
957
+ - x_train (pd.DataFrame): Training feature dataset.
958
+ - y_train (pd.Series): Training target variable.
959
+ - x_true (pd.DataFrame): Validation feature dataset.
960
+ - y_true (pd.Series): Validation target variable.
961
+ - common_features (set): Set of common features to use for validation.
962
+ - models (dict, optional): Dictionary of models to validate.
963
+ - metrics (list, optional): List of metrics to compute.
964
+ - random_state (int): Random state for reproducibility.
965
+ - plot_ (bool): Option to plot metrics (to be implemented if needed).
966
+ - class_weight (str or dict): Class weights to handle imbalance.
967
+
968
+ """
969
+ from tqdm import tqdm
970
+
971
+ # Ensure common features are selected
972
+ common_features = ips.shared(
973
+ common_features, x_train.columns, x_true.columns, strict=True, verbose=False
974
+ )
975
+
976
+ # Filter the training and validation datasets for the common features
977
+ x_train_selected = x_train[common_features]
978
+ x_true_selected = x_true[common_features]
979
+
980
+ if not x_true_selected.index.equals(y_true.index):
981
+ raise ValueError(
982
+ "Index mismatch between validation features and target. Ensure data alignment."
983
+ )
984
+
985
+ y_true = y_true.loc[x_true_selected.index]
986
+
987
+ # Handle class imbalance using SMOTE
988
+ if smote:
989
+ if (
990
+ y_train.value_counts(normalize=True).max() < 0.8
991
+ ): # Threshold to decide if data is imbalanced
992
+ smote = SMOTE(random_state=random_state)
993
+ x_train_resampled, y_train_resampled = smote.fit_resample(
994
+ x_train_selected, y_train
995
+ )
996
+ else:
997
+ # skip SMOTE
998
+ x_train_resampled, y_train_resampled = x_train_selected, y_train
999
+ else:
1000
+ x_train_resampled, y_train_resampled = x_train_selected, y_train
1001
+
1002
+ # Default models if not provided
1003
+ if models is None:
1004
+ models = {
1005
+ "Random Forest": RandomForestClassifier(
1006
+ class_weight=class_weight, random_state=random_state
1007
+ ),
1008
+ "SVM": SVC(probability=True, class_weight=class_weight),
1009
+ "Logistic Regression": LogisticRegression(
1010
+ class_weight=class_weight, random_state=random_state
1011
+ ),
1012
+ "Gradient Boosting": GradientBoostingClassifier(random_state=random_state),
1013
+ "AdaBoost": AdaBoostClassifier(
1014
+ random_state=random_state, algorithm="SAMME"
1015
+ ),
1016
+ "Lasso": LogisticRegression(
1017
+ penalty="l1", solver="saga", random_state=random_state
1018
+ ),
1019
+ "Ridge": LogisticRegression(
1020
+ penalty="l2", solver="saga", random_state=random_state
1021
+ ),
1022
+ "Elastic Net": LogisticRegression(
1023
+ penalty="elasticnet",
1024
+ solver="saga",
1025
+ l1_ratio=0.5,
1026
+ random_state=random_state,
1027
+ ),
1028
+ "XGBoost": xgb.XGBClassifier(eval_metric="logloss"),
1029
+ "Naive Bayes": GaussianNB(),
1030
+ "LDA": LinearDiscriminantAnalysis(),
1031
+ }
1032
+
1033
+ # Hyperparameter grids for tuning
1034
+ param_grids = {
1035
+ "Random Forest": {
1036
+ "n_estimators": [100, 200, 300, 400, 500],
1037
+ "max_depth": [None, 3, 5, 10, 20],
1038
+ "min_samples_split": [2, 5, 10],
1039
+ "min_samples_leaf": [1, 2, 4],
1040
+ "class_weight": [None, "balanced"],
1041
+ },
1042
+ "SVM": {
1043
+ "C": [0.01, 0.1, 1, 10, 100, 1000],
1044
+ "gamma": [0.001, 0.01, 0.1, "scale", "auto"],
1045
+ "kernel": ["linear", "rbf", "poly"],
1046
+ },
1047
+ "Logistic Regression": {
1048
+ "C": [0.01, 0.1, 1, 10, 100],
1049
+ "solver": ["liblinear", "saga", "newton-cg", "lbfgs"],
1050
+ "penalty": ["l1", "l2"],
1051
+ "max_iter": [100, 200, 300],
1052
+ },
1053
+ "Gradient Boosting": {
1054
+ "n_estimators": [100, 200, 300, 400, 500],
1055
+ "learning_rate": np.logspace(-3, 0, 4),
1056
+ "max_depth": [3, 5, 7, 9],
1057
+ "min_samples_split": [2, 5, 10],
1058
+ },
1059
+ "AdaBoost": {
1060
+ "n_estimators": [50, 100, 200, 300, 500],
1061
+ "learning_rate": np.logspace(-3, 0, 4),
1062
+ },
1063
+ "Lasso": {"C": np.logspace(-3, 1, 10), "max_iter": [100, 200, 300]},
1064
+ "Ridge": {"C": np.logspace(-3, 1, 10), "max_iter": [100, 200, 300]},
1065
+ "Elastic Net": {
1066
+ "C": np.logspace(-3, 1, 10),
1067
+ "l1_ratio": [0.1, 0.5, 0.9],
1068
+ "max_iter": [100, 200, 300],
1069
+ },
1070
+ "XGBoost": {
1071
+ "n_estimators": [100, 200],
1072
+ "max_depth": [3, 5, 7],
1073
+ "learning_rate": [0.01, 0.1, 0.2],
1074
+ "subsample": [0.8, 1.0],
1075
+ "colsample_bytree": [0.8, 1.0],
1076
+ },
1077
+ "Naive Bayes": {},
1078
+ "LDA": {"solver": ["svd", "lsqr", "eigen"]},
1079
+ }
1080
+ # Default metrics if not provided
1081
+ if metrics is None:
1082
+ metrics = [
1083
+ "accuracy",
1084
+ "precision",
1085
+ "recall",
1086
+ "f1",
1087
+ "roc_auc",
1088
+ "mcc",
1089
+ "specificity",
1090
+ "balanced_accuracy",
1091
+ "pr_auc",
1092
+ ]
1093
+
1094
+ results = {}
1095
+
1096
+ # Validate each classifier with GridSearchCV
1097
+ for name, clf in tqdm(
1098
+ models.items(),
1099
+ desc="for metric in metrics",
1100
+ colour="green",
1101
+ bar_format="{l_bar}{bar} {n_fmt}/{total_fmt}",
1102
+ ):
1103
+ print(f"\nValidating {name} on the validation dataset:")
1104
+
1105
+ # Check if `predict_proba` method exists; if not, use CalibratedClassifierCV
1106
+ # 没有predict_proba的分类器,使用 CalibratedClassifierCV 可以获得校准的概率估计。此外,为了使代码更灵活,我们可以在创建分类器
1107
+ # 时检查 predict_proba 方法是否存在,如果不存在且用户希望计算 roc_auc 或 pr_auc,则启用 CalibratedClassifierCV
1108
+ if not hasattr(clf, "predict_proba"):
1109
+ print(
1110
+ f"Using CalibratedClassifierCV for {name} due to lack of probability estimates."
1111
+ )
1112
+ calibrated_clf = CalibratedClassifierCV(clf, method="sigmoid", cv="prefit")
1113
+ else:
1114
+ calibrated_clf = clf
1115
+ # Stratified K-Fold for cross-validation
1116
+ skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=random_state)
1117
+
1118
+ # Create GridSearchCV object
1119
+ gs = GridSearchCV(
1120
+ estimator=calibrated_clf,
1121
+ param_grid=param_grids[name],
1122
+ scoring="roc_auc", # Optimize for ROC AUC
1123
+ cv=skf, # Stratified K-Folds cross-validation
1124
+ n_jobs=n_jobs,
1125
+ verbose=1,
1126
+ )
1127
+
1128
+ # Fit the model using GridSearchCV
1129
+ gs.fit(x_train_resampled, y_train_resampled)
1130
+ # Best estimator from grid search
1131
+ best_clf = gs.best_estimator_
1132
+ # Make predictions on the validation set
1133
+ y_pred = best_clf.predict(x_true_selected)
1134
+ # Calculate probabilities for ROC AUC if possible
1135
+ if hasattr(best_clf, "predict_proba"):
1136
+ y_pred_proba = best_clf.predict_proba(x_true_selected)[:, 1]
1137
+ elif hasattr(best_clf, "decision_function"):
1138
+ # If predict_proba is not available, use decision_function (e.g., for SVM)
1139
+ y_pred_proba = best_clf.decision_function(x_true_selected)
1140
+ # Ensure y_pred_proba is within 0 and 1 bounds
1141
+ y_pred_proba = (y_pred_proba - y_pred_proba.min()) / (
1142
+ y_pred_proba.max() - y_pred_proba.min()
1143
+ )
1144
+ else:
1145
+ y_pred_proba = None # No probability output for certain models
1146
+
1147
+ # Calculate metrics
1148
+ validation_scores = {}
1149
+ for metric in metrics:
1150
+ if metric == "accuracy":
1151
+ validation_scores[metric] = accuracy_score(y_true, y_pred)
1152
+ elif metric == "precision":
1153
+ validation_scores[metric] = precision_score(
1154
+ y_true, y_pred, average="weighted"
1155
+ )
1156
+ elif metric == "recall":
1157
+ validation_scores[metric] = recall_score(
1158
+ y_true, y_pred, average="weighted"
1159
+ )
1160
+ elif metric == "f1":
1161
+ validation_scores[metric] = f1_score(y_true, y_pred, average="weighted")
1162
+ elif metric == "roc_auc" and y_pred_proba is not None:
1163
+ validation_scores[metric] = roc_auc_score(y_true, y_pred_proba)
1164
+ elif metric == "mcc":
1165
+ validation_scores[metric] = matthews_corrcoef(y_true, y_pred)
1166
+ elif metric == "specificity":
1167
+ tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
1168
+ validation_scores[metric] = tn / (tn + fp) # Specificity calculation
1169
+ elif metric == "balanced_accuracy":
1170
+ validation_scores[metric] = balanced_accuracy_score(y_true, y_pred)
1171
+ elif metric == "pr_auc" and y_pred_proba is not None:
1172
+ precision, recall, _ = precision_recall_curve(y_true, y_pred_proba)
1173
+ validation_scores[metric] = average_precision_score(
1174
+ y_true, y_pred_proba
1175
+ )
1176
+
1177
+ # Calculate ROC curve
1178
+ # https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html
1179
+ if y_pred_proba is not None:
1180
+ # fpr, tpr, roc_auc = dict(), dict(), dict()
1181
+ fpr, tpr, _ = roc_curve(y_true, y_pred_proba)
1182
+ lower_ci, upper_ci = cal_auc_ci(y_true, y_pred_proba, verbose=False)
1183
+ roc_auc = auc(fpr, tpr)
1184
+ roc_info = {
1185
+ "fpr": fpr.tolist(),
1186
+ "tpr": tpr.tolist(),
1187
+ "auc": roc_auc,
1188
+ "ci95": (lower_ci, upper_ci),
1189
+ }
1190
+ # precision-recall curve
1191
+ precision_, recall_, _ = precision_recall_curve(y_true, y_pred_proba)
1192
+ avg_precision_ = average_precision_score(y_true, y_pred_proba)
1193
+ pr_info = {
1194
+ "precision": precision_,
1195
+ "recall": recall_,
1196
+ "avg_precision": avg_precision_,
1197
+ }
1198
+ else:
1199
+ roc_info, pr_info = None, None
1200
+ results[name] = {
1201
+ "best_params": gs.best_params_,
1202
+ "scores": validation_scores,
1203
+ "roc_curve": roc_info,
1204
+ "pr_curve": pr_info,
1205
+ "confusion_matrix": confusion_matrix(y_true, y_pred),
1206
+ }
1207
+
1208
+ df_results = pd.DataFrame.from_dict(results, orient="index")
1209
+
1210
+ return df_results
1211
+
1212
+
1213
+ #! usage validate_features()
1214
+ # Validate models using the validation dataset (X_val, y_val)
1215
+ # validation_results = validate_features(X, y, X_val, y_val, common_features)
1216
+
1217
+
1218
+ # # If you want to access validation scores
1219
+ # print(validation_results)
1220
+ def plot_validate_features(res_val):
1221
+ """
1222
+ plot the results of 'validate_features()'
1223
+ """
1224
+ colors = plot.get_color(len(ips.flatten(res_val["pr_curve"].index)))
1225
+ if res_val.shape[0] > 5:
1226
+ alpha = 0
1227
+ figsize = [8, 10]
1228
+ subplot_layout = [1, 2]
1229
+ ncols = 2
1230
+ bbox_to_anchor = [1.5, 0.6]
1231
+ else:
1232
+ alpha = 0.03
1233
+ figsize = [10, 6]
1234
+ subplot_layout = [1, 1]
1235
+ ncols = 1
1236
+ bbox_to_anchor = [1, 1]
1237
+ nexttile = plot.subplot(figsize=figsize)
1238
+ ax = nexttile(subplot_layout[0], subplot_layout[1])
1239
+ for i, model_name in enumerate(ips.flatten(res_val["pr_curve"].index)):
1240
+ fpr = res_val["roc_curve"][model_name]["fpr"]
1241
+ tpr = res_val["roc_curve"][model_name]["tpr"]
1242
+ (lower_ci, upper_ci) = res_val["roc_curve"][model_name]["ci95"]
1243
+ mean_auc = res_val["roc_curve"][model_name]["auc"]
1244
+ plot_roc_curve(
1245
+ fpr,
1246
+ tpr,
1247
+ mean_auc,
1248
+ lower_ci,
1249
+ upper_ci,
1250
+ model_name=model_name,
1251
+ lw=1.5,
1252
+ color=colors[i],
1253
+ alpha=alpha,
1254
+ ax=ax,
1255
+ )
1256
+ plot.figsets(
1257
+ sp=2,
1258
+ legend=dict(
1259
+ loc="upper right",
1260
+ ncols=ncols,
1261
+ fontsize=8,
1262
+ bbox_to_anchor=[1.5, 0.6],
1263
+ markerscale=0.8,
1264
+ ),
1265
+ )
1266
+ # plot.split_legend(ax,n=2, loc=["upper left", "lower left"],bbox=[[1,0.5],[1,0.5]],ncols=2,labelcolor="k",fontsize=8)
1267
+
1268
+ ax = nexttile(subplot_layout[0], subplot_layout[1])
1269
+ for i, model_name in enumerate(ips.flatten(res_val["pr_curve"].index)):
1270
+ plot_pr_curve(
1271
+ recall=res_val["pr_curve"][model_name]["recall"],
1272
+ precision=res_val["pr_curve"][model_name]["precision"],
1273
+ avg_precision=res_val["pr_curve"][model_name]["avg_precision"],
1274
+ model_name=model_name,
1275
+ color=colors[i],
1276
+ lw=1.5,
1277
+ alpha=alpha,
1278
+ ax=ax,
1279
+ )
1280
+ plot.figsets(
1281
+ sp=2,
1282
+ legend=dict(loc="upper right", ncols=1, fontsize=8, bbox_to_anchor=[1.5, 0.5]),
1283
+ )
1284
+ # plot.split_legend(ax,n=2, loc=["upper left", "lower left"],bbox=[[1,0.5],[1,0.5]],ncols=2,labelcolor="k",fontsize=8)
1285
+
1286
+
1287
+ def plot_validate_features_single(res_val, figsize=None):
1288
+ if figsize is None:
1289
+ nexttile = plot.subplot(len(ips.flatten(res_val["pr_curve"].index)), 3)
1290
+ else:
1291
+ nexttile = plot.subplot(
1292
+ len(ips.flatten(res_val["pr_curve"].index)), 3, figsize=figsize
1293
+ )
1294
+ for model_name in ips.flatten(res_val["pr_curve"].index):
1295
+ fpr = res_val["roc_curve"][model_name]["fpr"]
1296
+ tpr = res_val["roc_curve"][model_name]["tpr"]
1297
+ (lower_ci, upper_ci) = res_val["roc_curve"][model_name]["ci95"]
1298
+ mean_auc = res_val["roc_curve"][model_name]["auc"]
1299
+
1300
+ # Plotting
1301
+ plot_roc_curve(fpr, tpr, mean_auc, lower_ci, upper_ci,
1302
+ model_name=model_name, ax=nexttile())
1303
+ plot.figsets(title=model_name, sp=2)
1304
+
1305
+ plot_pr_binary(
1306
+ recall=res_val["pr_curve"][model_name]["recall"],
1307
+ precision=res_val["pr_curve"][model_name]["precision"],
1308
+ avg_precision=res_val["pr_curve"][model_name]["avg_precision"],
1309
+ model_name=model_name,
1310
+ ax=nexttile(),
1311
+ )
1312
+ plot.figsets(title=model_name, sp=2)
1313
+
1314
+ # plot cm
1315
+ plot_cm(res_val["confusion_matrix"][model_name], ax=nexttile(), normalize=False)
1316
+ plot.figsets(title=model_name, sp=2)
1317
+
1318
+
1319
+ def cal_auc_ci(
1320
+ y_true, y_pred, n_bootstraps=1000, ci=0.95, random_state=1, verbose=True
1321
+ ):
1322
+ y_true = np.asarray(y_true)
1323
+ y_pred = np.asarray(y_pred)
1324
+ bootstrapped_scores = []
1325
+ if verbose:
1326
+ print("auroc score:", roc_auc_score(y_true, y_pred))
1327
+ rng = np.random.RandomState(random_state)
1328
+ for i in range(n_bootstraps):
1329
+ # bootstrap by sampling with replacement on the prediction indices
1330
+ indices = rng.randint(0, len(y_pred), len(y_pred))
1331
+ if len(np.unique(y_true[indices])) < 2:
1332
+ # We need at least one positive and one negative sample for ROC AUC
1333
+ # to be defined: reject the sample
1334
+ continue
1335
+ if isinstance(y_true, np.ndarray):
1336
+ score = roc_auc_score(y_true[indices], y_pred[indices])
1337
+ else:
1338
+ score = roc_auc_score(y_true.iloc[indices], y_pred.iloc[indices])
1339
+ bootstrapped_scores.append(score)
1340
+ # print("Bootstrap #{} ROC area: {:0.3f}".format(i + 1, score))
1341
+ sorted_scores = np.array(bootstrapped_scores)
1342
+ sorted_scores.sort()
1343
+
1344
+ # Computing the lower and upper bound of the 90% confidence interval
1345
+ # You can change the bounds percentiles to 0.025 and 0.975 to get
1346
+ # a 95% confidence interval instead.
1347
+ confidence_lower = sorted_scores[int((1 - ci) * len(sorted_scores))]
1348
+ confidence_upper = sorted_scores[int(ci * len(sorted_scores))]
1349
+ if verbose:
1350
+ print(
1351
+ "Confidence interval for the score: [{:0.3f} - {:0.3}]".format(
1352
+ confidence_lower, confidence_upper
1353
+ )
1354
+ )
1355
+ return confidence_lower, confidence_upper
1356
+
1357
+
1358
+ def plot_roc_curve(
1359
+ fpr=None,
1360
+ tpr=None,
1361
+ mean_auc=None,
1362
+ lower_ci=None,
1363
+ upper_ci=None,
1364
+ model_name=None,
1365
+ color="#FF8F00",
1366
+ lw=2,
1367
+ alpha=0.1,
1368
+ ci_display=True,
1369
+ title="ROC Curve",
1370
+ xlabel="1−Specificity",
1371
+ ylabel="Sensitivity",
1372
+ legend_loc="lower right",
1373
+ diagonal_color="0.5",
1374
+ figsize=(5, 5),
1375
+ ax=None,
1376
+ **kwargs,
1377
+ ):
1378
+ if ax is None:
1379
+ fig, ax = plt.subplots(figsize=figsize)
1380
+ if mean_auc is not None:
1381
+ model_name = "ROC curve" if model_name is None else model_name
1382
+ if ci_display:
1383
+ label = f"{model_name} (AUC = {mean_auc:.3f})\n95% CI: {lower_ci:.3f} - {upper_ci:.3f}"
1384
+ else:
1385
+ label = f"{model_name} (AUC = {mean_auc:.3f})"
1386
+ else:
1387
+ label = None
1388
+
1389
+ # Plot ROC curve and the diagonal reference line
1390
+ ax.fill_between(fpr, tpr, alpha=alpha, color=color)
1391
+ ax.plot([0, 1], [0, 1], color=diagonal_color, clip_on=False, linestyle="--")
1392
+ ax.plot(fpr, tpr, color=color, lw=lw, label=label, clip_on=False, **kwargs)
1393
+ # Setting plot limits, labels, and title
1394
+ ax.set_xlim([-0.01, 1.0])
1395
+ ax.set_ylim([0.0, 1.0])
1396
+ ax.set_xlabel(xlabel)
1397
+ ax.set_ylabel(ylabel)
1398
+ ax.set_title(title)
1399
+ ax.legend(loc=legend_loc)
1400
+ return ax
1401
+
1402
+
1403
+ # * usage: ml2ls.plot_roc_curve(fpr, tpr, mean_auc, lower_ci, upper_ci)
1404
+ # for model_name in flatten(validation_results["roc_curve"].keys())[2:]:
1405
+ # fpr = validation_results["roc_curve"][model_name]["fpr"]
1406
+ # tpr = validation_results["roc_curve"][model_name]["tpr"]
1407
+ # (lower_ci, upper_ci) = validation_results["roc_curve"][model_name]["ci95"]
1408
+ # mean_auc = validation_results["roc_curve"][model_name]["auc"]
1409
+
1410
+ # # Plotting
1411
+ # ml2ls.plot_roc_curve(fpr, tpr, mean_auc, lower_ci, upper_ci)
1412
+ # figsets(title=model_name)
1413
+
1414
+ def plot_pr_curve(
1415
+ recall=None,
1416
+ precision=None,
1417
+ avg_precision=None,
1418
+ model_name=None,
1419
+ lw=2,
1420
+ figsize=[5, 5],
1421
+ title="Precision-Recall Curve",
1422
+ xlabel="Recall",
1423
+ ylabel="Precision",
1424
+ alpha=0.1,
1425
+ color="#FF8F00",
1426
+ legend_loc="lower left",
1427
+ ax=None,
1428
+ **kwargs,
1429
+ ):
1430
+ if ax is None:
1431
+ fig, ax = plt.subplots(figsize=figsize)
1432
+ model_name = "PR curve" if model_name is None else model_name
1433
+ # Plot Precision-Recall curve
1434
+ ax.plot(
1435
+ recall,
1436
+ precision,
1437
+ lw=lw,
1438
+ color=color,
1439
+ label=(f"{model_name} (AP={avg_precision:.2f})"),
1440
+ clip_on=False,
1441
+ **kwargs,
1442
+ )
1443
+ # Fill area under the curve
1444
+ ax.fill_between(recall, precision, alpha=alpha, color=color)
1445
+
1446
+ # Customize axes
1447
+ ax.set_title(title)
1448
+ ax.set_xlabel(xlabel)
1449
+ ax.set_ylabel(ylabel)
1450
+ ax.set_xlim([-0.01, 1.0])
1451
+ ax.set_ylim([0.0, 1.0])
1452
+ ax.grid(False)
1453
+ ax.legend(loc=legend_loc)
1454
+ return ax
1455
+
1456
+ # * usage: ml2ls.plot_pr_curve()
1457
+ # for md_name in flatten(validation_results["pr_curve"].keys()):
1458
+ # ml2ls.plot_pr_curve(
1459
+ # recall=validation_results["pr_curve"][md_name]["recall"],
1460
+ # precision=validation_results["pr_curve"][md_name]["precision"],
1461
+ # avg_precision=validation_results["pr_curve"][md_name]["avg_precision"],
1462
+ # model_name=md_name,
1463
+ # lw=2,
1464
+ # alpha=0.1,
1465
+ # color="r",
1466
+ # )
1467
+
1468
+ def plot_pr_binary(
1469
+ recall=None,
1470
+ precision=None,
1471
+ avg_precision=None,
1472
+ model_name=None,
1473
+ lw=2,
1474
+ figsize=[5, 5],
1475
+ title="Precision-Recall Curve",
1476
+ xlabel="Recall",
1477
+ ylabel="Precision",
1478
+ alpha=0.1,
1479
+ color="#FF8F00",
1480
+ legend_loc="lower left",
1481
+ ax=None,
1482
+ show_avg_precision=False,
1483
+ **kwargs,
1484
+ ):
1485
+ from scipy.interpolate import interp1d
1486
+ if ax is None:
1487
+ fig, ax = plt.subplots(figsize=figsize)
1488
+ model_name = "Binary PR Curve" if model_name is None else model_name
1489
+
1490
+ #* use sklearn bulitin function 'PrecisionRecallDisplay'?
1491
+ # from sklearn.metrics import PrecisionRecallDisplay
1492
+ # disp = PrecisionRecallDisplay(precision=precision,
1493
+ # recall=recall,
1494
+ # average_precision=avg_precision,**kwargs)
1495
+ # disp.plot(ax=ax, name=model_name, color=color)
1496
+
1497
+ # Plot Precision-Recall curve
1498
+ ax.plot(
1499
+ recall,
1500
+ precision,
1501
+ lw=lw,
1502
+ color=color,
1503
+ label=(f"{model_name} (AP={avg_precision:.2f})"),
1504
+ clip_on=False,
1505
+ **kwargs,
1506
+ )
1507
+
1508
+ # Fill area under the curve
1509
+ ax.fill_between(recall, precision, alpha=alpha, color=color)
1510
+ # Add F1 score iso-contours
1511
+ f_scores = np.linspace(0.2, 0.8, num=4)
1512
+ # for f_score in f_scores:
1513
+ # x = np.linspace(0.01, 1)
1514
+ # y = f_score * x / (2 * x - f_score)
1515
+ # plt.plot(x[y >= 0], y[y >= 0], color="gray", alpha=1)
1516
+ # plt.annotate(f"$f_1={f_score:0.1f}$", xy=(0.8, y[45] + 0.02))
1517
+
1518
+ pr_boundary = interp1d(recall, precision, kind="linear", fill_value="extrapolate")
1519
+ for f_score in f_scores:
1520
+ x_vals = np.linspace(0.01, 1, 10000)
1521
+ y_vals = f_score * x_vals / (2 * x_vals - f_score)
1522
+ y_vals_clipped = np.minimum(y_vals, pr_boundary(x_vals))
1523
+ y_vals_clipped = np.clip(y_vals_clipped, 1e-3, None) # Prevent going to zero
1524
+ valid = y_vals_clipped < pr_boundary(x_vals)
1525
+ valid_ = y_vals_clipped > 1e-3
1526
+ valid = valid&valid_
1527
+ x_vals = x_vals[valid]
1528
+ y_vals_clipped = y_vals_clipped[valid]
1529
+ if len(x_vals) > 0: # Ensure annotation is placed only if line segment exists
1530
+ ax.plot(x_vals, y_vals_clipped, color="gray", alpha=1)
1531
+ plt.annotate(f"$f_1={f_score:0.1f}$", xy=(0.8, y_vals_clipped[-int(len(y_vals_clipped)*0.35)] + 0.02))
1532
+
1533
+
1534
+ # # Plot the average precision line
1535
+ if show_avg_precision:
1536
+ plt.axhline(
1537
+ y=avg_precision,
1538
+ color="red",
1539
+ ls="--",
1540
+ lw=lw,
1541
+ label=f"Avg. precision={avg_precision:.2f}",
1542
+ )
1543
+ # Customize axes
1544
+ ax.set_title(title)
1545
+ ax.set_xlabel(xlabel)
1546
+ ax.set_ylabel(ylabel)
1547
+ ax.set_xlim([-0.01, 1.0])
1548
+ ax.set_ylim([0.0, 1.0])
1549
+ ax.grid(False)
1550
+ ax.legend(loc=legend_loc)
1551
+ return ax
1552
+
1553
+ def plot_cm(
1554
+ cm,
1555
+ labels_name=None,
1556
+ thresh=0.8,
1557
+ axis_labels=None,
1558
+ cmap="Reds",
1559
+ normalize=True,
1560
+ xlabel="Predicted Label",
1561
+ ylabel="Actual Label",
1562
+ fontsize=12,
1563
+ figsize=[5, 5],
1564
+ ax=None,
1565
+ ):
1566
+ if ax is None:
1567
+ fig, ax = plt.subplots(figsize=figsize)
1568
+
1569
+ cm_normalized = np.round(
1570
+ cm.astype("float") / cm.sum(axis=1)[:, np.newaxis] * 100, 2
1571
+ )
1572
+ cm_value = cm_normalized if normalize else cm.astype("int")
1573
+ # Plot the heatmap
1574
+ cax = ax.imshow(cm_normalized, interpolation="nearest", cmap=cmap)
1575
+ plt.colorbar(cax, ax=ax, fraction=0.046, pad=0.04)
1576
+ cax.set_clim(0, 100)
1577
+
1578
+ # Define tick labels based on provided labels
1579
+ num_local = np.arange(len(labels_name)) if labels_name is not None else range(2)
1580
+ if axis_labels is None:
1581
+ axis_labels = labels_name if labels_name is not None else ["No", "Yes"]
1582
+ ax.set_xticks(num_local)
1583
+ ax.set_xticklabels(axis_labels)
1584
+ ax.set_yticks(num_local)
1585
+ ax.set_yticklabels(axis_labels)
1586
+ ax.set_ylabel(ylabel)
1587
+ ax.set_xlabel(xlabel)
1588
+
1589
+ # Add TN, FP, FN, TP annotations specifically for binary classification (2x2 matrix)
1590
+ if labels_name is None or len(labels_name) == 2:
1591
+ # True Negative (TN), False Positive (FP), False Negative (FN), and True Positive (TP)
1592
+ # Predicted
1593
+ # 0 | 1
1594
+ # ----------------
1595
+ # 0 | TN | FP
1596
+ # Actual ----------------
1597
+ # 1 | FN | TP
1598
+ tn_label = "TN"
1599
+ fp_label = "FP"
1600
+ fn_label = "FN"
1601
+ tp_label = "TP"
1602
+
1603
+ # Adjust positions slightly for TN, FP, FN, TP labels
1604
+ ax.text(
1605
+ 0,
1606
+ 0,
1607
+ (
1608
+ f"{tn_label}:{cm_normalized[0, 0]:.2f}%"
1609
+ if normalize
1610
+ else f"{tn_label}:{cm_value[0, 0]}"
1611
+ ),
1612
+ ha="center",
1613
+ va="center",
1614
+ color="white" if cm_normalized[0, 0] > thresh * 100 else "black",
1615
+ fontsize=fontsize,
1616
+ )
1617
+ ax.text(
1618
+ 1,
1619
+ 0,
1620
+ (
1621
+ f"{fp_label}:{cm_normalized[0, 1]:.2f}%"
1622
+ if normalize
1623
+ else f"{fp_label}:{cm_value[0, 1]}"
1624
+ ),
1625
+ ha="center",
1626
+ va="center",
1627
+ color="white" if cm_normalized[0, 1] > thresh * 100 else "black",
1628
+ fontsize=fontsize,
1629
+ )
1630
+ ax.text(
1631
+ 0,
1632
+ 1,
1633
+ (
1634
+ f"{fn_label}:{cm_normalized[1, 0]:.2f}%"
1635
+ if normalize
1636
+ else f"{fn_label}:{cm_value[1, 0]}"
1637
+ ),
1638
+ ha="center",
1639
+ va="center",
1640
+ color="white" if cm_normalized[1, 0] > thresh * 100 else "black",
1641
+ fontsize=fontsize,
1642
+ )
1643
+ ax.text(
1644
+ 1,
1645
+ 1,
1646
+ (
1647
+ f"{tp_label}:{cm_normalized[1, 1]:.2f}%"
1648
+ if normalize
1649
+ else f"{tp_label}:{cm_value[1, 1]}"
1650
+ ),
1651
+ ha="center",
1652
+ va="center",
1653
+ color="white" if cm_normalized[1, 1] > thresh * 100 else "black",
1654
+ fontsize=fontsize,
1655
+ )
1656
+ else:
1657
+ # Annotate cells with normalized percentage values
1658
+ for i in range(len(labels_name)):
1659
+ for j in range(len(labels_name)):
1660
+ val = cm_normalized[i, j]
1661
+ color = "white" if val > thresh * 100 else "black"
1662
+ ax.text(
1663
+ j,
1664
+ i,
1665
+ f"{val:.2f}%",
1666
+ ha="center",
1667
+ va="center",
1668
+ color=color,
1669
+ fontsize=fontsize,
1670
+ )
1671
+
1672
+ plot.figsets(ax=ax, boxloc="none")
1673
+ return ax
1674
+
1675
+
1676
+ def rank_models(
1677
+ cv_test_scores,
1678
+ rm_outlier=False,
1679
+ metric_weights=None,
1680
+ plot_=True,
1681
+ ):
1682
+ """
1683
+ Selects the best model based on a multi-metric scoring approach, with outlier handling, optional visualization,
1684
+ and additional performance metrics.
1685
+
1686
+ Parameters:
1687
+ - cv_test_scores (pd.DataFrame): DataFrame with cross-validation results across multiple metrics.
1688
+ Assumes columns are 'Classifier', 'accuracy', 'precision', 'recall', 'f1', 'roc_auc'.
1689
+ - metric_weights (dict): Dictionary specifying weights for each metric (e.g., {'accuracy': 0.2, 'precision': 0.3, ...}).
1690
+ If None, default weights are applied equally across available metrics.
1691
+ a. equal_weights(standard approch): 所有的metrics同等重要
1692
+ e.g., {"accuracy": 0.2, "precision": 0.2, "recall": 0.2, "f1": 0.2, "roc_auc": 0.2}
1693
+ b. accuracy_focosed: classification correctness (e.g., in balanced datasets), accuracy might be weighted more heavily.
1694
+ e.g., {"accuracy": 0.4, "precision": 0.2, "recall": 0.2, "f1": 0.1, "roc_auc": 0.1}
1695
+ c. Precision and Recall Emphasis: In cases where false positives and false negatives are particularly important (such as
1696
+ in medical applications or fraud detection), precision and recall may be weighted more heavily.
1697
+ e.g., {"accuracy": 0.2, "precision": 0.3, "recall": 0.3, "f1": 0.1, "roc_auc": 0.1}
1698
+ d. F1-Focused: When balance between precision and recall is crucial (e.g., in imbalanced datasets)
1699
+ e.g., {"accuracy": 0.2, "precision": 0.2, "recall": 0.2, "f1": 0.3, "roc_auc": 0.1}
1700
+ e. ROC-AUC Emphasis: In some cases, ROC AUC may be prioritized, particularly in classification tasks where class imbalance
1701
+ is present, as ROC AUC accounts for the model's performance across all classification thresholds.
1702
+ e.g., {"accuracy": 0.1, "precision": 0.2, "recall": 0.2, "f1": 0.3, "roc_auc": 0.3}
1703
+
1704
+ - normalize (bool): Whether to normalize scores of each metric to range [0, 1].
1705
+ - visualize (bool): If True, generates visualizations (e.g., bar plot, radar chart).
1706
+ - outlier_threshold (float): The threshold to detect outliers using the IQR method. Default is 1.5.
1707
+ - cv_folds (int): The number of cross-validation folds used.
1708
+
1709
+ Returns:
1710
+ - best_model (str): Name of the best model based on the combined metric scores.
1711
+ - scored_df (pd.DataFrame): DataFrame with an added 'combined_score' column used for model selection.
1712
+ - visualizations (dict): A dictionary containing visualizations if `visualize=True`.
1713
+ """
1714
+ from sklearn.preprocessing import MinMaxScaler
1715
+ import seaborn as sns
1716
+ import matplotlib.pyplot as plt
1717
+ from py2ls import plot
1718
+
1719
+ # Check for missing metrics and set default weights if not provided
1720
+ available_metrics = cv_test_scores.columns[1:] # Exclude 'Classifier' column
1721
+ if metric_weights is None:
1722
+ metric_weights = {
1723
+ metric: 1 / len(available_metrics) for metric in available_metrics
1724
+ } # Equal weight if not specified
1725
+ elif metric_weights == "a":
1726
+ metric_weights = {
1727
+ "accuracy": 0.2,
1728
+ "precision": 0.2,
1729
+ "recall": 0.2,
1730
+ "f1": 0.2,
1731
+ "roc_auc": 0.2,
1732
+ }
1733
+ elif metric_weights == "b":
1734
+ metric_weights = {
1735
+ "accuracy": 0.4,
1736
+ "precision": 0.2,
1737
+ "recall": 0.2,
1738
+ "f1": 0.1,
1739
+ "roc_auc": 0.1,
1740
+ }
1741
+ elif metric_weights == "c":
1742
+ metric_weights = {
1743
+ "accuracy": 0.2,
1744
+ "precision": 0.3,
1745
+ "recall": 0.3,
1746
+ "f1": 0.1,
1747
+ "roc_auc": 0.1,
1748
+ }
1749
+ elif metric_weights == "d":
1750
+ metric_weights = {
1751
+ "accuracy": 0.2,
1752
+ "precision": 0.2,
1753
+ "recall": 0.2,
1754
+ "f1": 0.3,
1755
+ "roc_auc": 0.1,
1756
+ }
1757
+ elif metric_weights == "e":
1758
+ metric_weights = {
1759
+ "accuracy": 0.1,
1760
+ "precision": 0.2,
1761
+ "recall": 0.2,
1762
+ "f1": 0.3,
1763
+ "roc_auc": 0.3,
1764
+ }
1765
+ else:
1766
+ metric_weights = {
1767
+ metric: 1 / len(available_metrics) for metric in available_metrics
1768
+ }
1769
+
1770
+ # Normalize weights if they don’t sum to 1
1771
+ total_weight = sum(metric_weights.values())
1772
+ metric_weights = {
1773
+ metric: weight / total_weight for metric, weight in metric_weights.items()
1774
+ }
1775
+ if rm_outlier:
1776
+ cv_test_scores_ = ips.df_outlier(cv_test_scores)
1777
+ else:
1778
+ cv_test_scores_ = cv_test_scores
1779
+
1780
+ # Normalize the scores of metrics if normalize is True
1781
+ scaler = MinMaxScaler()
1782
+ normalized_scores = pd.DataFrame(
1783
+ scaler.fit_transform(cv_test_scores_[available_metrics]),
1784
+ columns=available_metrics,
1785
+ )
1786
+ cv_test_scores_ = pd.concat(
1787
+ [cv_test_scores_[["Classifier"]], normalized_scores], axis=1
1788
+ )
1789
+
1790
+ # Calculate weighted scores for each model
1791
+ cv_test_scores_["combined_score"] = sum(
1792
+ cv_test_scores_[metric] * weight for metric, weight in metric_weights.items()
1793
+ )
1794
+ top_models = cv_test_scores_.sort_values(by="combined_score", ascending=False)
1795
+ cv_test_scores = cv_test_scores.loc[top_models.index]
1796
+ top_models.reset_index(drop=True, inplace=True)
1797
+ cv_test_scores.reset_index(drop=True, inplace=True)
1798
+
1799
+ if plot_:
1800
+
1801
+ def generate_bar_plot(ax, cv_test_scores):
1802
+ ax = plot.plotxy(
1803
+ y="Classifier", x="combined_score", data=cv_test_scores, kind="bar"
1804
+ )
1805
+ plt.title("Classifier Performance")
1806
+ plt.tight_layout()
1807
+ return plt
1808
+
1809
+ nexttile = plot.subplot(2, 2, figsize=[10, 7])
1810
+ generate_bar_plot(nexttile(), top_models.dropna())
1811
+ plot.radar(
1812
+ ax=nexttile(projection="polar"),
1813
+ data=cv_test_scores.set_index("Classifier"),
1814
+ ylim=[0.5, 1],
1815
+ color=plot.get_color(10),
1816
+ alpha=0.05,
1817
+ circular=1,
1818
+ )
1819
+ return cv_test_scores
1820
+
1821
+
1822
+ # # Example Usage:
1823
+ # metric_weights = {
1824
+ # "accuracy": 0.2,
1825
+ # "precision": 0.3,
1826
+ # "recall": 0.2,
1827
+ # "f1": 0.2,
1828
+ # "roc_auc": 0.1,
1829
+ # }
1830
+ # cv_test_scores = res["cv_test_scores"].copy()
1831
+ # best_model = rank_models(
1832
+ # cv_test_scores, metric_weights=metric_weights, normalize=True, plot_=True
1833
+ # )
1834
+
1835
+ # figsave("classifier_performance.pdf")
1836
+
1837
+
1838
+ def predict(
1839
+ x_train: pd.DataFrame,
1840
+ y_train: pd.Series,
1841
+ x_true: pd.DataFrame = None,
1842
+ y_true: Optional[pd.Series] = None,
1843
+ common_features: set = None,
1844
+ purpose: str = "classification", # 'classification' or 'regression'
1845
+ cls: Optional[Dict[str, Any]] = None,
1846
+ metrics: Optional[List[str]] = None,
1847
+ random_state: int = 1,
1848
+ smote: bool = False,
1849
+ n_jobs: int = -1,
1850
+ plot_: bool = True,
1851
+ dir_save: str = "./",
1852
+ test_size: float = 0.2, # specific only when x_true is None
1853
+ cv_folds: int = 5, # more cv_folds 得更加稳定,auc可能更低
1854
+ cv_level: str = "l", # "s":'low',"m":'medium',"l":"high"
1855
+ class_weight: str = "balanced",
1856
+ verbose: bool = False,
1857
+ ) -> pd.DataFrame:
1858
+ """
1859
+ 第一种情况是内部拆分,第二种是直接预测,第三种是外部验证。
1860
+ Usage:
1861
+ (1). predict(x_train, y_train,...) 对 x_train 进行拆分训练/测试集,并在测试集上进行验证.
1862
+ predict 函数会根据 test_size 参数,将 x_train 和 y_train 拆分出内部测试集。然后模型会在拆分出的训练集上进行训练,并在测试集上验证效果。
1863
+ (2). predict(x_train, y_train, x_true,...)使用 x_train 和 y_train 训练并对 x_true 进行预测
1864
+ 由于传入了 x_true,函数会跳过 x_train 的拆分,直接使用全部的 x_train 和 y_train 进行训练。然后对 x_true 进行预测,但由于没有提供 y_true,
1865
+ 因此无法与真实值进行对比。
1866
+ (3). predict(x_train, y_train, x_true, y_true,...)使用 x_train 和 y_train 训练,并验证 x_true 与真实标签 y_true.
1867
+ predict 函数会在 x_train 和 y_train 上进行训练,并将 x_true 作为测试集。由于提供了 y_true,函数可以将预测结果与 y_true 进行对比,从而
1868
+ 计算验证指标,完成对 x_true 的真正验证。
1869
+ trains and validates a variety of machine learning models for both classification and regression tasks.
1870
+ It supports hyperparameter tuning with grid search and includes additional features like cross-validation,
1871
+ feature scaling, and handling of class imbalance through SMOTE.
1872
+
1873
+ Parameters:
1874
+ - x_train (pd.DataFrame):Training feature data, structured with each row as an observation and each column as a feature.
1875
+ - y_train (pd.Series):Target variable for the training dataset.
1876
+ - x_true (pd.DataFrame, optional):Test feature data. If not provided, the function splits x_train based on test_size.
1877
+ - y_true (pd.Series, optional):Test target values. If not provided, y_train is split into training and testing sets.
1878
+ - common_features (set, optional):Specifies a subset of features common across training and test data.
1879
+ - purpose (str, default = "classification"):Defines whether the task is "classification" or "regression". Determines which
1880
+ metrics and models are applied.
1881
+ - cls (dict, optional):Dictionary to specify custom classifiers/regressors. Defaults to a set of common models if not provided.
1882
+ - metrics (list, optional):List of evaluation metrics (like accuracy, F1 score) used for model evaluation.
1883
+ - random_state (int, default = 1):Random seed to ensure reproducibility.
1884
+ - smote (bool, default = False):Applies Synthetic Minority Oversampling Technique (SMOTE) to address class imbalance if enabled.
1885
+ - n_jobs (int, default = -1):Number of parallel jobs for computation. Set to -1 to use all available cores.
1886
+ - plot_ (bool, default = True):If True, generates plots of the model evaluation metrics.
1887
+ - test_size (float, default = 0.2):Test data proportion if x_true is not provided.
1888
+ - cv_folds (int, default = 5):Number of cross-validation folds.
1889
+ - cv_level (str, default = "l"):Sets the detail level of cross-validation. "s" for low, "m" for medium, and "l" for high.
1890
+ - class_weight (str, default = "balanced"):Balances class weights in classification tasks.
1891
+ - verbose (bool, default = False):If True, prints detailed output during model training.
1892
+ - dir_save (str, default = "./"):Directory path to save plot outputs and results.
1893
+
1894
+ Key Steps in the Function:
1895
+ Model Initialization: Depending on purpose, initializes either classification or regression models.
1896
+ Feature Selection: Ensures training and test sets have matching feature columns.
1897
+ SMOTE Application: Balances classes if smote is enabled and the task is classification.
1898
+ Cross-Validation and Hyperparameter Tuning: Utilizes GridSearchCV for model tuning based on cv_level.
1899
+ Evaluation and Plotting: Outputs evaluation metrics like AUC, confusion matrices, and optional plotting of performance metrics.
1900
+ """
1901
+ from tqdm import tqdm
1902
+ from sklearn.ensemble import (
1903
+ RandomForestClassifier,
1904
+ RandomForestRegressor,
1905
+ ExtraTreesClassifier,
1906
+ ExtraTreesRegressor,
1907
+ BaggingClassifier,
1908
+ BaggingRegressor,
1909
+ AdaBoostClassifier,
1910
+ AdaBoostRegressor,
1911
+ )
1912
+ from sklearn.svm import SVC, SVR
1913
+ from sklearn.tree import DecisionTreeRegressor
1914
+ from sklearn.linear_model import (
1915
+ LogisticRegression,
1916
+ ElasticNet,
1917
+ ElasticNetCV,
1918
+ LinearRegression,
1919
+ Lasso,
1920
+ RidgeClassifierCV,
1921
+ Perceptron,
1922
+ SGDClassifier,
1923
+ )
1924
+ from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor
1925
+ from sklearn.naive_bayes import GaussianNB, BernoulliNB
1926
+ from sklearn.ensemble import GradientBoostingClassifier, GradientBoostingRegressor
1927
+ import xgboost as xgb
1928
+ import lightgbm as lgb
1929
+ import catboost as cb
1930
+ from sklearn.neural_network import MLPClassifier, MLPRegressor
1931
+ from sklearn.model_selection import GridSearchCV, StratifiedKFold, KFold
1932
+ from sklearn.discriminant_analysis import (
1933
+ LinearDiscriminantAnalysis,
1934
+ QuadraticDiscriminantAnalysis,
1935
+ )
1936
+ from sklearn.preprocessing import PolynomialFeatures
1937
+
1938
+ # 拼写检查
1939
+ purpose = ips.strcmp(purpose, ["classification", "regression"])[0]
1940
+ print(f"{purpose} processing...")
1941
+ # Default models or regressors if not provided
1942
+ if purpose == "classification":
1943
+ model_ = {
1944
+ "Random Forest": RandomForestClassifier(
1945
+ random_state=random_state, class_weight=class_weight
1946
+ ),
1947
+ # SVC (Support Vector Classification)
1948
+ "SVM": SVC(
1949
+ kernel="rbf",
1950
+ probability=True,
1951
+ class_weight=class_weight,
1952
+ random_state=random_state,
1953
+ ),
1954
+ # fit the best model without enforcing sparsity, which means it does not directly perform feature selection.
1955
+ "Logistic Regression": LogisticRegression(
1956
+ class_weight=class_weight, random_state=random_state
1957
+ ),
1958
+ # Logistic Regression with L1 Regularization (Lasso)
1959
+ "Lasso Logistic Regression": LogisticRegression(
1960
+ penalty="l1", solver="saga", random_state=random_state
1961
+ ),
1962
+ "Gradient Boosting": GradientBoostingClassifier(random_state=random_state),
1963
+ "XGBoost": xgb.XGBClassifier(
1964
+ eval_metric="logloss",
1965
+ random_state=random_state,
1966
+ ),
1967
+ "KNN": KNeighborsClassifier(n_neighbors=5),
1968
+ "Naive Bayes": GaussianNB(),
1969
+ "Linear Discriminant Analysis": LinearDiscriminantAnalysis(),
1970
+ "AdaBoost": AdaBoostClassifier(
1971
+ algorithm="SAMME", random_state=random_state
1972
+ ),
1973
+ # "LightGBM": lgb.LGBMClassifier(random_state=random_state, class_weight=class_weight),
1974
+ "CatBoost": cb.CatBoostClassifier(verbose=0, random_state=random_state),
1975
+ "Extra Trees": ExtraTreesClassifier(
1976
+ random_state=random_state, class_weight=class_weight
1977
+ ),
1978
+ "Bagging": BaggingClassifier(random_state=random_state),
1979
+ "Neural Network": MLPClassifier(max_iter=500, random_state=random_state),
1980
+ "DecisionTree": DecisionTreeClassifier(),
1981
+ "Quadratic Discriminant Analysis": QuadraticDiscriminantAnalysis(),
1982
+ "Ridge": RidgeClassifierCV(
1983
+ class_weight=class_weight, store_cv_results=True
1984
+ ),
1985
+ "Perceptron": Perceptron(random_state=random_state),
1986
+ "Bernoulli Naive Bayes": BernoulliNB(),
1987
+ "SGDClassifier": SGDClassifier(random_state=random_state),
1988
+ }
1989
+ elif purpose == "regression":
1990
+ model_ = {
1991
+ "Random Forest": RandomForestRegressor(random_state=random_state),
1992
+ "SVM": SVR(), # SVR (Support Vector Regression)
1993
+ # "Lasso": Lasso(random_state=random_state), # 它和LassoCV相同(必须要提供alpha参数),
1994
+ "LassoCV": LassoCV(
1995
+ cv=cv_folds, random_state=random_state
1996
+ ), # LassoCV自动找出最适alpha,优于Lasso
1997
+ "Gradient Boosting": GradientBoostingRegressor(random_state=random_state),
1998
+ "XGBoost": xgb.XGBRegressor(eval_metric="rmse", random_state=random_state),
1999
+ "Linear Regression": LinearRegression(),
2000
+ "Lasso": Lasso(random_state=random_state),
2001
+ "AdaBoost": AdaBoostRegressor(random_state=random_state),
2002
+ # "LightGBM": lgb.LGBMRegressor(random_state=random_state),
2003
+ "CatBoost": cb.CatBoostRegressor(verbose=0, random_state=random_state),
2004
+ "Extra Trees": ExtraTreesRegressor(random_state=random_state),
2005
+ "Bagging": BaggingRegressor(random_state=random_state),
2006
+ "Neural Network": MLPRegressor(max_iter=500, random_state=random_state),
2007
+ "ElasticNet": ElasticNet(random_state=random_state),
2008
+ "Ridge": Ridge(),
2009
+ "KNN": KNeighborsRegressor(),
2010
+ }
2011
+ # indicate cls:
2012
+ if ips.run_once_within(30): # 10 min
2013
+ print(f"supported models: {list(model_.keys())}")
2014
+ if cls is None:
2015
+ models = model_
2016
+ else:
2017
+ if not isinstance(cls, list):
2018
+ cls = [cls]
2019
+ models = {}
2020
+ for cls_ in cls:
2021
+ cls_ = ips.strcmp(cls_, list(model_.keys()))[0]
2022
+ models[cls_] = model_[cls_]
2023
+ if "LightGBM" in models:
2024
+ x_train = ips.df_special_characters_cleaner(x_train)
2025
+ x_true = (
2026
+ ips.df_special_characters_cleaner(x_true) if x_true is not None else None
2027
+ )
2028
+
2029
+ if isinstance(y_train, str) and y_train in x_train.columns:
2030
+ y_train_col_name = y_train
2031
+ y_train = x_train[y_train]
2032
+ # y_train = ips.df_encoder(pd.DataFrame(y_train), method="dummy")
2033
+ x_train = x_train.drop(y_train_col_name, axis=1)
2034
+ # else:
2035
+ # y_train = ips.df_encoder(pd.DataFrame(y_train), method="dummy").values.ravel()
2036
+ y_train=pd.DataFrame(y_train)
2037
+ y_train_=ips.df_encoder(y_train, method="dummy")
2038
+ is_binary = False if y_train_.shape[1] >1 else True
2039
+ print(is_binary)
2040
+ if is_binary:
2041
+ y_train = ips.df_encoder(pd.DataFrame(y_train), method="dummy").values.ravel()
2042
+ if x_true is None:
2043
+ x_train, x_true, y_train, y_true = train_test_split(
2044
+ x_train,
2045
+ y_train,
2046
+ test_size=test_size,
2047
+ random_state=random_state,
2048
+ stratify=y_train if purpose == "classification" else None,
2049
+ )
2050
+ if isinstance(y_train, str) and y_train in x_train.columns:
2051
+ y_train_col_name = y_train
2052
+ y_train = x_train[y_train]
2053
+ y_train = ips.df_encoder(pd.DataFrame(y_train), method="dummy")
2054
+ x_train = x_train.drop(y_train_col_name, axis=1)
2055
+ else:
2056
+ y_train = ips.df_encoder(
2057
+ pd.DataFrame(y_train), method="dummy"
2058
+ ).values.ravel()
2059
+
2060
+ if y_true is not None:
2061
+ if isinstance(y_true, str) and y_true in x_true.columns:
2062
+ y_true_col_name = y_true
2063
+ y_true = x_true[y_true]
2064
+ # y_true = ips.df_encoder(pd.DataFrame(y_true), method="dummy")
2065
+ y_true = pd.DataFrame(y_true)
2066
+ x_true = x_true.drop(y_true_col_name, axis=1)
2067
+ # else:
2068
+ # y_true = ips.df_encoder(pd.DataFrame(y_true), method="dummy").values.ravel()
2069
+
2070
+ # to convert the 2D to 1D: 2D column-vector format (like [[1], [0], [1], ...]) instead of a 1D array ([1, 0, 1, ...]
2071
+
2072
+ # y_train=y_train.values.ravel() if y_train is not None else None
2073
+ # y_true=y_true.values.ravel() if y_true is not None else None
2074
+ y_train = (
2075
+ y_train.ravel() if isinstance(y_train, np.ndarray) else y_train.values.ravel()
2076
+ )
2077
+ print(len(y_train),len(y_true))
2078
+ y_true = y_true.ravel() if isinstance(y_true, np.ndarray) else y_true.values.ravel()
2079
+ print(len(y_train),len(y_true))
2080
+ # Ensure common features are selected
2081
+ if common_features is not None:
2082
+ x_train, x_true = x_train[common_features], x_true[common_features]
2083
+ else:
2084
+ share_col_names = ips.shared(x_train.columns, x_true.columns, verbose=verbose)
2085
+ x_train, x_true = x_train[share_col_names], x_true[share_col_names]
2086
+
2087
+ x_train, x_true = ips.df_scaler(x_train), ips.df_scaler(x_true)
2088
+ x_train, x_true = ips.df_encoder(x_train, method="dummy"), ips.df_encoder(x_true, method="dummy")
2089
+ # Handle class imbalance using SMOTE (only for classification)
2090
+ if (
2091
+ smote
2092
+ and purpose == "classification"
2093
+ and y_train.value_counts(normalize=True).max() < 0.8
2094
+ ):
2095
+ from imblearn.over_sampling import SMOTE
2096
+
2097
+ smote_sampler = SMOTE(random_state=random_state)
2098
+ x_train, y_train = smote_sampler.fit_resample(x_train, y_train)
2099
+
2100
+ # Hyperparameter grids for tuning
2101
+ if cv_level in ["low", "simple", "s", "l"]:
2102
+ param_grids = {
2103
+ "Random Forest": (
2104
+ {
2105
+ "n_estimators": [100], # One basic option
2106
+ "max_depth": [None, 10],
2107
+ "min_samples_split": [2],
2108
+ "min_samples_leaf": [1],
2109
+ "class_weight": [None],
2110
+ }
2111
+ if purpose == "classification"
2112
+ else {
2113
+ "n_estimators": [100], # One basic option
2114
+ "max_depth": [None, 10],
2115
+ "min_samples_split": [2],
2116
+ "min_samples_leaf": [1],
2117
+ "max_features": [None],
2118
+ "bootstrap": [True], # Only one option for simplicity
2119
+ }
2120
+ ),
2121
+ "SVM": {
2122
+ "C": [1],
2123
+ "gamma": ["scale"],
2124
+ "kernel": ["rbf"],
2125
+ },
2126
+ "Lasso": {
2127
+ "alpha": [0.1],
2128
+ },
2129
+ "LassoCV": {
2130
+ "alphas": [[0.1]],
2131
+ },
2132
+ "Logistic Regression": {
2133
+ "C": [1],
2134
+ "solver": ["lbfgs"],
2135
+ "penalty": ["l2"],
2136
+ "max_iter": [500],
2137
+ },
2138
+ "Gradient Boosting": {
2139
+ "n_estimators": [100],
2140
+ "learning_rate": [0.1],
2141
+ "max_depth": [3],
2142
+ "min_samples_split": [2],
2143
+ "subsample": [0.8],
2144
+ },
2145
+ "XGBoost": {
2146
+ "n_estimators": [100],
2147
+ "max_depth": [3],
2148
+ "learning_rate": [0.1],
2149
+ "subsample": [0.8],
2150
+ "colsample_bytree": [0.8],
2151
+ },
2152
+ "KNN": (
2153
+ {
2154
+ "n_neighbors": [3],
2155
+ "weights": ["uniform"],
2156
+ "algorithm": ["auto"],
2157
+ "p": [2],
2158
+ }
2159
+ if purpose == "classification"
2160
+ else {
2161
+ "n_neighbors": [3],
2162
+ "weights": ["uniform"],
2163
+ "metric": ["euclidean"],
2164
+ "leaf_size": [30],
2165
+ "p": [2],
2166
+ }
2167
+ ),
2168
+ "Naive Bayes": {
2169
+ "var_smoothing": [1e-9],
2170
+ },
2171
+ "SVR": {
2172
+ "C": [1],
2173
+ "gamma": ["scale"],
2174
+ "kernel": ["rbf"],
2175
+ },
2176
+ "Linear Regression": {
2177
+ "fit_intercept": [True],
2178
+ },
2179
+ "Extra Trees": {
2180
+ "n_estimators": [100],
2181
+ "max_depth": [None, 10],
2182
+ "min_samples_split": [2],
2183
+ "min_samples_leaf": [1],
2184
+ },
2185
+ "CatBoost": {
2186
+ "iterations": [100],
2187
+ "learning_rate": [0.1],
2188
+ "depth": [3],
2189
+ "l2_leaf_reg": [1],
2190
+ },
2191
+ "LightGBM": {
2192
+ "n_estimators": [100],
2193
+ "num_leaves": [31],
2194
+ "max_depth": [10],
2195
+ "min_data_in_leaf": [20],
2196
+ "min_gain_to_split": [0.01],
2197
+ "scale_pos_weight": [10],
2198
+ },
2199
+ "Bagging": {
2200
+ "n_estimators": [50],
2201
+ "max_samples": [0.7],
2202
+ "max_features": [0.7],
2203
+ },
2204
+ "Neural Network": {
2205
+ "hidden_layer_sizes": [(50,)],
2206
+ "activation": ["relu"],
2207
+ "solver": ["adam"],
2208
+ "alpha": [0.0001],
2209
+ },
2210
+ "Decision Tree": {
2211
+ "max_depth": [None, 10],
2212
+ "min_samples_split": [2],
2213
+ "min_samples_leaf": [1],
2214
+ "criterion": ["gini"],
2215
+ },
2216
+ "AdaBoost": {
2217
+ "n_estimators": [50],
2218
+ "learning_rate": [0.5],
2219
+ },
2220
+ "Linear Discriminant Analysis": {
2221
+ "solver": ["svd"],
2222
+ "shrinkage": [None],
2223
+ },
2224
+ "Quadratic Discriminant Analysis": {
2225
+ "reg_param": [0.0],
2226
+ "priors": [None],
2227
+ "tol": [1e-4],
2228
+ },
2229
+ "Ridge": (
2230
+ {"class_weight": [None, "balanced"]}
2231
+ if purpose == "classification"
2232
+ else {
2233
+ "alpha": [0.1, 1, 10],
2234
+ }
2235
+ ),
2236
+ "Perceptron": {
2237
+ "alpha": [1e-3],
2238
+ "penalty": ["l2"],
2239
+ "max_iter": [1000],
2240
+ "eta0": [1.0],
2241
+ },
2242
+ "Bernoulli Naive Bayes": {
2243
+ "alpha": [0.1, 1, 10],
2244
+ "binarize": [0.0],
2245
+ "fit_prior": [True],
2246
+ },
2247
+ "SGDClassifier": {
2248
+ "eta0": [0.01],
2249
+ "loss": ["hinge"],
2250
+ "penalty": ["l2"],
2251
+ "alpha": [1e-3],
2252
+ "max_iter": [1000],
2253
+ "tol": [1e-3],
2254
+ "random_state": [random_state],
2255
+ "learning_rate": ["constant"],
2256
+ },
2257
+ }
2258
+ elif cv_level in ["high", "advanced", "h"]:
2259
+ param_grids = {
2260
+ "Random Forest": (
2261
+ {
2262
+ "n_estimators": [100, 200, 500, 700, 1000],
2263
+ "max_depth": [None, 3, 5, 10, 15, 20, 30],
2264
+ "min_samples_split": [2, 5, 10, 20],
2265
+ "min_samples_leaf": [1, 2, 4],
2266
+ "class_weight": (
2267
+ [None, "balanced"] if purpose == "classification" else {}
2268
+ ),
2269
+ }
2270
+ if purpose == "classification"
2271
+ else {
2272
+ "n_estimators": [100, 200, 500, 700, 1000],
2273
+ "max_depth": [None, 3, 5, 10, 15, 20, 30],
2274
+ "min_samples_split": [2, 5, 10, 20],
2275
+ "min_samples_leaf": [1, 2, 4],
2276
+ "max_features": [
2277
+ "auto",
2278
+ "sqrt",
2279
+ "log2",
2280
+ ], # Number of features to consider when looking for the best split
2281
+ "bootstrap": [
2282
+ True,
2283
+ False,
2284
+ ], # Whether bootstrap samples are used when building trees
2285
+ }
2286
+ ),
2287
+ "SVM": {
2288
+ "C": [0.001, 0.01, 0.1, 1, 10, 100, 1000],
2289
+ "gamma": ["scale", "auto", 0.001, 0.01, 0.1],
2290
+ "kernel": ["linear", "rbf", "poly"],
2291
+ },
2292
+ "Logistic Regression": {
2293
+ "C": [0.001, 0.01, 0.1, 1, 10, 100, 1000],
2294
+ "solver": ["liblinear", "saga", "newton-cg", "lbfgs"],
2295
+ "penalty": ["l1", "l2", "elasticnet"],
2296
+ "max_iter": [100, 200, 300, 500],
2297
+ },
2298
+ "Lasso": {
2299
+ "alpha": [0.0001, 0.001, 0.01, 0.1, 1.0, 10.0, 100.0],
2300
+ "max_iter": [500, 1000, 2000, 5000],
2301
+ "tol": [1e-4, 1e-5, 1e-6],
2302
+ "selection": ["cyclic", "random"],
2303
+ },
2304
+ "LassoCV": {
2305
+ "alphas": [[0.0001, 0.001, 0.01, 0.1, 1.0, 10.0, 100.0]],
2306
+ "max_iter": [500, 1000, 2000, 5000],
2307
+ "cv": [3, 5, 10],
2308
+ "tol": [1e-4, 1e-5, 1e-6],
2309
+ },
2310
+ "Gradient Boosting": {
2311
+ "n_estimators": [100, 200, 300, 400, 500, 700, 1000],
2312
+ "learning_rate": [0.001, 0.01, 0.1, 0.2, 0.3, 0.5],
2313
+ "max_depth": [3, 5, 7, 9, 15],
2314
+ "min_samples_split": [2, 5, 10, 20],
2315
+ "subsample": [0.8, 1.0],
2316
+ },
2317
+ "XGBoost": {
2318
+ "n_estimators": [100, 200, 500, 700],
2319
+ "max_depth": [3, 5, 7, 10],
2320
+ "learning_rate": [0.01, 0.1, 0.2, 0.3],
2321
+ "subsample": [0.8, 1.0],
2322
+ "colsample_bytree": [0.8, 0.9, 1.0],
2323
+ },
2324
+ "KNN": (
2325
+ {
2326
+ "n_neighbors": [1, 3, 5, 10, 15, 20],
2327
+ "weights": ["uniform", "distance"],
2328
+ "algorithm": ["auto", "ball_tree", "kd_tree", "brute"],
2329
+ "p": [1, 2], # 1 for Manhattan, 2 for Euclidean distance
2330
+ }
2331
+ if purpose == "classification"
2332
+ else {
2333
+ "n_neighbors": [3, 5, 7, 9, 11], # Number of neighbors
2334
+ "weights": [
2335
+ "uniform",
2336
+ "distance",
2337
+ ], # Weight function used in prediction
2338
+ "metric": [
2339
+ "euclidean",
2340
+ "manhattan",
2341
+ "minkowski",
2342
+ ], # Distance metric
2343
+ "leaf_size": [
2344
+ 20,
2345
+ 30,
2346
+ 40,
2347
+ 50,
2348
+ ], # Leaf size for KDTree or BallTree algorithms
2349
+ "p": [
2350
+ 1,
2351
+ 2,
2352
+ ], # Power parameter for the Minkowski metric (1 = Manhattan, 2 = Euclidean)
2353
+ }
2354
+ ),
2355
+ "Naive Bayes": {
2356
+ "var_smoothing": [1e-10, 1e-9, 1e-8, 1e-7],
2357
+ },
2358
+ "AdaBoost": {
2359
+ "n_estimators": [50, 100, 200, 300, 500],
2360
+ "learning_rate": [0.001, 0.01, 0.1, 0.5, 1.0],
2361
+ },
2362
+ "SVR": {
2363
+ "C": [0.01, 0.1, 1, 10, 100, 1000],
2364
+ "gamma": [0.001, 0.01, 0.1, "scale", "auto"],
2365
+ "kernel": ["linear", "rbf", "poly"],
2366
+ },
2367
+ "Linear Regression": {
2368
+ "fit_intercept": [True, False],
2369
+ },
2370
+ "Lasso": {
2371
+ "alpha": [0.001, 0.01, 0.1, 1.0, 10.0, 100.0],
2372
+ "max_iter": [1000, 2000], # Higher iteration limit for fine-tuning
2373
+ },
2374
+ "Extra Trees": {
2375
+ "n_estimators": [100, 200, 500, 700, 1000],
2376
+ "max_depth": [None, 5, 10, 15, 20, 30],
2377
+ "min_samples_split": [2, 5, 10, 20],
2378
+ "min_samples_leaf": [1, 2, 4],
2379
+ },
2380
+ "CatBoost": {
2381
+ "iterations": [100, 200, 500],
2382
+ "learning_rate": [0.001, 0.01, 0.1, 0.2],
2383
+ "depth": [3, 5, 7, 10],
2384
+ "l2_leaf_reg": [1, 3, 5, 7, 10],
2385
+ "border_count": [32, 64, 128],
2386
+ },
2387
+ "LightGBM": {
2388
+ "n_estimators": [100, 200, 500, 700, 1000],
2389
+ "learning_rate": [0.001, 0.01, 0.1, 0.2],
2390
+ "num_leaves": [31, 50, 100, 200],
2391
+ "max_depth": [-1, 5, 10, 20, 30],
2392
+ "min_child_samples": [5, 10, 20],
2393
+ "subsample": [0.8, 1.0],
2394
+ "colsample_bytree": [0.8, 0.9, 1.0],
2395
+ },
2396
+ "Neural Network": {
2397
+ "hidden_layer_sizes": [(50,), (100,), (100, 50), (200, 100)],
2398
+ "activation": ["relu", "tanh", "logistic"],
2399
+ "solver": ["adam", "sgd", "lbfgs"],
2400
+ "alpha": [0.0001, 0.001, 0.01],
2401
+ "learning_rate": ["constant", "adaptive"],
2402
+ },
2403
+ "Decision Tree": {
2404
+ "max_depth": [None, 5, 10, 20, 30],
2405
+ "min_samples_split": [2, 5, 10, 20],
2406
+ "min_samples_leaf": [1, 2, 5, 10],
2407
+ "criterion": ["gini", "entropy"],
2408
+ "splitter": ["best", "random"],
2409
+ },
2410
+ "Linear Discriminant Analysis": {
2411
+ "solver": ["svd", "lsqr", "eigen"],
2412
+ "shrinkage": [
2413
+ None,
2414
+ "auto",
2415
+ 0.1,
2416
+ 0.5,
2417
+ 1.0,
2418
+ ], # shrinkage levels for 'lsqr' and 'eigen'
2419
+ },
2420
+ "Ridge": (
2421
+ {"class_weight": [None, "balanced"]}
2422
+ if purpose == "classification"
2423
+ else {
2424
+ "alpha": [0.1, 1, 10, 100, 1000],
2425
+ "solver": ["auto", "svd", "cholesky", "lsqr", "lbfgs"],
2426
+ "fit_intercept": [
2427
+ True,
2428
+ False,
2429
+ ], # Whether to calculate the intercept
2430
+ "normalize": [
2431
+ True,
2432
+ False,
2433
+ ], # If True, the regressors X will be normalized
2434
+ }
2435
+ ),
2436
+ }
2437
+ else: # median level
2438
+ param_grids = {
2439
+ "Random Forest": (
2440
+ {
2441
+ "n_estimators": [100, 200, 500],
2442
+ "max_depth": [None, 10, 20, 30],
2443
+ "min_samples_split": [2, 5, 10],
2444
+ "min_samples_leaf": [1, 2, 4],
2445
+ "class_weight": [None, "balanced"],
2446
+ }
2447
+ if purpose == "classification"
2448
+ else {
2449
+ "n_estimators": [100, 200, 500],
2450
+ "max_depth": [None, 10, 20, 30],
2451
+ "min_samples_split": [2, 5, 10],
2452
+ "min_samples_leaf": [1, 2, 4],
2453
+ "max_features": [
2454
+ "auto",
2455
+ "sqrt",
2456
+ "log2",
2457
+ ], # Number of features to consider when looking for the best split
2458
+ "bootstrap": [
2459
+ True,
2460
+ False,
2461
+ ], # Whether bootstrap samples are used when building trees
2462
+ }
2463
+ ),
2464
+ "SVM": {
2465
+ "C": [0.1, 1, 10, 100], # Regularization strength
2466
+ "gamma": ["scale", "auto"], # Common gamma values
2467
+ "kernel": ["rbf", "linear", "poly"],
2468
+ },
2469
+ "Logistic Regression": {
2470
+ "C": [0.1, 1, 10, 100], # Regularization strength
2471
+ "solver": ["lbfgs", "liblinear", "saga"], # Common solvers
2472
+ "penalty": ["l2"], # L2 penalty is most common
2473
+ "max_iter": [
2474
+ 500,
2475
+ 1000,
2476
+ 2000,
2477
+ ], # Increased max_iter for better convergence
2478
+ },
2479
+ "Lasso": {
2480
+ "alpha": [0.001, 0.01, 0.1, 1.0, 10.0, 100.0],
2481
+ "max_iter": [500, 1000, 2000],
2482
+ },
2483
+ "LassoCV": {
2484
+ "alphas": [[0.001, 0.01, 0.1, 1.0, 10.0, 100.0]],
2485
+ "max_iter": [500, 1000, 2000],
2486
+ },
2487
+ "Gradient Boosting": {
2488
+ "n_estimators": [100, 200, 500],
2489
+ "learning_rate": [0.01, 0.1, 0.2],
2490
+ "max_depth": [3, 5, 7],
2491
+ "min_samples_split": [2, 5, 10],
2492
+ "subsample": [0.8, 1.0],
2493
+ },
2494
+ "XGBoost": {
2495
+ "n_estimators": [100, 200, 500],
2496
+ "max_depth": [3, 5, 7],
2497
+ "learning_rate": [0.01, 0.1, 0.2],
2498
+ "subsample": [0.8, 1.0],
2499
+ "colsample_bytree": [0.8, 1.0],
2500
+ },
2501
+ "KNN": (
2502
+ {
2503
+ "n_neighbors": [3, 5, 7, 10],
2504
+ "weights": ["uniform", "distance"],
2505
+ "algorithm": ["auto", "ball_tree", "kd_tree", "brute"],
2506
+ "p": [1, 2],
2507
+ }
2508
+ if purpose == "classification"
2509
+ else {
2510
+ "n_neighbors": [3, 5, 7, 9, 11], # Number of neighbors
2511
+ "weights": [
2512
+ "uniform",
2513
+ "distance",
2514
+ ], # Weight function used in prediction
2515
+ "metric": [
2516
+ "euclidean",
2517
+ "manhattan",
2518
+ "minkowski",
2519
+ ], # Distance metric
2520
+ "leaf_size": [
2521
+ 20,
2522
+ 30,
2523
+ 40,
2524
+ 50,
2525
+ ], # Leaf size for KDTree or BallTree algorithms
2526
+ "p": [
2527
+ 1,
2528
+ 2,
2529
+ ], # Power parameter for the Minkowski metric (1 = Manhattan, 2 = Euclidean)
2530
+ }
2531
+ ),
2532
+ "Naive Bayes": {
2533
+ "var_smoothing": [1e-9, 1e-8, 1e-7],
2534
+ },
2535
+ "SVR": {
2536
+ "C": [0.1, 1, 10, 100],
2537
+ "gamma": ["scale", "auto"],
2538
+ "kernel": ["rbf", "linear"],
2539
+ },
2540
+ "Linear Regression": {
2541
+ "fit_intercept": [True, False],
2542
+ },
2543
+ "Lasso": {
2544
+ "alpha": [0.1, 1.0, 10.0],
2545
+ "max_iter": [1000, 2000], # Sufficient iterations for convergence
2546
+ },
2547
+ "Extra Trees": {
2548
+ "n_estimators": [100, 200, 500],
2549
+ "max_depth": [None, 10, 20, 30],
2550
+ "min_samples_split": [2, 5, 10],
2551
+ "min_samples_leaf": [1, 2, 4],
2552
+ },
2553
+ "CatBoost": {
2554
+ "iterations": [100, 200],
2555
+ "learning_rate": [0.01, 0.1],
2556
+ "depth": [3, 6, 10],
2557
+ "l2_leaf_reg": [1, 3, 5, 7],
2558
+ },
2559
+ "LightGBM": {
2560
+ "n_estimators": [100, 200, 500],
2561
+ "learning_rate": [0.01, 0.1],
2562
+ "num_leaves": [31, 50, 100],
2563
+ "max_depth": [-1, 10, 20],
2564
+ "min_data_in_leaf": [20], # Minimum samples in each leaf
2565
+ "min_gain_to_split": [0.01], # Minimum gain to allow a split
2566
+ "scale_pos_weight": [10], # Address class imbalance
2567
+ },
2568
+ "Bagging": {
2569
+ "n_estimators": [10, 50, 100],
2570
+ "max_samples": [0.5, 0.7, 1.0],
2571
+ "max_features": [0.5, 0.7, 1.0],
2572
+ },
2573
+ "Neural Network": {
2574
+ "hidden_layer_sizes": [(50,), (100,), (100, 50)],
2575
+ "activation": ["relu", "tanh"],
2576
+ "solver": ["adam", "sgd"],
2577
+ "alpha": [0.0001, 0.001],
2578
+ },
2579
+ "Decision Tree": {
2580
+ "max_depth": [None, 10, 20],
2581
+ "min_samples_split": [2, 10],
2582
+ "min_samples_leaf": [1, 4],
2583
+ "criterion": ["gini", "entropy"],
2584
+ },
2585
+ "AdaBoost": {
2586
+ "n_estimators": [50, 100],
2587
+ "learning_rate": [0.5, 1.0],
2588
+ },
2589
+ "Linear Discriminant Analysis": {
2590
+ "solver": ["svd", "lsqr", "eigen"],
2591
+ "shrinkage": [None, "auto"],
2592
+ },
2593
+ "Quadratic Discriminant Analysis": {
2594
+ "reg_param": [0.0, 0.1, 0.5, 1.0], # Regularization parameter
2595
+ "priors": [None, [0.5, 0.5], [0.3, 0.7]], # Class priors
2596
+ "tol": [
2597
+ 1e-4,
2598
+ 1e-3,
2599
+ 1e-2,
2600
+ ], # Tolerance value for the convergence of the algorithm
2601
+ },
2602
+ "Perceptron": {
2603
+ "alpha": [1e-4, 1e-3, 1e-2], # Regularization parameter
2604
+ "penalty": ["l2", "l1", "elasticnet"], # Regularization penalty
2605
+ "max_iter": [1000, 2000], # Maximum number of iterations
2606
+ "eta0": [1.0, 0.1], # Learning rate for gradient descent
2607
+ "tol": [1e-3, 1e-4, 1e-5], # Tolerance for stopping criteria
2608
+ "random_state": [random_state], # Random state for reproducibility
2609
+ },
2610
+ "Bernoulli Naive Bayes": {
2611
+ "alpha": [0.1, 1.0, 10.0], # Additive (Laplace) smoothing parameter
2612
+ "binarize": [
2613
+ 0.0,
2614
+ 0.5,
2615
+ 1.0,
2616
+ ], # Threshold for binarizing the input features
2617
+ "fit_prior": [
2618
+ True,
2619
+ False,
2620
+ ], # Whether to learn class prior probabilities
2621
+ },
2622
+ "SGDClassifier": {
2623
+ "eta0": [0.01, 0.1, 1.0],
2624
+ "loss": [
2625
+ "hinge",
2626
+ "log",
2627
+ "modified_huber",
2628
+ "squared_hinge",
2629
+ "perceptron",
2630
+ ], # Loss function
2631
+ "penalty": ["l2", "l1", "elasticnet"], # Regularization penalty
2632
+ "alpha": [1e-4, 1e-3, 1e-2], # Regularization strength
2633
+ "l1_ratio": [0.15, 0.5, 0.85], # L1 ratio for elasticnet penalty
2634
+ "max_iter": [1000, 2000], # Maximum number of iterations
2635
+ "tol": [1e-3, 1e-4], # Tolerance for stopping criteria
2636
+ "random_state": [random_state], # Random state for reproducibility
2637
+ "learning_rate": [
2638
+ "constant",
2639
+ "optimal",
2640
+ "invscaling",
2641
+ "adaptive",
2642
+ ], # Learning rate schedule
2643
+ },
2644
+ "Ridge": (
2645
+ {"class_weight": [None, "balanced"]}
2646
+ if purpose == "classification"
2647
+ else {
2648
+ "alpha": [0.1, 1, 10, 100],
2649
+ "solver": [
2650
+ "auto",
2651
+ "svd",
2652
+ "cholesky",
2653
+ "lsqr",
2654
+ ], # Solver for optimization
2655
+ }
2656
+ ),
2657
+ }
2658
+
2659
+ results = {}
2660
+ # Use StratifiedKFold for classification and KFold for regression
2661
+ cv = (
2662
+ StratifiedKFold(n_splits=cv_folds, shuffle=True, random_state=random_state)
2663
+ if purpose == "classification"
2664
+ else KFold(n_splits=cv_folds, shuffle=True, random_state=random_state)
2665
+ )
2666
+
2667
+ # Train and validate each model
2668
+ for name, clf in tqdm(
2669
+ models.items(),
2670
+ desc="models",
2671
+ colour="green",
2672
+ bar_format="{l_bar}{bar} {n_fmt}/{total_fmt}",
2673
+ ):
2674
+ if verbose:
2675
+ print(f"\nTraining and validating {name}:")
2676
+
2677
+ # Grid search with KFold or StratifiedKFold
2678
+ gs = GridSearchCV(
2679
+ clf,
2680
+ param_grid=param_grids.get(name, {}),
2681
+ scoring=(
2682
+ "roc_auc" if purpose == "classification" else "neg_mean_squared_error"
2683
+ ),
2684
+ cv=cv,
2685
+ n_jobs=n_jobs,
2686
+ verbose=verbose,
2687
+ )
2688
+ gs.fit(x_train, y_train)
2689
+ best_clf = gs.best_estimator_
2690
+ # make sure x_train and x_test has the same name
2691
+ x_true = x_true.reindex(columns=x_train.columns, fill_value=0)
2692
+ y_pred = best_clf.predict(x_true)
2693
+
2694
+ # y_pred_proba
2695
+ if hasattr(best_clf, "predict_proba"):
2696
+ y_pred_proba = best_clf.predict_proba(x_true)[:, 1]
2697
+ elif hasattr(best_clf, "decision_function"):
2698
+ # If predict_proba is not available, use decision_function (e.g., for SVM)
2699
+ y_pred_proba = best_clf.decision_function(x_true)
2700
+ # Ensure y_pred_proba is within 0 and 1 bounds
2701
+ y_pred_proba = (y_pred_proba - y_pred_proba.min()) / (
2702
+ y_pred_proba.max() - y_pred_proba.min()
2703
+ )
2704
+ else:
2705
+ y_pred_proba = None # No probability output for certain models
2706
+
2707
+ validation_scores = {}
2708
+ if y_true is not None:
2709
+ validation_scores = cal_metrics(
2710
+ y_true,
2711
+ y_pred,
2712
+ y_pred_proba=y_pred_proba,
2713
+ is_binary=is_binary,
2714
+ purpose=purpose,
2715
+ average="weighted",
2716
+ )
2717
+
2718
+ # Calculate ROC curve
2719
+ # https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html
2720
+ if y_pred_proba is not None:
2721
+ # fpr, tpr, roc_auc = dict(), dict(), dict()
2722
+ fpr, tpr, _ = roc_curve(y_true, y_pred_proba)
2723
+ lower_ci, upper_ci = cal_auc_ci(y_true, y_pred_proba, verbose=False)
2724
+ roc_auc = auc(fpr, tpr)
2725
+ roc_info = {
2726
+ "fpr": fpr.tolist(),
2727
+ "tpr": tpr.tolist(),
2728
+ "auc": roc_auc,
2729
+ "ci95": (lower_ci, upper_ci),
2730
+ }
2731
+ # precision-recall curve
2732
+ precision_, recall_, _ = precision_recall_curve(y_true, y_pred_proba)
2733
+ avg_precision_ = average_precision_score(y_true, y_pred_proba)
2734
+ pr_info = {
2735
+ "precision": precision_,
2736
+ "recall": recall_,
2737
+ "avg_precision": avg_precision_,
2738
+ }
2739
+ else:
2740
+ roc_info, pr_info = None, None
2741
+ if purpose == "classification":
2742
+ results[name] = {
2743
+ "best_clf": gs.best_estimator_,
2744
+ "best_params": gs.best_params_,
2745
+ "auc_indiv": [
2746
+ gs.cv_results_[f"split{i}_test_score"][gs.best_index_]
2747
+ for i in range(cv_folds)
2748
+ ],
2749
+ "scores": validation_scores,
2750
+ "roc_curve": roc_info,
2751
+ "pr_curve": pr_info,
2752
+ "confusion_matrix": confusion_matrix(y_true, y_pred),
2753
+ "predictions": y_pred.tolist(),
2754
+ "predictions_proba": (
2755
+ y_pred_proba.tolist() if y_pred_proba is not None else None
2756
+ ),
2757
+ }
2758
+ else: # "regression"
2759
+ results[name] = {
2760
+ "best_clf": gs.best_estimator_,
2761
+ "best_params": gs.best_params_,
2762
+ "scores": validation_scores, # e.g., neg_MSE, R², etc.
2763
+ "predictions": y_pred.tolist(),
2764
+ "predictions_proba": (
2765
+ y_pred_proba.tolist() if y_pred_proba is not None else None
2766
+ ),
2767
+ }
2768
+
2769
+ else:
2770
+ results[name] = {
2771
+ "best_clf": gs.best_estimator_,
2772
+ "best_params": gs.best_params_,
2773
+ "scores": validation_scores,
2774
+ "predictions": y_pred.tolist(),
2775
+ "predictions_proba": (
2776
+ y_pred_proba.tolist() if y_pred_proba is not None else None
2777
+ ),
2778
+ }
2779
+
2780
+ # Convert results to DataFrame
2781
+ df_results = pd.DataFrame.from_dict(results, orient="index")
2782
+
2783
+ # sort
2784
+ if y_true is not None and purpose == "classification":
2785
+ df_scores = pd.DataFrame(
2786
+ df_results["scores"].tolist(), index=df_results["scores"].index
2787
+ ).sort_values(by="roc_auc", ascending=False)
2788
+ df_results = df_results.loc[df_scores.index]
2789
+
2790
+ if plot_:
2791
+ from datetime import datetime
2792
+
2793
+ now_ = datetime.now().strftime("%y%m%d_%H%M%S")
2794
+ nexttile = plot.subplot(figsize=[12, 10])
2795
+ plot.heatmap(df_scores, kind="direct", ax=nexttile())
2796
+ plot.figsets(xangle=30)
2797
+ if dir_save:
2798
+ ips.figsave(dir_save + f"scores_sorted_heatmap{now_}.pdf")
2799
+ if df_scores.shape[0] > 1: # draw cluster
2800
+ plot.heatmap(df_scores, kind="direct", cluster=True)
2801
+ plot.figsets(xangle=30)
2802
+ if dir_save:
2803
+ ips.figsave(dir_save + f"scores_clus{now_}.pdf")
2804
+ if all([plot_, y_true is not None, purpose == "classification"]):
2805
+ try:
2806
+ if len(models) > 3:
2807
+ plot_validate_features(df_results)
2808
+ else:
2809
+ plot_validate_features_single(df_results, figsize=(12, 4 * len(models)))
2810
+ if dir_save:
2811
+ ips.figsave(dir_save + f"validate_features{now_}.pdf")
2812
+ except Exception as e:
2813
+ print(f"Error: 在画图的过程中出现了问题:{e}")
2814
+ return df_results
2815
+
2816
+
2817
+ def cal_metrics(
2818
+ y_true, y_pred, y_pred_proba=None, is_binary=True,purpose="regression", average="weighted"
2819
+ ):
2820
+ """
2821
+ Calculate regression or classification metrics based on the purpose.
2822
+
2823
+ Parameters:
2824
+ - y_true: Array of true values.
2825
+ - y_pred: Array of predicted labels for classification or predicted values for regression.
2826
+ - y_pred_proba: Array of predicted probabilities for classification (optional).
2827
+ - purpose: str, "regression" or "classification".
2828
+ - average: str, averaging method for multi-class classification ("binary", "micro", "macro", "weighted", etc.).
2829
+
2830
+ Returns:
2831
+ - validation_scores: dict of computed metrics.
2832
+ """
2833
+ from sklearn.metrics import (
2834
+ mean_squared_error,
2835
+ mean_absolute_error,
2836
+ mean_absolute_percentage_error,
2837
+ explained_variance_score,
2838
+ r2_score,
2839
+ mean_squared_log_error,
2840
+ accuracy_score,
2841
+ precision_score,
2842
+ recall_score,
2843
+ f1_score,
2844
+ roc_auc_score,
2845
+ matthews_corrcoef,
2846
+ confusion_matrix,
2847
+ balanced_accuracy_score,
2848
+ average_precision_score,
2849
+ precision_recall_curve,
2850
+ )
2851
+
2852
+ validation_scores = {}
2853
+
2854
+ if purpose == "regression":
2855
+ y_true = np.asarray(y_true)
2856
+ y_true = y_true.ravel()
2857
+ y_pred = np.asarray(y_pred)
2858
+ y_pred = y_pred.ravel()
2859
+ # Regression metrics
2860
+ validation_scores = {
2861
+ "mse": mean_squared_error(y_true, y_pred),
2862
+ "rmse": np.sqrt(mean_squared_error(y_true, y_pred)),
2863
+ "mae": mean_absolute_error(y_true, y_pred),
2864
+ "r2": r2_score(y_true, y_pred),
2865
+ "mape": mean_absolute_percentage_error(y_true, y_pred),
2866
+ "explained_variance": explained_variance_score(y_true, y_pred),
2867
+ "mbd": np.mean(y_pred - y_true), # Mean Bias Deviation
2868
+ }
2869
+ # Check if MSLE can be calculated
2870
+ if np.all(y_true >= 0) and np.all(y_pred >= 0): # Ensure no negative values
2871
+ validation_scores["msle"] = mean_squared_log_error(y_true, y_pred)
2872
+ else:
2873
+ validation_scores["msle"] = "Cannot be calculated due to negative values"
2874
+
2875
+ elif purpose == "classification":
2876
+ # Classification metrics
2877
+ validation_scores = {
2878
+ "accuracy": accuracy_score(y_true, y_pred),
2879
+ "precision": precision_score(y_true, y_pred, average=average),
2880
+ "recall": recall_score(y_true, y_pred, average=average),
2881
+ "f1": f1_score(y_true, y_pred, average=average),
2882
+ "mcc": matthews_corrcoef(y_true, y_pred),
2883
+ "specificity": None,
2884
+ "balanced_accuracy": balanced_accuracy_score(y_true, y_pred),
2885
+ }
2886
+
2887
+ # Confusion matrix to calculate specificity
2888
+ if is_binary:
2889
+ tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
2890
+ else:
2891
+ cm=onfusion_matrix(y_true, y_pred)
2892
+ validation_scores["specificity"] = (
2893
+ tn / (tn + fp) if (tn + fp) > 0 else 0
2894
+ ) # Specificity calculation
2895
+
2896
+ if y_pred_proba is not None:
2897
+ # Calculate ROC-AUC
2898
+ validation_scores["roc_auc"] = roc_auc_score(y_true, y_pred_proba)
2899
+ # PR-AUC (Precision-Recall AUC) calculation
2900
+ validation_scores["pr_auc"] = average_precision_score(y_true, y_pred_proba)
2901
+ else:
2902
+ raise ValueError(
2903
+ "Invalid purpose specified. Choose 'regression' or 'classification'."
2904
+ )
2905
+
2906
+ return validation_scores